HN 2 weeks ago
parent 928efab79b
commit ad4aaa0fad

74
.gitignore vendored

@ -0,0 +1,74 @@
# =========================
# Build / CMake
# =========================
build/
build_*/
cmake-build-*/
out/
output/
dist/
test/
doc/
CMakeFiles/
CMakeCache.txt
cmake_install.cmake
install_manifest.txt
Makefile
compile_commands.json
.ninja_deps
.ninja_log
# =========================
# Generated / intermediate
# =========================
*.o
*.obj
*.a
*.lib
*.so
*.dylib
*.dll
*.exe
*.out
!test/test_case/**/*.out
*.app
*.pdb
*.ilk
*.dSYM/
*.log
*.tmp
*.swp
*.swo
*.bak
# ANTLR 生成物(通常在 build/,这里额外兜底)
**/generated/antlr4/
**/antlr4-generated/
*.tokens
*.interp
# =========================
# IDE / Editor
# =========================
.vscode/
.idea/
.fleet/
.vs/
*.code-workspace
# CLion
cmake-build-debug/
cmake-build-release/
# =========================
# OS / misc
# =========================
.DS_Store
Thumbs.db
# =========================
# Project outputs
# =========================
test/test_result/

@ -0,0 +1,79 @@
cmake_minimum_required(VERSION 3.20)
project(compiler LANGUAGES C CXX)
# C++
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
# <build>/bin build/src/
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
foreach(cfg IN ITEMS Debug Release RelWithDebInfo MinSizeRel)
string(TOUPPER "${cfg}" cfg_upper)
set("CMAKE_RUNTIME_OUTPUT_DIRECTORY_${cfg_upper}" "${CMAKE_BINARY_DIR}/bin")
endforeach()
# ANTLR
set(ANTLR4_GENERATED_DIR "${CMAKE_BINARY_DIR}/generated/antlr4")
# /
add_library(build_options INTERFACE)
target_compile_features(build_options INTERFACE cxx_std_17)
target_include_directories(build_options INTERFACE
"${PROJECT_SOURCE_DIR}/include"
"${PROJECT_SOURCE_DIR}/src"
"${ANTLR4_GENERATED_DIR}"
# 使 -Xexact-output-dir ANTLR
"${ANTLR4_GENERATED_DIR}/src/antlr4"
)
option(COMPILER_ENABLE_WARNINGS "Enable common compiler warnings" ON)
if(COMPILER_ENABLE_WARNINGS)
if(MSVC)
target_compile_options(build_options INTERFACE /W4)
else()
target_compile_options(build_options INTERFACE -Wall -Wextra -Wpedantic)
endif()
endif()
option(COMPILER_PARSE_ONLY "Build only the frontend parser pipeline" OFF)
# 使 third_party ANTLR4 C++ runtime
# third_party runtime third_party/antlr4-runtime-4.13.2/runtime/src
set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.13.2/runtime/src")
add_library(antlr4_runtime STATIC)
target_compile_features(antlr4_runtime PUBLIC cxx_std_17)
target_include_directories(antlr4_runtime PUBLIC
"${ANTLR4_RUNTIME_SRC_DIR}"
"${ANTLR4_RUNTIME_SRC_DIR}/atn"
"${ANTLR4_RUNTIME_SRC_DIR}/dfa"
"${ANTLR4_RUNTIME_SRC_DIR}/internal"
"${ANTLR4_RUNTIME_SRC_DIR}/misc"
"${ANTLR4_RUNTIME_SRC_DIR}/support"
"${ANTLR4_RUNTIME_SRC_DIR}/tree"
"${ANTLR4_RUNTIME_SRC_DIR}/tree/pattern"
"${ANTLR4_RUNTIME_SRC_DIR}/tree/xpath"
)
file(GLOB_RECURSE ANTLR4_RUNTIME_SOURCES CONFIGURE_DEPENDS
"${ANTLR4_RUNTIME_SRC_DIR}/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/atn/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/dfa/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/internal/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/misc/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/support/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/tree/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/tree/pattern/*.cpp"
"${ANTLR4_RUNTIME_SRC_DIR}/tree/xpath/*.cpp"
)
target_sources(antlr4_runtime PRIVATE ${ANTLR4_RUNTIME_SOURCES})
find_package(Threads REQUIRED)
target_link_libraries(antlr4_runtime PUBLIC Threads::Threads)
set(ANTLR4_RUNTIME_TARGET antlr4_runtime)
add_subdirectory(src)

@ -1,2 +1,218 @@
# test
# SysY 编译器课程实验C++
本仓库为“并行编译课程实验”提供一个 SysY 编译器的最小可运行示例,实验按 Lab1Lab6 逐步完成:
从前端(词法/语法分析与语法树处理到中端IR 生成、基本标量优化再到后端ARM64/AArch64 汇编生成、寄存器分配与后端优化),最后进行循环/并行相关优化。
## 1. 实验介绍
| 实验 | 名称 | 任务/目标 |
| --- | --- | --- |
| Lab1 | 语法树构建 | 基于 SysY 源程序完成语法分析与语法树构建 |
| Lab2 | 中间表示生成 | 将语法树翻译为 LLVM 风格的中间表示IR并输出 IR |
| Lab3 | 指令选择与汇编生成 | 将 IR 翻译为目标平台汇编代码(本项目以 ARM64/AArch64 为主) |
| Lab4 | 基本标量优化 | 实现常见的标量优化(如常量传播、死代码删除、简化 CFG 等) |
| Lab5 | 寄存器分配与后端优化 | 为后端生成的虚拟寄存器分配物理寄存器,并完成 spill/reload、冗余指令消除与局部后端优化 |
| Lab6 | 并行与循环优化 | 面向循环的优化(循环变换/并行化等),进一步提升程序性能 |
## 2. 参考资料
本仓库提供的示例代码和实验文档只是参考。我们非常鼓励大家在阅读当前仓库实现的同时,也结合自己的理解重新设计框架并完成实现,而不是机械照搬。
如果希望进一步参考编译相关项目和往届优秀实现,可以查看编译比赛官网的技术支持栏目:<https://compiler.educg.net/#/index?TYPE=26COM>。其中的“备赛推荐”整理了一些编译相关项目,也能看到往届优秀作品的开源实现,这些内容都很值得参考。
## 3. 头歌平台协作流程
头歌平台的代码托管方式与 GitHub/Gitee 类似。如果你希望基于当前仓库快速开始协作,可以参考下面这套流程。
### 3.1 组长 fork 课程仓库
组长打开课程仓库页面,点击右上角的 `Fork`,创建你们小组自己的仓库副本。后续组内开发统一基于这个 fork 后的仓库进行。
![组长 fork 课程仓库](doc/images/01.png)
### 3.2 组长邀请组员加入仓库
fork 完成后,组长进入自己的仓库页面,在右侧可以看到邀请码。把邀请码发给组员即可,组员不需要再 fork 课程仓库。
![组长查看邀请码](doc/images/02.png)
### 3.3 组员申请加入,组长审批通过
组员拿到邀请码后,可以在页面右上角的 `+` 菜单里选择 `加入项目`,然后提交加入申请。
![组员加入项目入口](doc/images/03.png)
申请发出后,组长到个人主页的待办事项中审批成员申请,同意后组员就可以正常参与仓库协作。
![组长审批组员申请](doc/images/04.png)
### 3.4 在本地克隆小组仓库并配置远端
组长和组员在成功加入小组仓库后,就可以从仓库页面复制 HTTPS 地址,在本地克隆代码:
![复制仓库 HTTPS 地址](doc/images/05.png)
下面示例使用 HTTPS 方式:
```bash
git clone <仓库 HTTPS 地址>
cd nudt-compiler-cpp
```
如果希望后续同步课程仓库更新,可以额外把课程主仓库配置为 `upstream`
```bash
git remote add upstream https://bdgit.educoder.net/NUDT-compiler/nudt-compiler-cpp.git
git remote -v
```
配置完成后,常见的远端分工如下:
- `origin`:你们小组 fork 后的仓库,日常提交代码、推送分支都使用这个远端。
- `upstream`:课程主仓库,通常用于查看或同步课程团队发布的更新。
如果后续需要同步主仓库更新,可以先抓取远端信息:
```bash
git fetch upstream
```
### 3.5 提交与协作建议
借助 Git 进行协作开发,是当前软件开发中非常常见的一种工作方式,也是这门课程里需要大家掌握的基本能力。如果你对 Git 还不太熟悉,可以先看一下网络上的 Git 教程,例如:<https://liaoxuefeng.com/books/git/introduction/index.html>
当然也没有必要一开始学得特别深入,只需要记住常见操作即可,例如 `clone`、`status`、`add`、`commit`、`pull`、`push`、分支切换与合并。遇到具体报错或不会处理的冲突时,可以把现象和命令发给大模型帮你分析。
Git Commit 提交的信息建议尽量写清楚,推荐使用下面的格式:
```text
<type>(<scope>): <subject>
```
常见的 `type` 有:
- `feat`:新增功能
- `fix`:修复 bug
- `refactor`:重构但不改变外部行为
- `docs`:文档修改
- `test`:测试相关
- `chore`:杂项维护
`scope` 用来说明改动的大致范围,例如 `frontend`、`irgen`、`backend`、`test`、`doc`。
`subject` 用一句简短的话说明“这次改了什么”。
例如:
```text
feat(irgen): 支持一元表达式生成
fix(frontend): 修复空语句解析错误
docs(doc): 补充实验环境配置说明
```
除了提交代码本身,也推荐大家把头歌平台上的协作功能真正用起来:
- `Issue` 适合用来拆分任务、记录 bug、整理讨论结果和跟踪待办。
- `PR` / `Merge Request` 适合用来做分支合并和代码评审。比较推荐的流程是:每个人在自己的分支上开发,完成一个相对独立的小功能后提交 PR再由组内其他同学帮忙检查实现思路、代码质量和测试结果。
## 4. 实验环境配置
### 4.1 系统建议
建议使用 Ubuntu 22.04 或 WSLUbuntu 22.04 环境)。
### 4.2 安装基础依赖
本项目使用 CMake + C++17 构建;前端基于 ANTLR运行 ANTLR 的 `antlr-*.jar` 需要 Java。
```bash
sudo apt update
sudo apt install -y build-essential cmake git openjdk-11-jre
```
### 4.3 安装 LLVM 工具链
`scripts/verify_ir.sh``--run` 模式下会调用 LLVM 工具链(`llc` 与 `clang`)将生成的 IR 编译、运行,并在存在同名 `.out` 时自动比对输出结果。
```bash
sudo apt update
sudo apt install -y llvm clang
```
### 4.4 安装 ARM64 交叉编译工具链与 QEMU
后续实验会生成 ARM64/AArch64 汇编代码,并使用 ARM64 交叉编译工具链完成汇编、链接;再用 QEMU 用户态模拟器运行生成的 ARM 可执行文件。
```bash
# 安装 ARM64 交叉编译工具链
sudo apt update
sudo apt install gcc-aarch64-linux-gnu
# 安装 QEMU 用户模式模拟器
sudo apt install qemu-user
```
## 5. 编译与运行
### 5.1 生成 Lexer/Parser
本仓库已内置 ANTLR jar`third_party/antlr-4.13.2-complete.jar`。
当前 CMake 只会收集构建目录中的 Lexer/Parser 生成文件,不会自动调用 ANTLR因此首次构建前需要先生成 Lexer/Parser 及相关生成文件。
生成文件不提交到仓库,统一输出到 `build/generated/antlr4/`
```bash
mkdir -p build/generated/antlr4
java -jar third_party/antlr-4.13.2-complete.jar \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o build/generated/antlr4 \
src/antlr4/SysY.g4
```
### 5.2 Lab1 语法树构建
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON
cmake --build build -j "$(nproc)"
```
该模式只构建前端解析与语法树打印,不编译 `sem` / `irgen` / `mir`,适合 Lab1。
构建成功后,可执行文件位于:`./build/bin/compiler`。
运行语法树打印:
```bash
./build/bin/compiler --emit-parse-tree test/test_case/functional/simple_add.sy
```
### 5.3 全量构建
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF
cmake --build build -j "$(nproc)"
```
该模式会继续编译 `sem` / `irgen` / `mir`,用于后续实验。
### 5.4 运行自检
运行帮助信息能正常输出,说明基本环境与可执行文件均正常:
```bash
./build/bin/compiler --help
```
若当前处于 Lab1只需检查语法树输出是否符合预期。
若需要跑完整编译流程自检,则先使用全量构建模式,再执行下面的命令:从 SysY 源码生成 AArch64 汇编,完成汇编、链接,在 QEMU 下运行结果程序,并与 `test/test_case` 下同名 `.out` 自动比对:
```bash
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/function/asm --run
```
如果最终看到 `输出匹配: test/test_case/simple_add.out`,说明当前示例用例 `return a + b` 的完整链路已经跑通。
但这条命令只适合做单个用例检查。完成对应实验后,不能只停留在 `simple_add`,还应覆盖 `test/test_case` 下全部测试用例;如有需要,也可以自行编写批量测试脚本统一执行。

@ -0,0 +1,20 @@
// 包装 ANTLR4提供简易的解析入口。
#pragma once
#include <memory>
#include <string>
#include "SysYLexer.h"
#include "SysYParser.h"
#include "antlr4-runtime.h"
struct AntlrResult {
std::unique_ptr<antlr4::ANTLRInputStream> input;
std::unique_ptr<SysYLexer> lexer;
std::unique_ptr<antlr4::CommonTokenStream> tokens;
std::unique_ptr<SysYParser> parser;
antlr4::tree::ParseTree* tree = nullptr; // owned by parser
};
// 解析指定文件,发生错误时抛出 std::runtime_error。
AntlrResult ParseFileWithAntlr(const std::string& path);

@ -0,0 +1,9 @@
#pragma once
#include <iosfwd>
#include "antlr4-runtime.h"
// 以树状缩进形式直接打印 ANTLR parse tree。
void PrintSyntaxTree(antlr4::tree::ParseTree* tree, antlr4::Parser* parser,
std::ostream& os);

@ -0,0 +1,73 @@
#pragma once
#include "ir/IR.h"
#include <cstdint>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
class DominatorTree {
public:
explicit DominatorTree(Function& function);
void Recalculate();
Function& GetFunction() const { return *function_; }
bool IsReachable(BasicBlock* block) const;
bool Dominates(BasicBlock* dom, BasicBlock* node) const;
bool Dominates(Instruction* dom, Instruction* user) const;
BasicBlock* GetIDom(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetChildren(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetReversePostOrder() const {
return reverse_post_order_;
}
private:
Function* function_ = nullptr;
std::vector<BasicBlock*> reverse_post_order_;
std::unordered_map<BasicBlock*, std::size_t> block_index_;
std::vector<std::vector<std::uint8_t>> dominates_;
std::unordered_map<BasicBlock*, BasicBlock*> immediate_dominator_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_children_;
};
struct Loop {
BasicBlock* header = nullptr;
std::unordered_set<BasicBlock*> blocks;
std::vector<BasicBlock*> block_list;
std::vector<BasicBlock*> latches;
std::vector<BasicBlock*> exiting_blocks;
std::vector<BasicBlock*> exit_blocks;
BasicBlock* preheader = nullptr;
Loop* parent = nullptr;
std::vector<Loop*> subloops;
bool Contains(BasicBlock* block) const;
bool Contains(const Loop* other) const;
bool IsInnermost() const;
};
class LoopInfo {
public:
LoopInfo(Function& function, const DominatorTree& dom_tree);
void Recalculate();
const std::vector<std::unique_ptr<Loop>>& GetLoops() const { return loops_; }
std::vector<Loop*> GetTopLevelLoops() const;
std::vector<Loop*> GetLoopsInPostOrder() const;
Loop* GetLoopFor(BasicBlock* block) const;
private:
Function* function_ = nullptr;
const DominatorTree* dom_tree_ = nullptr;
std::vector<std::unique_ptr<Loop>> loops_;
std::vector<Loop*> top_level_loops_;
std::unordered_map<BasicBlock*, Loop*> block_to_loop_;
};
} // namespace ir

@ -0,0 +1,799 @@
#pragma once
#include "utils.h"
#include <iosfwd>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
class Value;
class User;
class BasicBlock;
class Function;
class Instruction;
class Argument;
class ConstantInt;
class ConstantFloat;
class ConstantI1;
class ConstantArrayValue;
class Type;
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
void SetValue(Value* value) { value_ = value; }
void SetUser(User* user) { user_ = user; }
void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
class Context {
public:
Context() = default;
~Context();
ConstantInt* GetConstInt(int v);
ConstantI1* GetConstBool(bool v);
std::string NextTemp();
std::string NextBlockName(const std::string& prefix = "bb");
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<bool, std::unique_ptr<ConstantI1>> const_bools_;
int temp_index_ = -1;
int block_index_ = -1;
};
class Type {
public:
enum class Kind {
Void,
Int1,
Int32,
Float,
Label,
Function,
Pointer,
PtrInt32 = Pointer,
Array
};
explicit Type(Kind kind);
Type(Kind kind, std::shared_ptr<Type> element_type, size_t num_elements = 0);
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static const std::shared_ptr<Type>& GetBoolType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> pointee = nullptr);
static const std::shared_ptr<Type>& GetPtrInt32Type();
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> element_type,
size_t num_elements);
Kind GetKind() const { return kind_; }
bool IsVoid() const { return kind_ == Kind::Void; }
bool IsInt1() const { return kind_ == Kind::Int1; }
bool IsInt32() const { return kind_ == Kind::Int32; }
bool IsFloat() const { return kind_ == Kind::Float; }
bool IsLabel() const { return kind_ == Kind::Label; }
bool IsFunction() const { return kind_ == Kind::Function; }
bool IsBool() const { return kind_ == Kind::Int1; }
bool IsPointer() const { return kind_ == Kind::Pointer; }
bool IsPtrInt32() const { return IsPointer(); }
bool IsArray() const { return kind_ == Kind::Array; }
std::shared_ptr<Type> GetElementType() const { return element_type_; }
size_t GetNumElements() const { return num_elements_; }
int GetSize() const;
void Print(std::ostream& os) const;
private:
Kind kind_;
std::shared_ptr<Type> element_type_;
size_t num_elements_ = 0;
};
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const { return type_; }
const std::string& GetName() const { return name_; }
void SetName(std::string name) { name_ = std::move(name); }
bool IsVoid() const { return type_ && type_->IsVoid(); }
bool IsInt32() const { return type_ && type_->IsInt32(); }
bool IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool IsFloat() const { return type_ && type_->IsFloat(); }
bool IsBool() const { return type_ && type_->IsBool(); }
bool IsArray() const { return type_ && type_->IsArray(); }
bool IsLabel() const { return type_ && type_->IsLabel(); }
virtual bool IsConstant() const { return false; }
virtual bool IsInstruction() const { return false; }
virtual bool IsUser() const { return false; }
virtual bool IsFunction() const { return false; }
virtual bool IsArgument() const { return false; }
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const { return uses_; }
void ReplaceAllUsesWith(Value* new_value);
virtual void Print(std::ostream& os) const;
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
template <typename T>
inline bool isa(const Value* value) {
return value && T::classof(value);
}
template <typename T>
inline T* dyncast(Value* value) {
return isa<T>(value) ? dynamic_cast<T*>(value) : nullptr;
}
template <typename T>
inline const T* dyncast(const Value* value) {
return isa<T>(value) ? dynamic_cast<const T*>(value) : nullptr;
}
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
bool IsConstant() const override final { return true; }
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int value);
int GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantInt*>(value) != nullptr;
}
private:
int value_;
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float value);
float GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantFloat*>(value) != nullptr;
}
private:
float value_;
};
class ConstantI1 : public ConstantValue {
public:
ConstantI1(std::shared_ptr<Type> ty, bool value);
bool GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantI1*>(value) != nullptr;
}
private:
bool value_;
};
class ConstantArrayValue : public Value {
public:
ConstantArrayValue(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name = "");
const std::vector<Value*>& GetElements() const { return elements_; }
const std::vector<size_t>& GetDims() const { return dims_; }
void Print(std::ostream& os) const override;
static bool classof(const Value* value) {
return value && dynamic_cast<const ConstantArrayValue*>(value) != nullptr;
}
private:
std::vector<Value*> elements_;
std::vector<size_t> dims_;
};
enum class Opcode {
Add,
Sub,
Mul,
Div,
Rem,
FAdd,
FSub,
FMul,
FDiv,
FRem,
And,
Or,
Xor,
Shl,
AShr,
LShr,
ICmpEQ,
ICmpNE,
ICmpLT,
ICmpGT,
ICmpLE,
ICmpGE,
FCmpEQ,
FCmpNE,
FCmpLT,
FCmpGT,
FCmpLE,
FCmpGE,
Neg,
Not,
FNeg,
FtoI,
IToF,
Call,
CondBr,
Br,
Return,
Ret = Return,
Unreachable,
Alloca,
Load,
Store,
Memset,
GetElementPtr,
Phi,
Zext
};
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
bool IsUser() const override final { return true; }
size_t GetNumOperands() const { return operands_.size(); }
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
void AddOperand(Value* value);
void AddOperands(const std::vector<Value*>& values);
void RemoveOperand(size_t index);
void ClearAllOperands();
protected:
std::vector<Use> operands_;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> type, std::string name, size_t index);
size_t GetIndex() const { return index_; }
bool IsArgument() const override final { return true; }
static bool classof(const Value* value) {
return value && dynamic_cast<const Argument*>(value) != nullptr;
}
private:
size_t index_;
};
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> object_type, const std::string& name,
bool is_const = false, Value* init = nullptr);
bool IsConstant() const override { return is_const_; }
bool HasInitializer() const { return init_ != nullptr; }
Value* GetInitializer() const { return init_; }
std::shared_ptr<Type> GetObjectType() const { return object_type_; }
void SetConstant(bool is_const) { is_const_ = is_const; }
void SetInitializer(Value* init) { init_ = init; }
static bool classof(const Value* value) {
return value && dynamic_cast<const GlobalValue*>(value) != nullptr;
}
private:
std::shared_ptr<Type> object_type_;
bool is_const_ = false;
Value* init_ = nullptr;
};
class Instruction : public User {
public:
Instruction(Opcode opcode, std::shared_ptr<Type> ty,
BasicBlock* parent = nullptr, const std::string& name = "");
bool IsInstruction() const override final { return true; }
Opcode GetOpcode() const { return opcode_; }
bool IsTerminator() const;
BasicBlock* GetParent() const { return parent_; }
void SetParent(BasicBlock* parent) { parent_ = parent; }
static bool classof(const Value* value) {
return value && value->IsInstruction();
}
private:
Opcode opcode_;
BasicBlock* parent_;
};
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetLhs() const { return GetOperand(0); }
Value* GetRhs() const { return GetOperand(1); }
static bool classof(const Value* value);
};
class UnaryInst : public Instruction {
public:
UnaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* operand,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetOprd() const { return GetOperand(0); }
static bool classof(const Value* value);
};
class ReturnInst : public Instruction {
public:
ReturnInst(Value* value = nullptr, BasicBlock* parent = nullptr);
bool HasReturnValue() const { return GetNumOperands() > 0; }
Value* GetReturnValue() const {
return HasReturnValue() ? GetOperand(0) : nullptr;
}
Value* GetValue() const { return GetReturnValue(); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Return;
}
};
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> allocated_type, BasicBlock* parent = nullptr,
const std::string& name = "");
std::shared_ptr<Type> GetAllocatedType() const { return allocated_type_; }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Alloca;
}
private:
std::shared_ptr<Type> allocated_type_;
};
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> value_type, Value* ptr,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetPtr() const { return GetOperand(0); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Load;
}
};
class StoreInst : public Instruction {
public:
StoreInst(Value* value, Value* ptr, BasicBlock* parent = nullptr);
Value* GetValue() const { return GetOperand(0); }
Value* GetPtr() const { return GetOperand(1); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Store;
}
};
class UncondBrInst : public Instruction {
public:
UncondBrInst(BasicBlock* dest, BasicBlock* parent = nullptr);
BasicBlock* GetDest() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Br;
}
};
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* then_block, BasicBlock* else_block,
BasicBlock* parent = nullptr);
Value* GetCondition() const { return GetOperand(0); }
BasicBlock* GetThenBlock() const;
BasicBlock* GetElseBlock() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::CondBr;
}
};
class UnreachableInst : public Instruction {
public:
explicit UnreachableInst(BasicBlock* parent = nullptr);
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Unreachable;
}
};
class CallInst : public Instruction {
public:
CallInst(Function* callee, const std::vector<Value*>& args = {},
BasicBlock* parent = nullptr, const std::string& name = "");
Function* GetCallee() const;
std::vector<Value*> GetArguments() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Call;
}
};
class GetElementPtrInst : public Instruction {
public:
GetElementPtrInst(std::shared_ptr<Type> source_type, Value* ptr,
const std::vector<Value*>& indices,
BasicBlock* parent = nullptr,
const std::string& name = "");
Value* GetPointer() const { return GetOperand(0); }
size_t GetNumIndices() const {
return GetNumOperands() > 0 ? GetNumOperands() - 1 : 0;
}
Value* GetIndex(size_t index) const { return GetOperand(index + 1); }
std::shared_ptr<Type> GetSourceType() const { return source_type_; }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() ==
Opcode::GetElementPtr;
}
private:
std::shared_ptr<Type> source_type_;
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> type, BasicBlock* parent = nullptr,
const std::string& name = "");
void AddIncoming(Value* value, BasicBlock* block);
int GetNumIncomings() const {
return static_cast<int>(GetNumOperands() / 2);
}
Value* GetIncomingValue(int index) const {
return GetOperand(static_cast<size_t>(2 * index));
}
BasicBlock* GetIncomingBlock(int index) const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Phi;
}
};
class ZextInst : public Instruction {
public:
ZextInst(Value* value, std::shared_ptr<Type> target_type,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetValue() const { return GetOperand(0); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Zext;
}
};
class MemsetInst : public Instruction {
public:
MemsetInst(Value* dst, Value* value, Value* len, Value* is_volatile,
BasicBlock* parent = nullptr);
Value* GetDest() const { return GetOperand(0); }
Value* GetValue() const { return GetOperand(1); }
Value* GetLength() const { return GetOperand(2); }
Value* GetIsVolatile() const { return GetOperand(3); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Memset;
}
};
class BasicBlock : public Value {
public:
explicit BasicBlock(const std::string& name);
BasicBlock(Function* parent, const std::string& name);
Function* GetParent() const { return parent_; }
void SetParent(Function* parent) { parent_ = parent; }
bool HasTerminator() const;
std::vector<std::unique_ptr<Instruction>>& GetInstructions() {
return instructions_;
}
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const {
return instructions_;
}
void EraseInstruction(Instruction* inst);
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
const std::vector<BasicBlock*>& GetPredecessors() const {
return predecessors_;
}
const std::vector<BasicBlock*>& GetSuccessors() const {
return successors_;
}
template <typename T, typename... Args>
T* Insert(size_t index, Args&&... args) {
if (index > instructions_.size()) {
throw std::out_of_range("BasicBlock insert index out of range");
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin() + static_cast<long long>(index),
std::move(inst));
return ptr;
}
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock already has terminator");
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
return ptr;
}
static bool classof(const Value* value) {
return value && dynamic_cast<const BasicBlock*>(value) != nullptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
class Function : public Value {
public:
Function(std::string name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types = {},
const std::vector<std::string>& param_names = {},
bool is_external = false);
bool IsFunction() const override final { return true; }
std::shared_ptr<Type> GetReturnType() const { return return_type_; }
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const {
return param_types_;
}
const std::vector<std::unique_ptr<Argument>>& GetArguments() const {
return arguments_;
}
Argument* GetArgument(size_t index) const;
bool IsExternal() const { return is_external_; }
void SetExternal(bool is_external) { is_external_ = is_external; }
void SetEffectInfo(bool reads_global_memory, bool writes_global_memory,
bool reads_param_memory, bool writes_param_memory,
bool has_io, bool has_unknown_effects, bool is_recursive) {
reads_global_memory_ = reads_global_memory;
writes_global_memory_ = writes_global_memory;
reads_param_memory_ = reads_param_memory;
writes_param_memory_ = writes_param_memory;
has_io_ = has_io;
has_unknown_effects_ = has_unknown_effects;
is_recursive_ = is_recursive;
}
bool ReadsGlobalMemory() const { return reads_global_memory_; }
bool WritesGlobalMemory() const { return writes_global_memory_; }
bool ReadsParamMemory() const { return reads_param_memory_; }
bool WritesParamMemory() const { return writes_param_memory_; }
bool HasIO() const { return has_io_; }
bool HasUnknownEffects() const { return has_unknown_effects_; }
bool IsRecursive() const { return is_recursive_; }
bool MayReadMemory() const {
return has_unknown_effects_ || reads_global_memory_ || writes_global_memory_ ||
reads_param_memory_ || writes_param_memory_;
}
bool MayWriteMemory() const {
return has_unknown_effects_ || writes_global_memory_ || writes_param_memory_;
}
bool HasObservableSideEffects() const {
return has_unknown_effects_ || writes_global_memory_ ||
writes_param_memory_ || has_io_;
}
bool CanDiscardUnusedCall() const {
return !has_unknown_effects_ && !writes_global_memory_ &&
!writes_param_memory_ && !has_io_ && !is_recursive_;
}
BasicBlock* GetEntryBlock() const { return entry_; }
BasicBlock* GetEntry() const { return entry_; }
void SetEntryBlock(BasicBlock* bb) { entry_ = bb; }
BasicBlock* EnsureEntryBlock();
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* AddBlock(std::unique_ptr<BasicBlock> block);
std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() { return blocks_; }
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const {
return blocks_;
}
static bool classof(const Value* value) {
return value && value->IsFunction();
}
private:
std::shared_ptr<Type> return_type_;
std::vector<std::shared_ptr<Type>> param_types_;
std::vector<std::unique_ptr<Argument>> arguments_;
bool is_external_ = false;
bool reads_global_memory_ = false;
bool writes_global_memory_ = false;
bool reads_param_memory_ = false;
bool writes_param_memory_ = false;
bool has_io_ = false;
bool has_unknown_effects_ = true;
bool is_recursive_ = false;
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
};
class Module {
public:
Module() = default;
Context& GetContext() { return context_; }
const Context& GetContext() const { return context_; }
Function* CreateFunction(const std::string& name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types = {},
const std::vector<std::string>& param_names = {},
bool is_external = false);
Function* GetFunction(const std::string& name) const;
const std::vector<std::unique_ptr<Function>>& GetFunctions() const {
return functions_;
}
GlobalValue* CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> object_type,
bool is_const = false, Value* init = nullptr);
GlobalValue* GetGlobalValue(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalValue>>& GetGlobalValues() const {
return globals_;
}
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
std::map<std::string, Function*> function_map_;
std::vector<std::unique_ptr<GlobalValue>> globals_;
std::map<std::string, GlobalValue*> global_map_;
};
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const { return insert_block_; }
ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
ConstantI1* CreateConstBool(bool v);
ConstantArrayValue* CreateConstArray(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name = "");
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateRem(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateAnd(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateOr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateXor(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateShl(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateAShr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateLShr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
BinaryInst* CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
UnaryInst* CreateNeg(Value* operand, const std::string& name = "");
UnaryInst* CreateNot(Value* operand, const std::string& name = "");
UnaryInst* CreateFNeg(Value* operand, const std::string& name = "");
UnaryInst* CreateFtoI(Value* operand, const std::string& name = "");
UnaryInst* CreateIToF(Value* operand, const std::string& name = "");
AllocaInst* CreateAlloca(std::shared_ptr<Type> allocated_type,
const std::string& name = "");
LoadInst* CreateLoad(Value* ptr, std::shared_ptr<Type> value_type,
const std::string& name = "");
LoadInst* CreateLoad(Value* ptr, const std::string& name = "") {
return CreateLoad(ptr, Type::GetInt32Type(), name);
}
StoreInst* CreateStore(Value* val, Value* ptr);
UncondBrInst* CreateBr(BasicBlock* dest);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* then_bb,
BasicBlock* else_bb);
ReturnInst* CreateRet(Value* val = nullptr);
UnreachableInst* CreateUnreachable();
CallInst* CreateCall(Function* callee, const std::vector<Value*>& args,
const std::string& name = "");
GetElementPtrInst* CreateGEP(Value* ptr, std::shared_ptr<Type> source_type,
const std::vector<Value*>& indices,
const std::string& name = "");
PhiInst* CreatePhi(std::shared_ptr<Type> type, const std::string& name = "");
ZextInst* CreateZext(Value* val, std::shared_ptr<Type> target_type,
const std::string& name = "");
MemsetInst* CreateMemset(Value* dst, Value* val, Value* len,
Value* is_volatile);
private:
Context& ctx_;
BasicBlock* insert_block_;
};
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
inline std::ostream& operator<<(std::ostream& os, const Type& type) {
type.Print(os);
return os;
}
inline std::ostream& operator<<(std::ostream& os, const Value& value) {
value.Print(os);
return os;
}
} // namespace ir

@ -0,0 +1,373 @@
// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。
//
// 当前已经实现:
// 1. 基础类型系统void / i32 / i32*
// 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction
// 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表
// - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现
//
// 当前尚未实现或只做了最小占位:
// 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
//
// 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
//
// 建议的扩展顺序:
// 1. 先补更多指令和类型
// 2. 再补控制流相关 IR
// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架
#pragma once
#include <iosfwd>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
class Type;
class Value;
class User;
class ConstantValue;
class ConstantInt;
class GlobalValue;
class Instruction;
class BasicBlock;
class Function;
// Use 表示一个 Value 的一次使用记录。
// 当前实现设计:
// - value被使用的值
// - user使用该值的 User
// - operand_index该值在 user 操作数列表中的位置
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
void SetValue(Value* value) { value_ = value; }
void SetUser(User* user) { user_ = user; }
void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。
class Context {
public:
Context() = default;
~Context();
// 去重创建 i32 常量。
ConstantInt* GetConstInt(int v);
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
int temp_index_ = -1;
};
class Type {
public:
enum class Kind { Void, Int1, Int32, Float, Label, Function, PtrInt32, Array };
explicit Type(Kind k);
// 静态工厂方法:返回对应类型的共享单例
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static const std::shared_ptr<Type>& GetFunctionType();
static const std::shared_ptr<Type>& GetBoolType();
static const std::shared_ptr<Type>& GetPtrInt32Type();
static const std::shared_ptr<Type>& GetArrayType();
Kind GetKind() const;
// 便捷类型判断
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat() const;
bool IsLabel() const;
bool IsFunction() const;
bool IsBool() const;
bool IsPtrInt32() const;
bool IsArray() const;
private:
Kind kind_;
};
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
void ReplaceAllUsesWith(Value* new_value);
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
// ConstantValue 是常量体系的基类。
// 当前只实现了 ConstantInt后续可继续扩展更多常量种类。
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int v);
int GetValue() const { return value_; }
private:
int value_{};
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
class ConstantI1 : public ConstantValue {
public:
ConstantI1(std::shared_ptr<Type> ty, bool v);
int GetValue() const { return value_; }
private:
bool value_{};
};
class ConstantArrayValue : public Value {
public:
ConstantArrayValue()
};
//暂时先设计这些
enum class Opcode {
// 二元算术
Add,Sub,Mul,Div,Rem,FAdd,FSub,FMul,FDiv,FRem,
// 位运算
And,Or,Xor,Shl,AShr,LShr,
// 整数比较
ICmpEQ,ICmpNE,ICmpLT,ICmpGT,ICmpLE,ICmpGE,
// 浮点比较
FCmpEQ,FCmpNE,FCmpLT,FCmpGT,FCmpLE,FCmpGE,
// 一元运算
Neg,Not,FNeg,FtoI,IToF,
// 调用与终止
Call,CondBr,Br,Return,Unreachable,
// 内存操作
Alloca,Load,Store,Memset,
// 其他
GetElementPtr,Phi,Zext
};
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
size_t GetNumOperands() const;
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
protected:
void AddOperand(Value* value);
private:
std::vector<Value*> operands_;
};
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
};
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
Opcode GetOpcode() const;
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
private:
Opcode opcode_;
BasicBlock* parent_ = nullptr;
};
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
};
class ReturnInst : public Instruction {
public:
ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
Value* GetValue() const;
};
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
};
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* GetPtr() const;
};
class StoreInst : public Instruction {
public:
StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr);
Value* GetValue() const;
Value* GetPtr() const;
};
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
Function* GetParent() const;
void SetParent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock 已有 terminator不能继续追加指令: " +
name_);
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
// Function 当前也采用了最小实现。
// 需要特别注意:由于项目里还没有单独的 FunctionType
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
// 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value {
public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。
Function(std::string name, std::shared_ptr<Type> ret_type);
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
private:
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
};
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
};
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
// 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
ReturnInst* CreateRet(Value* v);
private:
Context& ctx_;
BasicBlock* insert_block_;
};
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
} // namespace ir

@ -0,0 +1,24 @@
#pragma once
namespace ir {
class Module;
void RunMem2Reg(Module& module);
bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunFunctionInlining(Module& module);
bool RunCSE(Module& module);
bool RunGVN(Module& module);
bool RunLoadStoreElim(Module& module);
bool RunDCE(Module& module);
bool RunCFGSimplify(Module& module);
bool RunLICM(Module& module);
bool RunLoopMemoryPromotion(Module& module);
bool RunLoopUnswitch(Module& module);
bool RunLoopStrengthReduction(Module& module);
bool RunLoopUnroll(Module& module);
bool RunLoopFission(Module& module);
void RunIRPassPipeline(Module& module);
} // namespace ir

@ -0,0 +1,41 @@
#pragma once
#include <iterator>
namespace ir {
template <typename IterT> struct range {
using iterator = IterT;
using value_type = typename std::iterator_traits<iterator>::value_type;
using reference = typename std::iterator_traits<iterator>::reference;
private:
iterator b;
iterator e;
public:
explicit range(iterator b, iterator e) : b(b), e(e) {}
iterator begin() { return b; }
iterator end() { return e; }
iterator begin() const { return b; }
iterator end() const { return e; }
auto size() const { return std::distance(b, e); }
auto empty() const { return b == e; }
};
//! create `range` object from iterator pair [begin, end)
template <typename IterT> range<IterT> make_range(IterT b, IterT e) {
return range<IterT>(b, e);
}
//! create `range` object from a container who has `begin()` and `end()` methods
template <typename ContainerT>
range<typename ContainerT::iterator> make_range(ContainerT &c) {
return make_range(c.begin(), c.end());
}
//! create `range` object from a container who has `begin()` and `end()` methods
template <typename ContainerT>
range<typename ContainerT::const_iterator> make_range(const ContainerT &c) {
return make_range(c.begin(), c.end());
}
} // namespace ir

@ -0,0 +1,193 @@
#pragma once
#include <any>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/Sema.h"
#include "sem/SymbolTable.h"
class IRGenImpl final : public SysYBaseVisitor {
public:
IRGenImpl(ir::Module& module, const SemanticContext& sema);
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
private:
enum class FlowState {
Continue,
Terminated,
};
struct TypedValue {
ir::Value* value = nullptr;
SemanticType type = SemanticType::Int;
bool is_array = false;
std::vector<int> dims;
};
struct LValueInfo {
SymbolEntry* symbol = nullptr;
ir::Value* addr = nullptr;
SemanticType type = SemanticType::Int;
bool is_array = false;
std::vector<int> dims;
bool root_param_array_no_index = false;
};
struct LoopContext {
ir::BasicBlock* cond_block = nullptr;
ir::BasicBlock* exit_block = nullptr;
};
struct InitExprSlot {
size_t index = 0;
SysYParser::ExpContext* expr = nullptr;
};
[[noreturn]] void ThrowError(const antlr4::ParserRuleContext* ctx,
const std::string& message) const;
void ApplyFunctionSema(const std::string& name, ir::Function& function);
void RegisterBuiltinFunctions();
void PredeclareTopLevel(SysYParser::CompUnitContext& ctx);
void PredeclareFunction(SysYParser::FuncDefContext& ctx);
void PredeclareGlobalDecl(SysYParser::DeclContext& ctx);
void EmitGlobalDecl(SysYParser::DeclContext& ctx);
void EmitFunction(SysYParser::FuncDefContext& ctx);
void BindFunctionParams(SysYParser::FuncDefContext& ctx, ir::Function& func);
void EmitBlock(SysYParser::BlockContext& ctx, bool create_scope = true);
FlowState EmitBlockItem(SysYParser::BlockItemContext& ctx);
FlowState EmitStmt(SysYParser::StmtContext& ctx);
void EmitDecl(SysYParser::DeclContext& ctx, bool is_global);
void EmitVarDecl(SysYParser::VarDeclContext* ctx, bool is_global, bool is_const);
void EmitConstDecl(SysYParser::ConstDeclContext* ctx, bool is_global);
void EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType type);
void EmitGlobalConstDef(SysYParser::ConstDefContext& ctx, SemanticType type);
void EmitLocalVarDef(SysYParser::VarDefContext& ctx, SemanticType type, bool is_const);
void EmitLocalConstDef(SysYParser::ConstDefContext& ctx, SemanticType type);
std::string ExpectIdent(const antlr4::ParserRuleContext& ctx,
antlr4::tree::TerminalNode* ident) const;
SemanticType ParseBType(SysYParser::BTypeContext* ctx) const;
SemanticType ParseFuncType(SysYParser::FuncTypeContext* ctx) const;
std::shared_ptr<ir::Type> GetIRScalarType(SemanticType type) const;
std::shared_ptr<ir::Type> BuildArrayType(SemanticType base_type,
const std::vector<int>& dims) const;
std::vector<int> ParseArrayDims(const std::vector<SysYParser::ConstExpContext*>& dims_ctx);
std::vector<int> ParseParamDims(SysYParser::FuncFParamContext& ctx);
FunctionTypeInfo BuildFunctionTypeInfo(SysYParser::FuncDefContext& ctx);
std::vector<std::shared_ptr<ir::Type>> BuildFunctionIRParamTypes(
const FunctionTypeInfo& function_type) const;
std::vector<std::string> BuildFunctionIRParamNames(SysYParser::FuncDefContext& ctx) const;
TypedValue EmitExp(SysYParser::ExpContext& ctx);
TypedValue EmitAddExp(SysYParser::AddExpContext& ctx);
TypedValue EmitMulExp(SysYParser::MulExpContext& ctx);
TypedValue EmitUnaryExp(SysYParser::UnaryExpContext& ctx);
TypedValue EmitPrimaryExp(SysYParser::PrimaryExpContext& ctx);
TypedValue EmitRelExp(SysYParser::RelExpContext& ctx);
TypedValue EmitEqExp(SysYParser::EqExpContext& ctx);
TypedValue EmitLValValue(SysYParser::LValContext& ctx);
LValueInfo ResolveLVal(SysYParser::LValContext& ctx);
ir::Value* GenLValAddr(SysYParser::LValContext& ctx);
void EmitCond(SysYParser::CondContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
void EmitLOrCond(SysYParser::LOrExpContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
void EmitLAndCond(SysYParser::LAndExpContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
TypedValue CastScalar(TypedValue value, SemanticType target_type,
const antlr4::ParserRuleContext* ctx);
ir::Value* CastToCondition(TypedValue value,
const antlr4::ParserRuleContext* ctx);
TypedValue NormalizeLogicalValue(TypedValue value,
const antlr4::ParserRuleContext* ctx);
bool IsNumeric(const TypedValue& value) const;
bool IsSameDims(const std::vector<int>& lhs, const std::vector<int>& rhs) const;
ConstantValue ParseNumber(SysYParser::NumberContext& ctx) const;
ConstantValue EvalConstExp(SysYParser::ExpContext& ctx);
ConstantValue EvalConstAddExp(SysYParser::AddExpContext& ctx);
ConstantValue EvalConstMulExp(SysYParser::MulExpContext& ctx);
ConstantValue EvalConstUnaryExp(SysYParser::UnaryExpContext& ctx);
ConstantValue EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx);
ConstantValue EvalConstLVal(SysYParser::LValContext& ctx);
ConstantValue ConvertConst(ConstantValue value, SemanticType target_type) const;
bool IsZeroConstant(const ConstantValue& value) const;
bool IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type);
bool IsExplicitZeroInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type);
std::vector<ConstantValue> FlattenConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims);
std::vector<ConstantValue> FlattenInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims);
std::vector<InitExprSlot> FlattenLocalInitVal(SysYParser::InitValContext* ctx,
const std::vector<int>& dims);
void FlattenConstInitValImpl(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<ConstantValue>& out);
void FlattenInitValImpl(SysYParser::InitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<ConstantValue>& out);
void FlattenLocalInitValImpl(SysYParser::InitValContext* ctx,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<InitExprSlot>& out);
size_t CountArrayElements(const std::vector<int>& dims, size_t start = 0) const;
size_t AlignInitializerCursor(const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t cursor) const;
size_t FlattenIndices(const std::vector<int>& dims,
const std::vector<int>& indices) const;
ConstantValue ZeroConst(SemanticType type) const;
ir::Value* ZeroIRValue(SemanticType type);
ir::Value* CreateTypedConstant(const ConstantValue& value);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> allocated_type,
const std::string& name);
void ZeroInitializeLocalArray(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims);
void StoreLocalArrayElements(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims,
const std::vector<InitExprSlot>& init_slots);
ir::Value* CreateArrayElementAddr(ir::Value* base_addr, bool is_param_array,
SemanticType base_type,
const std::vector<int>& dims,
const std::vector<ir::Value*>& indices,
const antlr4::ParserRuleContext* ctx);
std::string NextTemp();
std::string NextBlockName(const std::string& prefix);
ir::Module& module_;
const SemanticContext& sema_;
ir::IRBuilder builder_;
SymbolTable symbols_;
ir::Function* current_function_ = nullptr;
SemanticType current_return_type_ = SemanticType::Void;
std::vector<LoopContext> loop_stack_;
bool builtins_registered_ = false;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema);

@ -0,0 +1,296 @@
#pragma once
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace ir {
class Module;
}
namespace mir {
class MIRContext {
public:
MIRContext() = default;
};
MIRContext& DefaultContext();
enum class ValueType { Void, I1, I32, F32, Ptr };
enum class RegClass { GPR, FPR };
enum class CondCode { EQ, NE, LT, GT, LE, GE };
enum class StackObjectKind { Local, Spill, SavedGPR, SavedFPR };
enum class AddrBaseKind { None, FrameObject, Global, VReg };
enum class OperandKind { Invalid, VReg, Imm, Block, Symbol };
struct PhysReg {
RegClass reg_class = RegClass::GPR;
int index = -1;
bool IsValid() const { return index >= 0; }
bool operator==(const PhysReg& rhs) const {
return reg_class == rhs.reg_class && index == rhs.index;
}
};
bool IsGPR(ValueType type);
bool IsFPR(ValueType type);
int GetValueSize(ValueType type);
int GetValueAlign(ValueType type);
const char* GetPhysRegName(PhysReg reg, ValueType type);
class MachineOperand {
public:
MachineOperand() = default;
static MachineOperand VReg(int reg);
static MachineOperand Imm(std::int64_t value);
static MachineOperand Block(std::string name);
static MachineOperand Symbol(std::string name);
OperandKind GetKind() const { return kind_; }
int GetVReg() const { return vreg_; }
std::int64_t GetImm() const { return imm_; }
const std::string& GetText() const { return text_; }
private:
MachineOperand(OperandKind kind, int vreg, std::int64_t imm, std::string text);
OperandKind kind_ = OperandKind::Invalid;
int vreg_ = -1;
std::int64_t imm_ = 0;
std::string text_;
};
struct AddressExpr {
AddrBaseKind base_kind = AddrBaseKind::None;
int base_index = -1;
std::string symbol;
std::int64_t const_offset = 0;
std::vector<std::pair<int, std::int64_t>> scaled_vregs;
};
struct StackObject {
int index = -1;
StackObjectKind kind = StackObjectKind::Local;
int size = 0;
int align = 1;
int offset = 0;
std::string name;
};
struct VRegInfo {
int id = -1;
ValueType type = ValueType::Void;
};
struct Allocation {
enum class Kind { Unassigned, PhysReg, Spill };
Kind kind = Kind::Unassigned;
PhysReg phys;
int stack_object = -1;
};
class MachineInstr {
public:
enum class Opcode {
Arg,
Copy,
Load,
Store,
Lea,
Add,
Sub,
Mul,
Div,
Rem,
And,
Or,
Xor,
Shl,
AShr,
LShr,
FAdd,
FSub,
FMul,
FDiv,
FNeg,
ICmp,
FCmp,
ZExt,
ItoF,
FtoI,
Br,
CondBr,
Call,
Ret,
Memset,
Unreachable,
};
explicit MachineInstr(Opcode opcode,
std::vector<MachineOperand> operands = {});
Opcode GetOpcode() const { return opcode_; }
const std::vector<MachineOperand>& GetOperands() const { return operands_; }
std::vector<MachineOperand>& GetOperands() { return operands_; }
void SetCondCode(CondCode code) { cond_code_ = code; }
CondCode GetCondCode() const { return cond_code_; }
void SetAddress(AddressExpr address) {
address_ = std::move(address);
has_address_ = true;
}
bool HasAddress() const { return has_address_; }
const AddressExpr& GetAddress() const { return address_; }
AddressExpr& GetAddress() { return address_; }
void SetCallInfo(std::string callee, std::vector<ValueType> arg_types,
ValueType return_type) {
callee_ = std::move(callee);
call_arg_types_ = std::move(arg_types);
call_return_type_ = return_type;
}
const std::string& GetCallee() const { return callee_; }
const std::vector<ValueType>& GetCallArgTypes() const { return call_arg_types_; }
ValueType GetCallReturnType() const { return call_return_type_; }
void SetValueType(ValueType type) { value_type_ = type; }
ValueType GetValueType() const { return value_type_; }
bool IsTerminator() const;
std::vector<int> GetDefs() const;
std::vector<int> GetUses() const;
private:
Opcode opcode_;
std::vector<MachineOperand> operands_;
CondCode cond_code_ = CondCode::EQ;
AddressExpr address_;
bool has_address_ = false;
std::string callee_;
std::vector<ValueType> call_arg_types_;
ValueType call_return_type_ = ValueType::Void;
ValueType value_type_ = ValueType::Void;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
MachineInstr& Append(MachineInstr::Opcode opcode,
std::vector<MachineOperand> operands = {});
MachineInstr& Append(MachineInstr instr);
private:
std::string name_;
std::vector<MachineInstr> instructions_;
};
class MachineFunction {
public:
MachineFunction(std::string name, ValueType return_type,
std::vector<ValueType> param_types);
const std::string& GetName() const { return name_; }
ValueType GetReturnType() const { return return_type_; }
const std::vector<ValueType>& GetParamTypes() const { return param_types_; }
MachineBasicBlock* CreateBlock(const std::string& name);
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() { return blocks_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const {
return blocks_;
}
int NewVReg(ValueType type);
const VRegInfo& GetVRegInfo(int id) const;
VRegInfo& GetVRegInfo(int id);
const std::vector<VRegInfo>& GetVRegs() const { return vregs_; }
int CreateStackObject(int size, int align, StackObjectKind kind,
const std::string& name = "");
StackObject& GetStackObject(int index);
const StackObject& GetStackObject(int index) const;
const std::vector<StackObject>& GetStackObjects() const { return stack_objects_; }
void SetAllocation(int vreg, Allocation allocation);
const Allocation& GetAllocation(int vreg) const;
Allocation& GetAllocation(int vreg);
void AddUsedCalleeSavedGPR(int reg_index);
void AddUsedCalleeSavedFPR(int reg_index);
const std::vector<int>& GetUsedCalleeSavedGPRs() const {
return used_callee_saved_gprs_;
}
const std::vector<int>& GetUsedCalleeSavedFPRs() const {
return used_callee_saved_fprs_;
}
void SetFrameSize(int size) { frame_size_ = size; }
int GetFrameSize() const { return frame_size_; }
void SetMaxOutgoingArgBytes(int bytes) { max_outgoing_arg_bytes_ = bytes; }
int GetMaxOutgoingArgBytes() const { return max_outgoing_arg_bytes_; }
private:
std::string name_;
ValueType return_type_ = ValueType::Void;
std::vector<ValueType> param_types_;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<VRegInfo> vregs_;
std::vector<StackObject> stack_objects_;
std::vector<Allocation> allocations_;
std::vector<int> used_callee_saved_gprs_;
std::vector<int> used_callee_saved_fprs_;
int frame_size_ = 0;
int max_outgoing_arg_bytes_ = 0;
};
class MachineModule {
public:
explicit MachineModule(const ir::Module& source) : source_(&source) {}
const ir::Module& GetSourceModule() const { return *source_; }
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() { return functions_; }
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const {
return functions_;
}
MachineFunction* AddFunction(std::unique_ptr<MachineFunction> function);
private:
const ir::Module* source_ = nullptr;
std::vector<std::unique_ptr<MachineFunction>> functions_;
};
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
bool RunPeephole(MachineModule& module);
bool RunSpillReduction(MachineModule& module);
bool RunCFGCleanup(MachineModule& module);
void RunMIRPreRegAllocPassPipeline(MachineModule& module);
void RunMIRPostRegAllocPassPipeline(MachineModule& module);
void RunAddressHoisting(MachineModule& module);
void RunRegAlloc(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
} // namespace mir

@ -0,0 +1,94 @@
#pragma once
#include "SysYParser.h"
#include "sem/SymbolTable.h"
#include <string>
#include <unordered_map>
#include <vector>
struct GlobalSemanticInfo {
SemanticType type = SemanticType::Int;
bool is_const = false;
bool is_array = false;
std::vector<int> dims;
};
struct FunctionSemanticInfo {
SemanticType return_type = SemanticType::Void;
std::vector<bool> param_is_array;
bool is_builtin = false;
bool is_defined = false;
bool reads_global_memory = false;
bool writes_global_memory = false;
bool reads_param_memory = false;
bool writes_param_memory = false;
bool has_io = false;
bool has_unknown_effects = true;
bool is_recursive = false;
std::vector<std::string> direct_callees;
bool MayReadMemory() const {
return has_unknown_effects || reads_global_memory || writes_global_memory ||
reads_param_memory || writes_param_memory;
}
bool MayWriteMemory() const {
return has_unknown_effects || writes_global_memory || writes_param_memory;
}
bool HasObservableSideEffects() const {
return has_unknown_effects || writes_global_memory || writes_param_memory ||
has_io;
}
bool CanDiscardUnusedCall() const {
return !has_unknown_effects && !writes_global_memory &&
!writes_param_memory && !has_io && !is_recursive;
}
};
class SemanticContext {
public:
FunctionSemanticInfo* LookupFunction(const std::string& name) {
auto it = functions_.find(name);
return it == functions_.end() ? nullptr : &it->second;
}
const FunctionSemanticInfo* LookupFunction(const std::string& name) const {
auto it = functions_.find(name);
return it == functions_.end() ? nullptr : &it->second;
}
GlobalSemanticInfo* LookupGlobal(const std::string& name) {
auto it = globals_.find(name);
return it == globals_.end() ? nullptr : &it->second;
}
const GlobalSemanticInfo* LookupGlobal(const std::string& name) const {
auto it = globals_.find(name);
return it == globals_.end() ? nullptr : &it->second;
}
FunctionSemanticInfo& UpsertFunction(const std::string& name) {
return functions_[name];
}
GlobalSemanticInfo& UpsertGlobal(const std::string& name) {
return globals_[name];
}
const std::unordered_map<std::string, FunctionSemanticInfo>& GetFunctions() const {
return functions_;
}
const std::unordered_map<std::string, GlobalSemanticInfo>& GetGlobals() const {
return globals_;
}
private:
std::unordered_map<std::string, FunctionSemanticInfo> functions_;
std::unordered_map<std::string, GlobalSemanticInfo> globals_;
};
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -0,0 +1,69 @@
#pragma once
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
class Function;
class Value;
}
enum class SemanticType {
Void,
Int,
Float,
};
enum class SymbolKind {
Variable,
Constant,
Function,
};
struct ConstantValue {
SemanticType type = SemanticType::Int;
int int_value = 0;
float float_value = 0.0f;
};
struct FunctionTypeInfo {
SemanticType return_type = SemanticType::Void;
std::vector<SemanticType> param_types;
std::vector<bool> param_is_array;
std::vector<std::vector<int>> param_dims;
};
struct SymbolEntry {
SymbolKind kind = SymbolKind::Variable;
SemanticType type = SemanticType::Int;
bool is_const = false;
bool is_array = false;
bool is_param_array = false;
std::vector<int> dims;
ir::Value* ir_value = nullptr;
ir::Function* function = nullptr;
std::optional<ConstantValue> const_scalar;
std::vector<ConstantValue> const_array;
bool const_array_all_zero = false;
FunctionTypeInfo function_type;
};
class SymbolTable {
public:
void Clear();
void EnterScope();
void ExitScope();
bool Insert(const std::string& name, const SymbolEntry& entry);
bool ContainsInCurrentScope(const std::string& name) const;
SymbolEntry* Lookup(const std::string& name);
const SymbolEntry* Lookup(const std::string& name) const;
private:
std::vector<std::unordered_map<std::string, SymbolEntry>> scopes_;
};

@ -0,0 +1,14 @@
// 简易命令行解析:支持帮助、输入文件与输出阶段选择。
#pragma once
#include <string>
struct CLIOptions {
std::string input;
bool emit_parse_tree = false;
bool emit_ir = true;
bool emit_asm = false;
bool show_help = false;
};
CLIOptions ParseCLI(int argc, char** argv);

@ -0,0 +1,20 @@
// 轻量日志接口。
#pragma once
#include <cstddef>
#include <exception>
#include <iosfwd>
#include <string>
#include <string_view>
void LogInfo(std::string_view msg, std::ostream& os);
void LogError(std::string_view msg, std::ostream& os);
std::string FormatError(std::string_view stage, std::string_view msg);
std::string FormatErrorAt(std::string_view stage, std::size_t line,
std::size_t column, std::string_view msg);
bool HasErrorPrefix(std::string_view msg, std::string_view stage);
void PrintException(std::ostream& os, const std::exception& ex);
// 打印命令行帮助信息(用于 `compiler --help`)。
void PrintHelp(std::ostream& os);

@ -0,0 +1,195 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
BUILD_DIR="$REPO_ROOT/build_lab1"
COMPILER="$BUILD_DIR/bin/compiler"
ANTLR_JAR="$REPO_ROOT/third_party/antlr-4.13.2-complete.jar"
RUN_ROOT="$REPO_ROOT/output/logs/lab1"
RUN_NAME="lab1_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_TREE=false
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
TEST_DIRS=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-tree)
LEGACY_SAVE_TREE=true
;;
*)
TEST_DIRS+=("$1")
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
if [[ ${#TEST_DIRS[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_TREE" == true ]]; then
log_color "$YELLOW" "Warning: --save-tree is deprecated; successful case artifacts will still be deleted."
fi
log_plain "==> [1/3] Generate ANTLR Lexer/Parser"
mkdir -p "$BUILD_DIR/generated/antlr4"
if ! java -jar "$ANTLR_JAR" \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o "$BUILD_DIR/generated/antlr4" \
"$REPO_ROOT/src/antlr4/SysY.g4" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "ANTLR generation failed. See $WHOLE_LOG"
exit 1
fi
log_plain "==> [2/3] Configure and build parse-only compiler"
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
fi
log_plain "==> [3/3] Run parse validation suite"
PASS=0
FAIL=0
FAIL_LIST=()
test_one() {
local sy_file="$1"
local rel="$2"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local tree_file="$tmp_dir/parse.tree"
local case_log="$tmp_dir/error.log"
cleanup_tmp_dir "$tmp_dir"
cleanup_tmp_dir "$fail_case_dir"
mkdir -p "$tmp_dir"
if "$COMPILER" --emit-parse-tree "$sy_file" > "$tree_file" 2> "$case_log"; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
mkdir -p "$FAIL_DIR"
{
printf 'Command: %s --emit-parse-tree %s\n' "$COMPILER" "$sy_file"
if [[ -s "$case_log" ]]; then
printf '\n'
cat "$case_log"
fi
} > "$tmp_dir/error.log.tmp"
mv "$tmp_dir/error.log.tmp" "$case_log"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
continue
fi
while IFS= read -r -d '' sy_file; do
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
if test_one "$sy_file" "$rel"; then
log_color "$GREEN" "PASS $rel"
PASS=$((PASS + 1))
else
log_color "$RED" "FAIL $rel"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
fi
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
prune_empty_run_dirs
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,288 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
VERIFY_SCRIPT="$REPO_ROOT/scripts/verify_ir.sh"
BUILD_DIR="$REPO_ROOT/build_lab2"
RUN_ROOT="$REPO_ROOT/output/logs/lab2"
LAST_RUN_FILE="$RUN_ROOT/last_run.txt"
LAST_FAILED_FILE="$RUN_ROOT/last_failed.txt"
RUN_NAME="lab2_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_IR=false
FAILED_ONLY=false
FALLBACK_TO_FULL=false
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
TEST_DIRS=()
TEST_FILES=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-ir)
LEGACY_SAVE_IR=true
;;
--failed-only)
FAILED_ONLY=true
;;
*)
if [[ -f "$1" ]]; then
TEST_FILES+=("$1")
else
TEST_DIRS+=("$1")
fi
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
printf '%s\n' "$RUN_DIR" > "$LAST_RUN_FILE"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
now_ns() {
date +%s%N
}
format_duration_ns() {
local ns="$1"
local sec=$((ns / 1000000000))
local ms=$(((ns % 1000000000) / 1000000))
printf '%d.%03ds' "$sec" "$ms"
}
is_transient_io_failure() {
local log_file="$1"
[[ -f "$log_file" ]] || return 1
grep -Eq \
'Permission denied|Text file busy|Device or resource busy|Stale file handle|Input/output error|Resource temporarily unavailable|Read-only file system' \
"$log_file"
}
test_one() {
local sy_file="$1"
local rel="$2"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local case_log="$tmp_dir/error.log"
local attempt=1
cleanup_tmp_dir "$fail_case_dir"
while true; do
cleanup_tmp_dir "$tmp_dir"
mkdir -p "$tmp_dir"
if "$VERIFY_SCRIPT" "$sy_file" "$tmp_dir" --run > "$case_log" 2>&1; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
if [[ $attempt -eq 1 ]] && is_transient_io_failure "$case_log"; then
log_color "$YELLOW" "RETRY $rel (transient I/O failure)"
attempt=$((attempt + 1))
continue
fi
break
done
mkdir -p "$FAIL_DIR"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
run_case() {
local sy_file="$1"
local rel
local case_start_ns
local case_end_ns
local case_elapsed_ns
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
case_start_ns=$(now_ns)
if test_one "$sy_file" "$rel"; then
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
log_color "$GREEN" "PASS $rel [$(format_duration_ns "$case_elapsed_ns")]"
PASS=$((PASS + 1))
else
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
log_color "$RED" "FAIL $rel [$(format_duration_ns "$case_elapsed_ns")]"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
fi
}
TOTAL_START_NS=$(now_ns)
if [[ "$FAILED_ONLY" == true ]]; then
if [[ -f "$LAST_FAILED_FILE" ]]; then
while IFS= read -r sy_file; do
[[ -n "$sy_file" ]] || continue
[[ -f "$sy_file" ]] || continue
TEST_FILES+=("$sy_file")
done < "$LAST_FAILED_FILE"
fi
if [[ ${#TEST_FILES[@]} -eq 0 ]]; then
FALLBACK_TO_FULL=true
FAILED_ONLY=false
fi
fi
if [[ "$FAILED_ONLY" == false && ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_IR" == true ]]; then
log_color "$YELLOW" "Warning: --save-ir is deprecated; successful case artifacts will still be deleted."
fi
if [[ "$FAILED_ONLY" == true ]]; then
log_plain "Mode: rerun cached failed cases only"
fi
if [[ "$FALLBACK_TO_FULL" == true ]]; then
log_color "$YELLOW" "No cached failed cases found, fallback to full suite."
fi
if [[ ! -f "$VERIFY_SCRIPT" ]]; then
log_color "$RED" "missing verify script: $VERIFY_SCRIPT"
exit 1
fi
log_plain "==> [1/2] Configure and build compiler"
BUILD_START_NS=$(now_ns)
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
fi
BUILD_END_NS=$(now_ns)
BUILD_ELAPSED_NS=$((BUILD_END_NS - BUILD_START_NS))
log_plain "==> [2/2] Run IR validation suite"
VALIDATION_START_NS=$(now_ns)
PASS=0
FAIL=0
FAIL_LIST=()
if [[ "$FAILED_ONLY" == true ]]; then
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
else
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
continue
fi
while IFS= read -r -d '' sy_file; do
run_case "$sy_file"
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
fi
rm -f "$LAST_FAILED_FILE"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
for f in "${FAIL_LIST[@]}"; do
printf '%s/%s\n' "$REPO_ROOT" "$f" >> "$LAST_FAILED_FILE"
done
fi
prune_empty_run_dirs
VALIDATION_END_NS=$(now_ns)
VALIDATION_ELAPSED_NS=$((VALIDATION_END_NS - VALIDATION_START_NS))
TOTAL_END_NS=$(now_ns)
TOTAL_ELAPSED_NS=$((TOTAL_END_NS - TOTAL_START_NS))
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
log_plain "build elapsed: $(format_duration_ns "$BUILD_ELAPSED_NS")"
log_plain "validation elapsed: $(format_duration_ns "$VALIDATION_ELAPSED_NS")"
log_plain "total elapsed: $(format_duration_ns "$TOTAL_ELAPSED_NS")"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,295 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
VERIFY_SCRIPT="$REPO_ROOT/scripts/verify_asm.sh"
BUILD_DIR="$REPO_ROOT/build_lab3"
RUN_ROOT="$REPO_ROOT/output/logs/lab3"
LAST_RUN_FILE="$RUN_ROOT/last_run.txt"
LAST_FAILED_FILE="$RUN_ROOT/last_failed.txt"
RUN_NAME="lab3_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_ASM=false
FAILED_ONLY=false
FALLBACK_TO_FULL=false
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
TEST_DIRS=()
TEST_FILES=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-asm)
LEGACY_SAVE_ASM=true
;;
--failed-only)
FAILED_ONLY=true
;;
*)
if [[ -f "$1" ]]; then
TEST_FILES+=("$1")
else
TEST_DIRS+=("$1")
fi
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
printf '%s\n' "$RUN_DIR" > "$LAST_RUN_FILE"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
now_ns() {
date +%s%N
}
format_duration_ns() {
local ns="$1"
local sec=$((ns / 1000000000))
local ms=$(((ns % 1000000000) / 1000000))
printf '%d.%03ds' "$sec" "$ms"
}
is_transient_io_failure() {
local log_file="$1"
[[ -f "$log_file" ]] || return 1
grep -Eq \
'Permission denied|Text file busy|Device or resource busy|Stale file handle|Input/output error|Resource temporarily unavailable|Read-only file system' \
"$log_file"
}
test_one() {
local sy_file="$1"
local rel="$2"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local case_log="$tmp_dir/error.log"
local attempt=1
cleanup_tmp_dir "$fail_case_dir"
while true; do
cleanup_tmp_dir "$tmp_dir"
mkdir -p "$tmp_dir"
if "$VERIFY_SCRIPT" "$sy_file" "$tmp_dir" --run > "$case_log" 2>&1; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
if [[ $attempt -eq 1 ]] && is_transient_io_failure "$case_log"; then
log_color "$YELLOW" "RETRY $rel (transient I/O failure)"
attempt=$((attempt + 1))
continue
fi
break
done
mkdir -p "$FAIL_DIR"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
run_case() {
local sy_file="$1"
local rel
local case_start_ns
local case_end_ns
local case_elapsed_ns
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
case_start_ns=$(now_ns)
if test_one "$sy_file" "$rel"; then
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
log_color "$GREEN" "PASS $rel [$(format_duration_ns "$case_elapsed_ns")]"
PASS=$((PASS + 1))
else
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
log_color "$RED" "FAIL $rel [$(format_duration_ns "$case_elapsed_ns")]"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
fi
}
TOTAL_START_NS=$(now_ns)
if [[ "$FAILED_ONLY" == true ]]; then
if [[ -f "$LAST_FAILED_FILE" ]]; then
while IFS= read -r sy_file; do
[[ -n "$sy_file" ]] || continue
[[ -f "$sy_file" ]] || continue
TEST_FILES+=("$sy_file")
done < "$LAST_FAILED_FILE"
fi
if [[ ${#TEST_FILES[@]} -eq 0 ]]; then
FALLBACK_TO_FULL=true
FAILED_ONLY=false
fi
fi
if [[ "$FAILED_ONLY" == false && ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_ASM" == true ]]; then
log_color "$YELLOW" "Warning: --save-asm is deprecated; successful case artifacts will still be deleted."
fi
if [[ "$FAILED_ONLY" == true ]]; then
log_plain "Mode: rerun cached failed cases only"
fi
if [[ "$FALLBACK_TO_FULL" == true ]]; then
log_color "$YELLOW" "No cached failed cases found, fallback to full suite."
fi
if [[ ! -f "$VERIFY_SCRIPT" ]]; then
log_color "$RED" "missing verify script: $VERIFY_SCRIPT"
exit 1
fi
for tool in llc aarch64-linux-gnu-gcc qemu-aarch64; do
if ! command -v "$tool" >/dev/null 2>&1; then
log_color "$RED" "missing required tool: $tool"
exit 1
fi
done
log_plain "==> [1/2] Configure and build compiler"
BUILD_START_NS=$(now_ns)
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
fi
BUILD_END_NS=$(now_ns)
BUILD_ELAPSED_NS=$((BUILD_END_NS - BUILD_START_NS))
log_plain "==> [2/2] Run ASM validation suite"
VALIDATION_START_NS=$(now_ns)
PASS=0
FAIL=0
FAIL_LIST=()
if [[ "$FAILED_ONLY" == true ]]; then
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
else
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
continue
fi
while IFS= read -r -d '' sy_file; do
run_case "$sy_file"
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
fi
rm -f "$LAST_FAILED_FILE"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
for f in "${FAIL_LIST[@]}"; do
printf '%s/%s\n' "$REPO_ROOT" "$f" >> "$LAST_FAILED_FILE"
done
fi
prune_empty_run_dirs
VALIDATION_END_NS=$(now_ns)
VALIDATION_ELAPSED_NS=$((VALIDATION_END_NS - VALIDATION_START_NS))
TOTAL_END_NS=$(now_ns)
TOTAL_ELAPSED_NS=$((TOTAL_END_NS - TOTAL_START_NS))
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
log_plain "build elapsed: $(format_duration_ns "$BUILD_ELAPSED_NS")"
log_plain "validation elapsed: $(format_duration_ns "$VALIDATION_ELAPSED_NS")"
log_plain "total elapsed: $(format_duration_ns "$TOTAL_ELAPSED_NS")"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,124 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "usage: $0 input.sy [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="$REPO_ROOT/test/test_result/asm"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "input file not found: $input" >&2
exit 1
fi
compiler=""
for candidate in "$REPO_ROOT/build_lab3/bin/compiler" "$REPO_ROOT/build_lab2/bin/compiler" "$REPO_ROOT/build/bin/compiler"; do
if [[ -x "$candidate" ]]; then
compiler="$candidate"
break
fi
done
if [[ -z "$compiler" ]]; then
echo "compiler not found; try: cmake -S . -B build_lab3 && cmake --build build_lab3 -j" >&2
exit 1
fi
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
echo "aarch64-linux-gnu-gcc not found" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" > "$asm_file"
echo "asm generated: $asm_file"
aarch64-linux-gnu-gcc "$asm_file" "$REPO_ROOT/sylib/sylib.c" -O2 -o "$exe"
echo "executable generated: $exe"
if [[ "$run_exec" == true ]]; then
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
echo "qemu-aarch64 not found" >&2
exit 1
fi
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
timeout_sec="${RUN_TIMEOUT_SEC:-60}"
if [[ "$input" == *"/performance/"* || "$input" == *"/h_performance/"* ]]; then
timeout_sec="${PERF_TIMEOUT_SEC:-300}"
fi
set +e
if command -v timeout >/dev/null 2>&1; then
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
timeout "$timeout_sec" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
else
if [[ -f "$stdin_file" ]]; then
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
fi
status=$?
set -e
if [[ $status -eq 124 ]]; then
echo "timeout after ${timeout_sec}s: $exe" >&2
fi
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
echo "exit code: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u <(awk '{ sub(/\r$/, ""); print }' "$expected_file") <(awk '{ sub(/\r$/, ""); print }' "$actual_file"); then
echo "matched: $expected_file"
else
echo "mismatch: $expected_file" >&2
echo "actual saved to: $actual_file" >&2
exit 1
fi
else
echo "expected output not found, skipped diff: $expected_file"
fi
fi

@ -0,0 +1,145 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -lt 1 || $# -gt 2 ]]; then
echo "用法: $0 <test_dir> [output_dir]" >&2
exit 1
fi
test_dir=${1%/}
out_dir="test/test_result/function/asm"
shift
while [[ $# -gt 0 ]]; do
out_dir="$1"
shift
done
if [[ ! -d "$test_dir" ]]; then
echo "测试目录不存在: $test_dir" >&2
exit 1
fi
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 aarch64-linux-gnu-gcc无法汇编/链接。" >&2
exit 1
fi
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
echo "未找到 qemu-aarch64无法运行生成的可执行文件。" >&2
exit 1
fi
sylib_c="sylib/sylib.c"
if [[ ! -f "$sylib_c" ]]; then
echo "未找到 sylib: $sylib_c" >&2
exit 1
fi
mkdir -p "$out_dir"
sylib_obj="$out_dir/sylib.o"
aarch64-linux-gnu-gcc -c "$sylib_c" -I sylib -o "$sylib_obj"
mapfile -t inputs < <(find "$test_dir" -type f -name '*.sy' | sort)
if [[ ${#inputs[@]} -eq 0 ]]; then
echo "测试目录下未找到 .sy 文件: $test_dir" >&2
exit 1
fi
failures=0
normalize() {
# strip CR, then strip a single trailing newline so both files
# are comparable regardless of CRLF / trailing-newline differences
tr -d '\r' < "$1" | sed -e '${ /^$/d; }' | perl -pe 'chomp if eof'
}
run_case() {
local input=$1
local input_dir base stem rel_path rel_dir case_out_dir asm_file exe
local stdin_file expected_file stdout_file actual_file status
input_dir=$(dirname "$input")
base=$(basename "$input")
stem=${base%.sy}
rel_path=${input#"$test_dir"/}
rel_dir=$(dirname "$rel_path")
case_out_dir="$out_dir"
if [[ "$rel_dir" != "." ]]; then
case_out_dir="$out_dir/$rel_dir"
fi
mkdir -p "$case_out_dir"
asm_file="$case_out_dir/$stem.s"
exe="$case_out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
stdout_file="$case_out_dir/$stem.stdout"
actual_file="$case_out_dir/$stem.actual.out"
if ! "$compiler" --emit-asm "$input" > "$asm_file" 2>"$case_out_dir/$stem.err"; then
echo "$stem: 编译失败"
cat "$case_out_dir/$stem.err" >&2
return 1
fi
if ! aarch64-linux-gnu-gcc "$asm_file" "$sylib_obj" -o "$exe" 2>"$case_out_dir/$stem.link.err"; then
echo "$stem: 链接失败"
cat "$case_out_dir/$stem.link.err" >&2
return 1
fi
set +e
if [[ -f "$stdin_file" ]]; then
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
status=$?
set -e
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff <(normalize "$expected_file") <(normalize "$actual_file") >/dev/null 2>&1; then
echo "$stem: PASS"
return 0
else
echo "$stem: FAIL (退出码: $status)"
diff -u --strip-trailing-cr "$expected_file" "$actual_file" >&2 || true
return 1
fi
else
echo "$stem: SKIP (无预期输出, 退出码: $status)"
return 0
fi
}
for input in "${inputs[@]}"; do
if ! run_case "$input"; then
((failures+=1))
fi
done
total=${#inputs[@]}
passed=$((total - failures))
echo "总计: $total, 通过: $passed, 失败: $failures"
if (( failures > 0 )); then
exit 1
fi

@ -0,0 +1,160 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -lt 1 || $# -gt 2 ]]; then
echo "用法: $0 <test_dir> [output_dir]" >&2
exit 1
fi
test_dir=${1%/}
out_dir="test/test_result/function/asm_time"
shift
while [[ $# -gt 0 ]]; do
out_dir="$1"
shift
done
if [[ ! -d "$test_dir" ]]; then
echo "测试目录不存在: $test_dir" >&2
exit 1
fi
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 aarch64-linux-gnu-gcc无法汇编/链接。" >&2
exit 1
fi
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
echo "未找到 qemu-aarch64无法运行生成的可执行文件。" >&2
exit 1
fi
sylib_c="sylib/sylib.c"
if [[ ! -f "$sylib_c" ]]; then
echo "未找到 sylib: $sylib_c" >&2
exit 1
fi
mkdir -p "$out_dir"
sylib_obj="$out_dir/sylib.o"
aarch64-linux-gnu-gcc -c "$sylib_c" -I sylib -o "$sylib_obj"
mapfile -t inputs < <(find "$test_dir" -type f -name '*.sy' | sort)
if [[ ${#inputs[@]} -eq 0 ]]; then
echo "测试目录下未找到 .sy 文件: $test_dir" >&2
exit 1
fi
failures=0
normalize() {
tr -d '\r' < "$1" | sed -e '${ /^$/d; }' | perl -pe 'chomp if eof'
}
run_case() {
local input=$1
local input_dir base stem rel_path rel_dir case_out_dir asm_file exe
local stdin_file expected_file stdout_file actual_file time_file elapsed status
input_dir=$(dirname "$input")
base=$(basename "$input")
stem=${base%.sy}
rel_path=${input#"$test_dir"/}
rel_dir=$(dirname "$rel_path")
case_out_dir="$out_dir"
if [[ "$rel_dir" != "." ]]; then
case_out_dir="$out_dir/$rel_dir"
fi
mkdir -p "$case_out_dir"
asm_file="$case_out_dir/$stem.s"
exe="$case_out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
stdout_file="$case_out_dir/$stem.stdout"
actual_file="$case_out_dir/$stem.actual.out"
time_file="$case_out_dir/$stem.time"
if ! "$compiler" --emit-asm "$input" > "$asm_file" 2>"$case_out_dir/$stem.err"; then
echo "$stem: 编译失败"
cat "$case_out_dir/$stem.err" >&2
return 1
fi
if ! aarch64-linux-gnu-gcc "$asm_file" "$sylib_obj" -o "$exe" 2>"$case_out_dir/$stem.link.err"; then
echo "$stem: 链接失败"
cat "$case_out_dir/$stem.link.err" >&2
return 1
fi
set +e
if [[ -f "$stdin_file" ]]; then
/usr/bin/time -f "%e" -o "$time_file" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
/usr/bin/time -f "%e" -o "$time_file" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
status=$?
set -e
elapsed=$(tail -1 "$time_file")
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff <(normalize "$expected_file") <(normalize "$actual_file") >/dev/null 2>&1; then
printf "%s: PASS (%.3fs)\n" "$stem" "$elapsed"
printf '%s\t%s\n' "$elapsed" "$stem" >> "$out_dir/elapsed.log"
return 0
else
printf "%s: FAIL (退出码: %d, 耗时: %.3fs)\n" "$stem" "$status" "$elapsed"
diff -u --strip-trailing-cr "$expected_file" "$actual_file" >&2 || true
return 1
fi
else
printf "%s: SKIP (无预期输出, %.3fs, 退出码: %d)\n" "$stem" "$elapsed" "$status"
printf '%s\t%s\n' "$elapsed" "$stem" >> "$out_dir/elapsed.log"
return 0
fi
}
rm -f "$out_dir/elapsed.log"
for input in "${inputs[@]}"; do
if ! run_case "$input"; then
((failures+=1))
fi
done
total=${#inputs[@]}
passed=$((total - failures))
if [[ -f "$out_dir/elapsed.log" && -s "$out_dir/elapsed.log" ]]; then
total_elapsed=$(awk '{s+=$1} END {printf "%.3f", s}' "$out_dir/elapsed.log")
else
total_elapsed="0.000"
fi
echo "总计: $total, 通过: $passed, 失败: $failures"
echo "通过用例总耗时: $total_elapsed s"
if (( failures > 0 )); then
exit 1
fi

@ -0,0 +1,126 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "usage: $0 input.sy [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="$REPO_ROOT/test/test_result/ir"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "input file not found: $input" >&2
exit 1
fi
compiler=""
for candidate in "$REPO_ROOT/build_lab2/bin/compiler" "$REPO_ROOT/build/bin/compiler"; do
if [[ -x "$candidate" ]]; then
compiler="$candidate"
break
fi
done
if [[ -z "$compiler" ]]; then
echo "compiler not found; try: cmake -S . -B build_lab2 && cmake --build build_lab2 -j" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
out_file="$out_dir/$stem.ll"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-ir "$input" > "$out_file"
echo "IR generated: $out_file"
if [[ "$run_exec" == true ]]; then
if ! command -v llc >/dev/null 2>&1; then
echo "llc not found" >&2
exit 1
fi
if ! command -v clang >/dev/null 2>&1; then
echo "clang not found" >&2
exit 1
fi
obj="$out_dir/$stem.o"
exe="$out_dir/$stem"
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
llc -opaque-pointers -filetype=obj "$out_file" -o "$obj"
clang "$obj" "$REPO_ROOT/sylib/sylib.c" -o "$exe"
# Optional timeout to prevent hanging test cases.
# Override with RUN_TIMEOUT_SEC/PERF_TIMEOUT_SEC env vars.
timeout_sec="${RUN_TIMEOUT_SEC:-60}"
if [[ "$input" == *"/performance/"* || "$input" == *"/h_performance/"* ]]; then
timeout_sec="${PERF_TIMEOUT_SEC:-300}"
fi
set +e
if command -v timeout >/dev/null 2>&1; then
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" "$exe" < "$stdin_file" > "$stdout_file"
else
timeout "$timeout_sec" "$exe" > "$stdout_file"
fi
else
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
else
"$exe" > "$stdout_file"
fi
fi
status=$?
set -e
if [[ $status -eq 124 ]]; then
echo "timeout after ${timeout_sec}s: $exe" >&2
fi
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
echo "exit code: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u <(awk '{ sub(/\r$/, ""); print }' "$expected_file") <(awk '{ sub(/\r$/, ""); print }' "$actual_file"); then
echo "matched: $expected_file"
else
echo "mismatch: $expected_file" >&2
echo "actual saved to: $actual_file" >&2
exit 1
fi
else
echo "expected output not found, skipped diff: $expected_file"
fi
fi

@ -0,0 +1,29 @@
# src/ CMakeLists.txt
add_subdirectory(utils)
add_subdirectory(ir)
add_subdirectory(frontend)
if(NOT COMPILER_PARSE_ONLY)
add_subdirectory(sem)
add_subdirectory(irgen)
add_subdirectory(mir)
endif()
add_executable(compiler
main.cpp
)
target_link_libraries(compiler PRIVATE
frontend
utils
)
if(NOT COMPILER_PARSE_ONLY)
target_link_libraries(compiler PRIVATE
sem
irgen
mir
)
target_compile_definitions(compiler PRIVATE COMPILER_PARSE_ONLY=0)
else()
target_compile_definitions(compiler PRIVATE COMPILER_PARSE_ONLY=1)
endif()

@ -0,0 +1,591 @@
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.tree.ErrorNode;
import org.antlr.v4.runtime.tree.TerminalNode;
/**
* This class provides an empty implementation of {@link SysYListener},
* which can be extended to create a listener which only needs to handle a subset
* of the available methods.
*/
@SuppressWarnings("CheckReturnValue")
public class SysYBaseListener implements SysYListener {
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterCompUnit(SysYParser.CompUnitContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitCompUnit(SysYParser.CompUnitContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterDecl(SysYParser.DeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitDecl(SysYParser.DeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstDecl(SysYParser.ConstDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstDecl(SysYParser.ConstDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBType(SysYParser.BTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBType(SysYParser.BTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstDef(SysYParser.ConstDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstDef(SysYParser.ConstDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstInitVal(SysYParser.ConstInitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstInitVal(SysYParser.ConstInitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterVarDecl(SysYParser.VarDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitVarDecl(SysYParser.VarDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterVarDef(SysYParser.VarDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitVarDef(SysYParser.VarDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterInitVal(SysYParser.InitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitInitVal(SysYParser.InitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncDef(SysYParser.FuncDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncDef(SysYParser.FuncDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncType(SysYParser.FuncTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncType(SysYParser.FuncTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncFParams(SysYParser.FuncFParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncFParams(SysYParser.FuncFParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncFParam(SysYParser.FuncFParamContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncFParam(SysYParser.FuncFParamContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBlock(SysYParser.BlockContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBlock(SysYParser.BlockContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBlockItem(SysYParser.BlockItemContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBlockItem(SysYParser.BlockItemContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterAssignStmt(SysYParser.AssignStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitAssignStmt(SysYParser.AssignStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterExpStmt(SysYParser.ExpStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitExpStmt(SysYParser.ExpStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBlockStmt(SysYParser.BlockStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBlockStmt(SysYParser.BlockStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterIfStmt(SysYParser.IfStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitIfStmt(SysYParser.IfStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterWhileStmt(SysYParser.WhileStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitWhileStmt(SysYParser.WhileStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBreakStmt(SysYParser.BreakStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBreakStmt(SysYParser.BreakStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterContinueStmt(SysYParser.ContinueStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitContinueStmt(SysYParser.ContinueStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterReturnStmt(SysYParser.ReturnStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitReturnStmt(SysYParser.ReturnStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterExp(SysYParser.ExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitExp(SysYParser.ExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterCond(SysYParser.CondContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitCond(SysYParser.CondContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterLVal(SysYParser.LValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitLVal(SysYParser.LValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterPrimaryExp(SysYParser.PrimaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitPrimaryExp(SysYParser.PrimaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterNumber(SysYParser.NumberContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitNumber(SysYParser.NumberContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterCallUnaryExp(SysYParser.CallUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitCallUnaryExp(SysYParser.CallUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterOpUnaryExp(SysYParser.OpUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitOpUnaryExp(SysYParser.OpUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterUnaryOp(SysYParser.UnaryOpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitUnaryOp(SysYParser.UnaryOpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncRParams(SysYParser.FuncRParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncRParams(SysYParser.FuncRParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryMulExp(SysYParser.BinaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryMulExp(SysYParser.BinaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterUnaryMulExp(SysYParser.UnaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitUnaryMulExp(SysYParser.UnaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryAddExp(SysYParser.BinaryAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryAddExp(SysYParser.BinaryAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterMulAddExp(SysYParser.MulAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitMulAddExp(SysYParser.MulAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterAddRelExp(SysYParser.AddRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitAddRelExp(SysYParser.AddRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryRelExp(SysYParser.BinaryRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryRelExp(SysYParser.BinaryRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryEqExp(SysYParser.BinaryEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryEqExp(SysYParser.BinaryEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterRelEqExp(SysYParser.RelEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitRelEqExp(SysYParser.RelEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterEqLAndExp(SysYParser.EqLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitEqLAndExp(SysYParser.EqLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterAndLOrExp(SysYParser.AndLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitAndLOrExp(SysYParser.AndLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstExp(SysYParser.ConstExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstExp(SysYParser.ConstExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterEveryRule(ParserRuleContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitEveryRule(ParserRuleContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void visitTerminal(TerminalNode node) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void visitErrorNode(ErrorNode node) { }
}

@ -0,0 +1,358 @@
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
import org.antlr.v4.runtime.Lexer;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.*;
import org.antlr.v4.runtime.atn.*;
import org.antlr.v4.runtime.dfa.DFA;
import org.antlr.v4.runtime.misc.*;
@SuppressWarnings({"all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape"})
public class SysYLexer extends Lexer {
static { RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); }
protected static final DFA[] _decisionToDFA;
protected static final PredictionContextCache _sharedContextCache =
new PredictionContextCache();
public static final int
CONST=1, INT=2, FLOAT=3, VOID=4, IF=5, ELSE=6, WHILE=7, BREAK=8, CONTINUE=9,
RETURN=10, ADD=11, SUB=12, MUL=13, DIV=14, MOD=15, ASSIGN=16, EQ=17, NE=18,
LT=19, LE=20, GT=21, GE=22, NOT=23, AND=24, OR=25, LPAREN=26, RPAREN=27,
LBRACK=28, RBRACK=29, LBRACE=30, RBRACE=31, COMMA=32, SEMI=33, Ident=34,
IntConst=35, FloatConst=36, WS=37, LINE_COMMENT=38, BLOCK_COMMENT=39;
public static String[] channelNames = {
"DEFAULT_TOKEN_CHANNEL", "HIDDEN"
};
public static String[] modeNames = {
"DEFAULT_MODE"
};
private static String[] makeRuleNames() {
return new String[] {
"CONST", "INT", "FLOAT", "VOID", "IF", "ELSE", "WHILE", "BREAK", "CONTINUE",
"RETURN", "ADD", "SUB", "MUL", "DIV", "MOD", "ASSIGN", "EQ", "NE", "LT",
"LE", "GT", "GE", "NOT", "AND", "OR", "LPAREN", "RPAREN", "LBRACK", "RBRACK",
"LBRACE", "RBRACE", "COMMA", "SEMI", "Ident", "Digit", "NonzeroDigit",
"OctDigit", "HexDigit", "DecInteger", "OctInteger", "HexInteger", "DecFraction",
"DecExponent", "DecFloat", "HexFraction", "BinExponent", "HexFloat",
"IntConst", "FloatConst", "WS", "LINE_COMMENT", "BLOCK_COMMENT"
};
}
public static final String[] ruleNames = makeRuleNames();
private static String[] makeLiteralNames() {
return new String[] {
null, "'const'", "'int'", "'float'", "'void'", "'if'", "'else'", "'while'",
"'break'", "'continue'", "'return'", "'+'", "'-'", "'*'", "'/'", "'%'",
"'='", "'=='", "'!='", "'<'", "'<='", "'>'", "'>='", "'!'", "'&&'", "'||'",
"'('", "')'", "'['", "']'", "'{'", "'}'", "','", "';'"
};
}
private static final String[] _LITERAL_NAMES = makeLiteralNames();
private static String[] makeSymbolicNames() {
return new String[] {
null, "CONST", "INT", "FLOAT", "VOID", "IF", "ELSE", "WHILE", "BREAK",
"CONTINUE", "RETURN", "ADD", "SUB", "MUL", "DIV", "MOD", "ASSIGN", "EQ",
"NE", "LT", "LE", "GT", "GE", "NOT", "AND", "OR", "LPAREN", "RPAREN",
"LBRACK", "RBRACK", "LBRACE", "RBRACE", "COMMA", "SEMI", "Ident", "IntConst",
"FloatConst", "WS", "LINE_COMMENT", "BLOCK_COMMENT"
};
}
private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames();
public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES);
/**
* @deprecated Use {@link #VOCABULARY} instead.
*/
@Deprecated
public static final String[] tokenNames;
static {
tokenNames = new String[_SYMBOLIC_NAMES.length];
for (int i = 0; i < tokenNames.length; i++) {
tokenNames[i] = VOCABULARY.getLiteralName(i);
if (tokenNames[i] == null) {
tokenNames[i] = VOCABULARY.getSymbolicName(i);
}
if (tokenNames[i] == null) {
tokenNames[i] = "<INVALID>";
}
}
}
@Override
@Deprecated
public String[] getTokenNames() {
return tokenNames;
}
@Override
public Vocabulary getVocabulary() {
return VOCABULARY;
}
public SysYLexer(CharStream input) {
super(input);
_interp = new LexerATNSimulator(this,_ATN,_decisionToDFA,_sharedContextCache);
}
@Override
public String getGrammarFileName() { return "SysY.g4"; }
@Override
public String[] getRuleNames() { return ruleNames; }
@Override
public String getSerializedATN() { return _serializedATN; }
@Override
public String[] getChannelNames() { return channelNames; }
@Override
public String[] getModeNames() { return modeNames; }
@Override
public ATN getATN() { return _ATN; }
public static final String _serializedATN =
"\u0004\u0000\'\u0171\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002\u0001"+
"\u0007\u0001\u0002\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004"+
"\u0007\u0004\u0002\u0005\u0007\u0005\u0002\u0006\u0007\u0006\u0002\u0007"+
"\u0007\u0007\u0002\b\u0007\b\u0002\t\u0007\t\u0002\n\u0007\n\u0002\u000b"+
"\u0007\u000b\u0002\f\u0007\f\u0002\r\u0007\r\u0002\u000e\u0007\u000e\u0002"+
"\u000f\u0007\u000f\u0002\u0010\u0007\u0010\u0002\u0011\u0007\u0011\u0002"+
"\u0012\u0007\u0012\u0002\u0013\u0007\u0013\u0002\u0014\u0007\u0014\u0002"+
"\u0015\u0007\u0015\u0002\u0016\u0007\u0016\u0002\u0017\u0007\u0017\u0002"+
"\u0018\u0007\u0018\u0002\u0019\u0007\u0019\u0002\u001a\u0007\u001a\u0002"+
"\u001b\u0007\u001b\u0002\u001c\u0007\u001c\u0002\u001d\u0007\u001d\u0002"+
"\u001e\u0007\u001e\u0002\u001f\u0007\u001f\u0002 \u0007 \u0002!\u0007"+
"!\u0002\"\u0007\"\u0002#\u0007#\u0002$\u0007$\u0002%\u0007%\u0002&\u0007"+
"&\u0002\'\u0007\'\u0002(\u0007(\u0002)\u0007)\u0002*\u0007*\u0002+\u0007"+
"+\u0002,\u0007,\u0002-\u0007-\u0002.\u0007.\u0002/\u0007/\u00020\u0007"+
"0\u00021\u00071\u00022\u00072\u00023\u00073\u0001\u0000\u0001\u0000\u0001"+
"\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0001\u0001\u0001\u0001"+
"\u0001\u0001\u0001\u0001\u0002\u0001\u0002\u0001\u0002\u0001\u0002\u0001"+
"\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0003\u0001\u0003\u0001"+
"\u0003\u0001\u0004\u0001\u0004\u0001\u0004\u0001\u0005\u0001\u0005\u0001"+
"\u0005\u0001\u0005\u0001\u0005\u0001\u0006\u0001\u0006\u0001\u0006\u0001"+
"\u0006\u0001\u0006\u0001\u0006\u0001\u0007\u0001\u0007\u0001\u0007\u0001"+
"\u0007\u0001\u0007\u0001\u0007\u0001\b\u0001\b\u0001\b\u0001\b\u0001\b"+
"\u0001\b\u0001\b\u0001\b\u0001\b\u0001\t\u0001\t\u0001\t\u0001\t\u0001"+
"\t\u0001\t\u0001\t\u0001\n\u0001\n\u0001\u000b\u0001\u000b\u0001\f\u0001"+
"\f\u0001\r\u0001\r\u0001\u000e\u0001\u000e\u0001\u000f\u0001\u000f\u0001"+
"\u0010\u0001\u0010\u0001\u0010\u0001\u0011\u0001\u0011\u0001\u0011\u0001"+
"\u0012\u0001\u0012\u0001\u0013\u0001\u0013\u0001\u0013\u0001\u0014\u0001"+
"\u0014\u0001\u0015\u0001\u0015\u0001\u0015\u0001\u0016\u0001\u0016\u0001"+
"\u0017\u0001\u0017\u0001\u0017\u0001\u0018\u0001\u0018\u0001\u0018\u0001"+
"\u0019\u0001\u0019\u0001\u001a\u0001\u001a\u0001\u001b\u0001\u001b\u0001"+
"\u001c\u0001\u001c\u0001\u001d\u0001\u001d\u0001\u001e\u0001\u001e\u0001"+
"\u001f\u0001\u001f\u0001 \u0001 \u0001!\u0001!\u0005!\u00d9\b!\n!\f!\u00dc"+
"\t!\u0001\"\u0001\"\u0001#\u0001#\u0001$\u0001$\u0001%\u0001%\u0001&\u0001"+
"&\u0005&\u00e8\b&\n&\f&\u00eb\t&\u0001\'\u0001\'\u0005\'\u00ef\b\'\n\'"+
"\f\'\u00f2\t\'\u0001(\u0001(\u0001(\u0004(\u00f7\b(\u000b(\f(\u00f8\u0001"+
")\u0004)\u00fc\b)\u000b)\f)\u00fd\u0001)\u0001)\u0005)\u0102\b)\n)\f)"+
"\u0105\t)\u0001)\u0001)\u0004)\u0109\b)\u000b)\f)\u010a\u0003)\u010d\b"+
")\u0001*\u0001*\u0003*\u0111\b*\u0001*\u0004*\u0114\b*\u000b*\f*\u0115"+
"\u0001+\u0001+\u0003+\u011a\b+\u0001+\u0001+\u0001+\u0003+\u011f\b+\u0001"+
",\u0005,\u0122\b,\n,\f,\u0125\t,\u0001,\u0001,\u0004,\u0129\b,\u000b,"+
"\f,\u012a\u0001,\u0004,\u012e\b,\u000b,\f,\u012f\u0001,\u0001,\u0003,"+
"\u0134\b,\u0001-\u0001-\u0003-\u0138\b-\u0001-\u0004-\u013b\b-\u000b-"+
"\f-\u013c\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0003"+
".\u0147\b.\u0001/\u0001/\u0001/\u0003/\u014c\b/\u00010\u00010\u00030\u0150"+
"\b0\u00011\u00041\u0153\b1\u000b1\f1\u0154\u00011\u00011\u00012\u0001"+
"2\u00012\u00012\u00052\u015d\b2\n2\f2\u0160\t2\u00012\u00012\u00013\u0001"+
"3\u00013\u00013\u00053\u0168\b3\n3\f3\u016b\t3\u00013\u00013\u00013\u0001"+
"3\u00013\u0001\u0169\u00004\u0001\u0001\u0003\u0002\u0005\u0003\u0007"+
"\u0004\t\u0005\u000b\u0006\r\u0007\u000f\b\u0011\t\u0013\n\u0015\u000b"+
"\u0017\f\u0019\r\u001b\u000e\u001d\u000f\u001f\u0010!\u0011#\u0012%\u0013"+
"\'\u0014)\u0015+\u0016-\u0017/\u00181\u00193\u001a5\u001b7\u001c9\u001d"+
";\u001e=\u001f? A!C\"E\u0000G\u0000I\u0000K\u0000M\u0000O\u0000Q\u0000"+
"S\u0000U\u0000W\u0000Y\u0000[\u0000]\u0000_#a$c%e&g\'\u0001\u0000\f\u0003"+
"\u0000AZ__az\u0004\u000009AZ__az\u0001\u000009\u0001\u000019\u0001\u0000"+
"07\u0003\u000009AFaf\u0002\u0000XXxx\u0002\u0000EEee\u0002\u0000++--\u0002"+
"\u0000PPpp\u0003\u0000\t\n\r\r \u0002\u0000\n\n\r\r\u017c\u0000\u0001"+
"\u0001\u0000\u0000\u0000\u0000\u0003\u0001\u0000\u0000\u0000\u0000\u0005"+
"\u0001\u0000\u0000\u0000\u0000\u0007\u0001\u0000\u0000\u0000\u0000\t\u0001"+
"\u0000\u0000\u0000\u0000\u000b\u0001\u0000\u0000\u0000\u0000\r\u0001\u0000"+
"\u0000\u0000\u0000\u000f\u0001\u0000\u0000\u0000\u0000\u0011\u0001\u0000"+
"\u0000\u0000\u0000\u0013\u0001\u0000\u0000\u0000\u0000\u0015\u0001\u0000"+
"\u0000\u0000\u0000\u0017\u0001\u0000\u0000\u0000\u0000\u0019\u0001\u0000"+
"\u0000\u0000\u0000\u001b\u0001\u0000\u0000\u0000\u0000\u001d\u0001\u0000"+
"\u0000\u0000\u0000\u001f\u0001\u0000\u0000\u0000\u0000!\u0001\u0000\u0000"+
"\u0000\u0000#\u0001\u0000\u0000\u0000\u0000%\u0001\u0000\u0000\u0000\u0000"+
"\'\u0001\u0000\u0000\u0000\u0000)\u0001\u0000\u0000\u0000\u0000+\u0001"+
"\u0000\u0000\u0000\u0000-\u0001\u0000\u0000\u0000\u0000/\u0001\u0000\u0000"+
"\u0000\u00001\u0001\u0000\u0000\u0000\u00003\u0001\u0000\u0000\u0000\u0000"+
"5\u0001\u0000\u0000\u0000\u00007\u0001\u0000\u0000\u0000\u00009\u0001"+
"\u0000\u0000\u0000\u0000;\u0001\u0000\u0000\u0000\u0000=\u0001\u0000\u0000"+
"\u0000\u0000?\u0001\u0000\u0000\u0000\u0000A\u0001\u0000\u0000\u0000\u0000"+
"C\u0001\u0000\u0000\u0000\u0000_\u0001\u0000\u0000\u0000\u0000a\u0001"+
"\u0000\u0000\u0000\u0000c\u0001\u0000\u0000\u0000\u0000e\u0001\u0000\u0000"+
"\u0000\u0000g\u0001\u0000\u0000\u0000\u0001i\u0001\u0000\u0000\u0000\u0003"+
"o\u0001\u0000\u0000\u0000\u0005s\u0001\u0000\u0000\u0000\u0007y\u0001"+
"\u0000\u0000\u0000\t~\u0001\u0000\u0000\u0000\u000b\u0081\u0001\u0000"+
"\u0000\u0000\r\u0086\u0001\u0000\u0000\u0000\u000f\u008c\u0001\u0000\u0000"+
"\u0000\u0011\u0092\u0001\u0000\u0000\u0000\u0013\u009b\u0001\u0000\u0000"+
"\u0000\u0015\u00a2\u0001\u0000\u0000\u0000\u0017\u00a4\u0001\u0000\u0000"+
"\u0000\u0019\u00a6\u0001\u0000\u0000\u0000\u001b\u00a8\u0001\u0000\u0000"+
"\u0000\u001d\u00aa\u0001\u0000\u0000\u0000\u001f\u00ac\u0001\u0000\u0000"+
"\u0000!\u00ae\u0001\u0000\u0000\u0000#\u00b1\u0001\u0000\u0000\u0000%"+
"\u00b4\u0001\u0000\u0000\u0000\'\u00b6\u0001\u0000\u0000\u0000)\u00b9"+
"\u0001\u0000\u0000\u0000+\u00bb\u0001\u0000\u0000\u0000-\u00be\u0001\u0000"+
"\u0000\u0000/\u00c0\u0001\u0000\u0000\u00001\u00c3\u0001\u0000\u0000\u0000"+
"3\u00c6\u0001\u0000\u0000\u00005\u00c8\u0001\u0000\u0000\u00007\u00ca"+
"\u0001\u0000\u0000\u00009\u00cc\u0001\u0000\u0000\u0000;\u00ce\u0001\u0000"+
"\u0000\u0000=\u00d0\u0001\u0000\u0000\u0000?\u00d2\u0001\u0000\u0000\u0000"+
"A\u00d4\u0001\u0000\u0000\u0000C\u00d6\u0001\u0000\u0000\u0000E\u00dd"+
"\u0001\u0000\u0000\u0000G\u00df\u0001\u0000\u0000\u0000I\u00e1\u0001\u0000"+
"\u0000\u0000K\u00e3\u0001\u0000\u0000\u0000M\u00e5\u0001\u0000\u0000\u0000"+
"O\u00ec\u0001\u0000\u0000\u0000Q\u00f3\u0001\u0000\u0000\u0000S\u010c"+
"\u0001\u0000\u0000\u0000U\u010e\u0001\u0000\u0000\u0000W\u011e\u0001\u0000"+
"\u0000\u0000Y\u0133\u0001\u0000\u0000\u0000[\u0135\u0001\u0000\u0000\u0000"+
"]\u0146\u0001\u0000\u0000\u0000_\u014b\u0001\u0000\u0000\u0000a\u014f"+
"\u0001\u0000\u0000\u0000c\u0152\u0001\u0000\u0000\u0000e\u0158\u0001\u0000"+
"\u0000\u0000g\u0163\u0001\u0000\u0000\u0000ij\u0005c\u0000\u0000jk\u0005"+
"o\u0000\u0000kl\u0005n\u0000\u0000lm\u0005s\u0000\u0000mn\u0005t\u0000"+
"\u0000n\u0002\u0001\u0000\u0000\u0000op\u0005i\u0000\u0000pq\u0005n\u0000"+
"\u0000qr\u0005t\u0000\u0000r\u0004\u0001\u0000\u0000\u0000st\u0005f\u0000"+
"\u0000tu\u0005l\u0000\u0000uv\u0005o\u0000\u0000vw\u0005a\u0000\u0000"+
"wx\u0005t\u0000\u0000x\u0006\u0001\u0000\u0000\u0000yz\u0005v\u0000\u0000"+
"z{\u0005o\u0000\u0000{|\u0005i\u0000\u0000|}\u0005d\u0000\u0000}\b\u0001"+
"\u0000\u0000\u0000~\u007f\u0005i\u0000\u0000\u007f\u0080\u0005f\u0000"+
"\u0000\u0080\n\u0001\u0000\u0000\u0000\u0081\u0082\u0005e\u0000\u0000"+
"\u0082\u0083\u0005l\u0000\u0000\u0083\u0084\u0005s\u0000\u0000\u0084\u0085"+
"\u0005e\u0000\u0000\u0085\f\u0001\u0000\u0000\u0000\u0086\u0087\u0005"+
"w\u0000\u0000\u0087\u0088\u0005h\u0000\u0000\u0088\u0089\u0005i\u0000"+
"\u0000\u0089\u008a\u0005l\u0000\u0000\u008a\u008b\u0005e\u0000\u0000\u008b"+
"\u000e\u0001\u0000\u0000\u0000\u008c\u008d\u0005b\u0000\u0000\u008d\u008e"+
"\u0005r\u0000\u0000\u008e\u008f\u0005e\u0000\u0000\u008f\u0090\u0005a"+
"\u0000\u0000\u0090\u0091\u0005k\u0000\u0000\u0091\u0010\u0001\u0000\u0000"+
"\u0000\u0092\u0093\u0005c\u0000\u0000\u0093\u0094\u0005o\u0000\u0000\u0094"+
"\u0095\u0005n\u0000\u0000\u0095\u0096\u0005t\u0000\u0000\u0096\u0097\u0005"+
"i\u0000\u0000\u0097\u0098\u0005n\u0000\u0000\u0098\u0099\u0005u\u0000"+
"\u0000\u0099\u009a\u0005e\u0000\u0000\u009a\u0012\u0001\u0000\u0000\u0000"+
"\u009b\u009c\u0005r\u0000\u0000\u009c\u009d\u0005e\u0000\u0000\u009d\u009e"+
"\u0005t\u0000\u0000\u009e\u009f\u0005u\u0000\u0000\u009f\u00a0\u0005r"+
"\u0000\u0000\u00a0\u00a1\u0005n\u0000\u0000\u00a1\u0014\u0001\u0000\u0000"+
"\u0000\u00a2\u00a3\u0005+\u0000\u0000\u00a3\u0016\u0001\u0000\u0000\u0000"+
"\u00a4\u00a5\u0005-\u0000\u0000\u00a5\u0018\u0001\u0000\u0000\u0000\u00a6"+
"\u00a7\u0005*\u0000\u0000\u00a7\u001a\u0001\u0000\u0000\u0000\u00a8\u00a9"+
"\u0005/\u0000\u0000\u00a9\u001c\u0001\u0000\u0000\u0000\u00aa\u00ab\u0005"+
"%\u0000\u0000\u00ab\u001e\u0001\u0000\u0000\u0000\u00ac\u00ad\u0005=\u0000"+
"\u0000\u00ad \u0001\u0000\u0000\u0000\u00ae\u00af\u0005=\u0000\u0000\u00af"+
"\u00b0\u0005=\u0000\u0000\u00b0\"\u0001\u0000\u0000\u0000\u00b1\u00b2"+
"\u0005!\u0000\u0000\u00b2\u00b3\u0005=\u0000\u0000\u00b3$\u0001\u0000"+
"\u0000\u0000\u00b4\u00b5\u0005<\u0000\u0000\u00b5&\u0001\u0000\u0000\u0000"+
"\u00b6\u00b7\u0005<\u0000\u0000\u00b7\u00b8\u0005=\u0000\u0000\u00b8("+
"\u0001\u0000\u0000\u0000\u00b9\u00ba\u0005>\u0000\u0000\u00ba*\u0001\u0000"+
"\u0000\u0000\u00bb\u00bc\u0005>\u0000\u0000\u00bc\u00bd\u0005=\u0000\u0000"+
"\u00bd,\u0001\u0000\u0000\u0000\u00be\u00bf\u0005!\u0000\u0000\u00bf."+
"\u0001\u0000\u0000\u0000\u00c0\u00c1\u0005&\u0000\u0000\u00c1\u00c2\u0005"+
"&\u0000\u0000\u00c20\u0001\u0000\u0000\u0000\u00c3\u00c4\u0005|\u0000"+
"\u0000\u00c4\u00c5\u0005|\u0000\u0000\u00c52\u0001\u0000\u0000\u0000\u00c6"+
"\u00c7\u0005(\u0000\u0000\u00c74\u0001\u0000\u0000\u0000\u00c8\u00c9\u0005"+
")\u0000\u0000\u00c96\u0001\u0000\u0000\u0000\u00ca\u00cb\u0005[\u0000"+
"\u0000\u00cb8\u0001\u0000\u0000\u0000\u00cc\u00cd\u0005]\u0000\u0000\u00cd"+
":\u0001\u0000\u0000\u0000\u00ce\u00cf\u0005{\u0000\u0000\u00cf<\u0001"+
"\u0000\u0000\u0000\u00d0\u00d1\u0005}\u0000\u0000\u00d1>\u0001\u0000\u0000"+
"\u0000\u00d2\u00d3\u0005,\u0000\u0000\u00d3@\u0001\u0000\u0000\u0000\u00d4"+
"\u00d5\u0005;\u0000\u0000\u00d5B\u0001\u0000\u0000\u0000\u00d6\u00da\u0007"+
"\u0000\u0000\u0000\u00d7\u00d9\u0007\u0001\u0000\u0000\u00d8\u00d7\u0001"+
"\u0000\u0000\u0000\u00d9\u00dc\u0001\u0000\u0000\u0000\u00da\u00d8\u0001"+
"\u0000\u0000\u0000\u00da\u00db\u0001\u0000\u0000\u0000\u00dbD\u0001\u0000"+
"\u0000\u0000\u00dc\u00da\u0001\u0000\u0000\u0000\u00dd\u00de\u0007\u0002"+
"\u0000\u0000\u00deF\u0001\u0000\u0000\u0000\u00df\u00e0\u0007\u0003\u0000"+
"\u0000\u00e0H\u0001\u0000\u0000\u0000\u00e1\u00e2\u0007\u0004\u0000\u0000"+
"\u00e2J\u0001\u0000\u0000\u0000\u00e3\u00e4\u0007\u0005\u0000\u0000\u00e4"+
"L\u0001\u0000\u0000\u0000\u00e5\u00e9\u0003G#\u0000\u00e6\u00e8\u0003"+
"E\"\u0000\u00e7\u00e6\u0001\u0000\u0000\u0000\u00e8\u00eb\u0001\u0000"+
"\u0000\u0000\u00e9\u00e7\u0001\u0000\u0000\u0000\u00e9\u00ea\u0001\u0000"+
"\u0000\u0000\u00eaN\u0001\u0000\u0000\u0000\u00eb\u00e9\u0001\u0000\u0000"+
"\u0000\u00ec\u00f0\u00050\u0000\u0000\u00ed\u00ef\u0003I$\u0000\u00ee"+
"\u00ed\u0001\u0000\u0000\u0000\u00ef\u00f2\u0001\u0000\u0000\u0000\u00f0"+
"\u00ee\u0001\u0000\u0000\u0000\u00f0\u00f1\u0001\u0000\u0000\u0000\u00f1"+
"P\u0001\u0000\u0000\u0000\u00f2\u00f0\u0001\u0000\u0000\u0000\u00f3\u00f4"+
"\u00050\u0000\u0000\u00f4\u00f6\u0007\u0006\u0000\u0000\u00f5\u00f7\u0003"+
"K%\u0000\u00f6\u00f5\u0001\u0000\u0000\u0000\u00f7\u00f8\u0001\u0000\u0000"+
"\u0000\u00f8\u00f6\u0001\u0000\u0000\u0000\u00f8\u00f9\u0001\u0000\u0000"+
"\u0000\u00f9R\u0001\u0000\u0000\u0000\u00fa\u00fc\u0003E\"\u0000\u00fb"+
"\u00fa\u0001\u0000\u0000\u0000\u00fc\u00fd\u0001\u0000\u0000\u0000\u00fd"+
"\u00fb\u0001\u0000\u0000\u0000\u00fd\u00fe\u0001\u0000\u0000\u0000\u00fe"+
"\u00ff\u0001\u0000\u0000\u0000\u00ff\u0103\u0005.\u0000\u0000\u0100\u0102"+
"\u0003E\"\u0000\u0101\u0100\u0001\u0000\u0000\u0000\u0102\u0105\u0001"+
"\u0000\u0000\u0000\u0103\u0101\u0001\u0000\u0000\u0000\u0103\u0104\u0001"+
"\u0000\u0000\u0000\u0104\u010d\u0001\u0000\u0000\u0000\u0105\u0103\u0001"+
"\u0000\u0000\u0000\u0106\u0108\u0005.\u0000\u0000\u0107\u0109\u0003E\""+
"\u0000\u0108\u0107\u0001\u0000\u0000\u0000\u0109\u010a\u0001\u0000\u0000"+
"\u0000\u010a\u0108\u0001\u0000\u0000\u0000\u010a\u010b\u0001\u0000\u0000"+
"\u0000\u010b\u010d\u0001\u0000\u0000\u0000\u010c\u00fb\u0001\u0000\u0000"+
"\u0000\u010c\u0106\u0001\u0000\u0000\u0000\u010dT\u0001\u0000\u0000\u0000"+
"\u010e\u0110\u0007\u0007\u0000\u0000\u010f\u0111\u0007\b\u0000\u0000\u0110"+
"\u010f\u0001\u0000\u0000\u0000\u0110\u0111\u0001\u0000\u0000\u0000\u0111"+
"\u0113\u0001\u0000\u0000\u0000\u0112\u0114\u0003E\"\u0000\u0113\u0112"+
"\u0001\u0000\u0000\u0000\u0114\u0115\u0001\u0000\u0000\u0000\u0115\u0113"+
"\u0001\u0000\u0000\u0000\u0115\u0116\u0001\u0000\u0000\u0000\u0116V\u0001"+
"\u0000\u0000\u0000\u0117\u0119\u0003S)\u0000\u0118\u011a\u0003U*\u0000"+
"\u0119\u0118\u0001\u0000\u0000\u0000\u0119\u011a\u0001\u0000\u0000\u0000"+
"\u011a\u011f\u0001\u0000\u0000\u0000\u011b\u011c\u0003M&\u0000\u011c\u011d"+
"\u0003U*\u0000\u011d\u011f\u0001\u0000\u0000\u0000\u011e\u0117\u0001\u0000"+
"\u0000\u0000\u011e\u011b\u0001\u0000\u0000\u0000\u011fX\u0001\u0000\u0000"+
"\u0000\u0120\u0122\u0003K%\u0000\u0121\u0120\u0001\u0000\u0000\u0000\u0122"+
"\u0125\u0001\u0000\u0000\u0000\u0123\u0121\u0001\u0000\u0000\u0000\u0123"+
"\u0124\u0001\u0000\u0000\u0000\u0124\u0126\u0001\u0000\u0000\u0000\u0125"+
"\u0123\u0001\u0000\u0000\u0000\u0126\u0128\u0005.\u0000\u0000\u0127\u0129"+
"\u0003K%\u0000\u0128\u0127\u0001\u0000\u0000\u0000\u0129\u012a\u0001\u0000"+
"\u0000\u0000\u012a\u0128\u0001\u0000\u0000\u0000\u012a\u012b\u0001\u0000"+
"\u0000\u0000\u012b\u0134\u0001\u0000\u0000\u0000\u012c\u012e\u0003K%\u0000"+
"\u012d\u012c\u0001\u0000\u0000\u0000\u012e\u012f\u0001\u0000\u0000\u0000"+
"\u012f\u012d\u0001\u0000\u0000\u0000\u012f\u0130\u0001\u0000\u0000\u0000"+
"\u0130\u0131\u0001\u0000\u0000\u0000\u0131\u0132\u0005.\u0000\u0000\u0132"+
"\u0134\u0001\u0000\u0000\u0000\u0133\u0123\u0001\u0000\u0000\u0000\u0133"+
"\u012d\u0001\u0000\u0000\u0000\u0134Z\u0001\u0000\u0000\u0000\u0135\u0137"+
"\u0007\t\u0000\u0000\u0136\u0138\u0007\b\u0000\u0000\u0137\u0136\u0001"+
"\u0000\u0000\u0000\u0137\u0138\u0001\u0000\u0000\u0000\u0138\u013a\u0001"+
"\u0000\u0000\u0000\u0139\u013b\u0003E\"\u0000\u013a\u0139\u0001\u0000"+
"\u0000\u0000\u013b\u013c\u0001\u0000\u0000\u0000\u013c\u013a\u0001\u0000"+
"\u0000\u0000\u013c\u013d\u0001\u0000\u0000\u0000\u013d\\\u0001\u0000\u0000"+
"\u0000\u013e\u013f\u00050\u0000\u0000\u013f\u0140\u0007\u0006\u0000\u0000"+
"\u0140\u0141\u0003Y,\u0000\u0141\u0142\u0003[-\u0000\u0142\u0147\u0001"+
"\u0000\u0000\u0000\u0143\u0144\u0003Q(\u0000\u0144\u0145\u0003[-\u0000"+
"\u0145\u0147\u0001\u0000\u0000\u0000\u0146\u013e\u0001\u0000\u0000\u0000"+
"\u0146\u0143\u0001\u0000\u0000\u0000\u0147^\u0001\u0000\u0000\u0000\u0148"+
"\u014c\u0003M&\u0000\u0149\u014c\u0003O\'\u0000\u014a\u014c\u0003Q(\u0000"+
"\u014b\u0148\u0001\u0000\u0000\u0000\u014b\u0149\u0001\u0000\u0000\u0000"+
"\u014b\u014a\u0001\u0000\u0000\u0000\u014c`\u0001\u0000\u0000\u0000\u014d"+
"\u0150\u0003W+\u0000\u014e\u0150\u0003].\u0000\u014f\u014d\u0001\u0000"+
"\u0000\u0000\u014f\u014e\u0001\u0000\u0000\u0000\u0150b\u0001\u0000\u0000"+
"\u0000\u0151\u0153\u0007\n\u0000\u0000\u0152\u0151\u0001\u0000\u0000\u0000"+
"\u0153\u0154\u0001\u0000\u0000\u0000\u0154\u0152\u0001\u0000\u0000\u0000"+
"\u0154\u0155\u0001\u0000\u0000\u0000\u0155\u0156\u0001\u0000\u0000\u0000"+
"\u0156\u0157\u00061\u0000\u0000\u0157d\u0001\u0000\u0000\u0000\u0158\u0159"+
"\u0005/\u0000\u0000\u0159\u015a\u0005/\u0000\u0000\u015a\u015e\u0001\u0000"+
"\u0000\u0000\u015b\u015d\b\u000b\u0000\u0000\u015c\u015b\u0001\u0000\u0000"+
"\u0000\u015d\u0160\u0001\u0000\u0000\u0000\u015e\u015c\u0001\u0000\u0000"+
"\u0000\u015e\u015f\u0001\u0000\u0000\u0000\u015f\u0161\u0001\u0000\u0000"+
"\u0000\u0160\u015e\u0001\u0000\u0000\u0000\u0161\u0162\u00062\u0000\u0000"+
"\u0162f\u0001\u0000\u0000\u0000\u0163\u0164\u0005/\u0000\u0000\u0164\u0165"+
"\u0005*\u0000\u0000\u0165\u0169\u0001\u0000\u0000\u0000\u0166\u0168\t"+
"\u0000\u0000\u0000\u0167\u0166\u0001\u0000\u0000\u0000\u0168\u016b\u0001"+
"\u0000\u0000\u0000\u0169\u016a\u0001\u0000\u0000\u0000\u0169\u0167\u0001"+
"\u0000\u0000\u0000\u016a\u016c\u0001\u0000\u0000\u0000\u016b\u0169\u0001"+
"\u0000\u0000\u0000\u016c\u016d\u0005*\u0000\u0000\u016d\u016e\u0005/\u0000"+
"\u0000\u016e\u016f\u0001\u0000\u0000\u0000\u016f\u0170\u00063\u0000\u0000"+
"\u0170h\u0001\u0000\u0000\u0000\u0019\u0000\u00da\u00e9\u00f0\u00f8\u00fd"+
"\u0103\u010a\u010c\u0110\u0115\u0119\u011e\u0123\u012a\u012f\u0133\u0137"+
"\u013c\u0146\u014b\u014f\u0154\u015e\u0169\u0001\u0006\u0000\u0000";
public static final ATN _ATN =
new ATNDeserializer().deserialize(_serializedATN.toCharArray());
static {
_decisionToDFA = new DFA[_ATN.getNumberOfDecisions()];
for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) {
_decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i);
}
}
}

@ -0,0 +1,515 @@
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
import org.antlr.v4.runtime.tree.ParseTreeListener;
/**
* This interface defines a complete listener for a parse tree produced by
* {@link SysYParser}.
*/
public interface SysYListener extends ParseTreeListener {
/**
* Enter a parse tree produced by {@link SysYParser#compUnit}.
* @param ctx the parse tree
*/
void enterCompUnit(SysYParser.CompUnitContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#compUnit}.
* @param ctx the parse tree
*/
void exitCompUnit(SysYParser.CompUnitContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#decl}.
* @param ctx the parse tree
*/
void enterDecl(SysYParser.DeclContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#decl}.
* @param ctx the parse tree
*/
void exitDecl(SysYParser.DeclContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constDecl}.
* @param ctx the parse tree
*/
void enterConstDecl(SysYParser.ConstDeclContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constDecl}.
* @param ctx the parse tree
*/
void exitConstDecl(SysYParser.ConstDeclContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#bType}.
* @param ctx the parse tree
*/
void enterBType(SysYParser.BTypeContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#bType}.
* @param ctx the parse tree
*/
void exitBType(SysYParser.BTypeContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constDef}.
* @param ctx the parse tree
*/
void enterConstDef(SysYParser.ConstDefContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constDef}.
* @param ctx the parse tree
*/
void exitConstDef(SysYParser.ConstDefContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constInitVal}.
* @param ctx the parse tree
*/
void enterConstInitVal(SysYParser.ConstInitValContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constInitVal}.
* @param ctx the parse tree
*/
void exitConstInitVal(SysYParser.ConstInitValContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#varDecl}.
* @param ctx the parse tree
*/
void enterVarDecl(SysYParser.VarDeclContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#varDecl}.
* @param ctx the parse tree
*/
void exitVarDecl(SysYParser.VarDeclContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#varDef}.
* @param ctx the parse tree
*/
void enterVarDef(SysYParser.VarDefContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#varDef}.
* @param ctx the parse tree
*/
void exitVarDef(SysYParser.VarDefContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#initVal}.
* @param ctx the parse tree
*/
void enterInitVal(SysYParser.InitValContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#initVal}.
* @param ctx the parse tree
*/
void exitInitVal(SysYParser.InitValContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcDef}.
* @param ctx the parse tree
*/
void enterFuncDef(SysYParser.FuncDefContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcDef}.
* @param ctx the parse tree
*/
void exitFuncDef(SysYParser.FuncDefContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcType}.
* @param ctx the parse tree
*/
void enterFuncType(SysYParser.FuncTypeContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcType}.
* @param ctx the parse tree
*/
void exitFuncType(SysYParser.FuncTypeContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcFParams}.
* @param ctx the parse tree
*/
void enterFuncFParams(SysYParser.FuncFParamsContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcFParams}.
* @param ctx the parse tree
*/
void exitFuncFParams(SysYParser.FuncFParamsContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcFParam}.
* @param ctx the parse tree
*/
void enterFuncFParam(SysYParser.FuncFParamContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcFParam}.
* @param ctx the parse tree
*/
void exitFuncFParam(SysYParser.FuncFParamContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#block}.
* @param ctx the parse tree
*/
void enterBlock(SysYParser.BlockContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#block}.
* @param ctx the parse tree
*/
void exitBlock(SysYParser.BlockContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#blockItem}.
* @param ctx the parse tree
*/
void enterBlockItem(SysYParser.BlockItemContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#blockItem}.
* @param ctx the parse tree
*/
void exitBlockItem(SysYParser.BlockItemContext ctx);
/**
* Enter a parse tree produced by the {@code assignStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterAssignStmt(SysYParser.AssignStmtContext ctx);
/**
* Exit a parse tree produced by the {@code assignStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitAssignStmt(SysYParser.AssignStmtContext ctx);
/**
* Enter a parse tree produced by the {@code expStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterExpStmt(SysYParser.ExpStmtContext ctx);
/**
* Exit a parse tree produced by the {@code expStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitExpStmt(SysYParser.ExpStmtContext ctx);
/**
* Enter a parse tree produced by the {@code blockStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterBlockStmt(SysYParser.BlockStmtContext ctx);
/**
* Exit a parse tree produced by the {@code blockStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitBlockStmt(SysYParser.BlockStmtContext ctx);
/**
* Enter a parse tree produced by the {@code ifStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterIfStmt(SysYParser.IfStmtContext ctx);
/**
* Exit a parse tree produced by the {@code ifStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitIfStmt(SysYParser.IfStmtContext ctx);
/**
* Enter a parse tree produced by the {@code whileStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterWhileStmt(SysYParser.WhileStmtContext ctx);
/**
* Exit a parse tree produced by the {@code whileStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitWhileStmt(SysYParser.WhileStmtContext ctx);
/**
* Enter a parse tree produced by the {@code breakStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterBreakStmt(SysYParser.BreakStmtContext ctx);
/**
* Exit a parse tree produced by the {@code breakStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitBreakStmt(SysYParser.BreakStmtContext ctx);
/**
* Enter a parse tree produced by the {@code continueStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterContinueStmt(SysYParser.ContinueStmtContext ctx);
/**
* Exit a parse tree produced by the {@code continueStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitContinueStmt(SysYParser.ContinueStmtContext ctx);
/**
* Enter a parse tree produced by the {@code returnStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterReturnStmt(SysYParser.ReturnStmtContext ctx);
/**
* Exit a parse tree produced by the {@code returnStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitReturnStmt(SysYParser.ReturnStmtContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#exp}.
* @param ctx the parse tree
*/
void enterExp(SysYParser.ExpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#exp}.
* @param ctx the parse tree
*/
void exitExp(SysYParser.ExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#cond}.
* @param ctx the parse tree
*/
void enterCond(SysYParser.CondContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#cond}.
* @param ctx the parse tree
*/
void exitCond(SysYParser.CondContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#lVal}.
* @param ctx the parse tree
*/
void enterLVal(SysYParser.LValContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#lVal}.
* @param ctx the parse tree
*/
void exitLVal(SysYParser.LValContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#primaryExp}.
* @param ctx the parse tree
*/
void enterPrimaryExp(SysYParser.PrimaryExpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#primaryExp}.
* @param ctx the parse tree
*/
void exitPrimaryExp(SysYParser.PrimaryExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#number}.
* @param ctx the parse tree
*/
void enterNumber(SysYParser.NumberContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#number}.
* @param ctx the parse tree
*/
void exitNumber(SysYParser.NumberContext ctx);
/**
* Enter a parse tree produced by the {@code primaryUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void enterPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx);
/**
* Exit a parse tree produced by the {@code primaryUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void exitPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx);
/**
* Enter a parse tree produced by the {@code callUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void enterCallUnaryExp(SysYParser.CallUnaryExpContext ctx);
/**
* Exit a parse tree produced by the {@code callUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void exitCallUnaryExp(SysYParser.CallUnaryExpContext ctx);
/**
* Enter a parse tree produced by the {@code opUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void enterOpUnaryExp(SysYParser.OpUnaryExpContext ctx);
/**
* Exit a parse tree produced by the {@code opUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void exitOpUnaryExp(SysYParser.OpUnaryExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#unaryOp}.
* @param ctx the parse tree
*/
void enterUnaryOp(SysYParser.UnaryOpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#unaryOp}.
* @param ctx the parse tree
*/
void exitUnaryOp(SysYParser.UnaryOpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcRParams}.
* @param ctx the parse tree
*/
void enterFuncRParams(SysYParser.FuncRParamsContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcRParams}.
* @param ctx the parse tree
*/
void exitFuncRParams(SysYParser.FuncRParamsContext ctx);
/**
* Enter a parse tree produced by the {@code binaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void enterBinaryMulExp(SysYParser.BinaryMulExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void exitBinaryMulExp(SysYParser.BinaryMulExpContext ctx);
/**
* Enter a parse tree produced by the {@code unaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void enterUnaryMulExp(SysYParser.UnaryMulExpContext ctx);
/**
* Exit a parse tree produced by the {@code unaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void exitUnaryMulExp(SysYParser.UnaryMulExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void enterBinaryAddExp(SysYParser.BinaryAddExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void exitBinaryAddExp(SysYParser.BinaryAddExpContext ctx);
/**
* Enter a parse tree produced by the {@code mulAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void enterMulAddExp(SysYParser.MulAddExpContext ctx);
/**
* Exit a parse tree produced by the {@code mulAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void exitMulAddExp(SysYParser.MulAddExpContext ctx);
/**
* Enter a parse tree produced by the {@code addRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void enterAddRelExp(SysYParser.AddRelExpContext ctx);
/**
* Exit a parse tree produced by the {@code addRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void exitAddRelExp(SysYParser.AddRelExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void enterBinaryRelExp(SysYParser.BinaryRelExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void exitBinaryRelExp(SysYParser.BinaryRelExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void enterBinaryEqExp(SysYParser.BinaryEqExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void exitBinaryEqExp(SysYParser.BinaryEqExpContext ctx);
/**
* Enter a parse tree produced by the {@code relEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void enterRelEqExp(SysYParser.RelEqExpContext ctx);
/**
* Exit a parse tree produced by the {@code relEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void exitRelEqExp(SysYParser.RelEqExpContext ctx);
/**
* Enter a parse tree produced by the {@code eqLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void enterEqLAndExp(SysYParser.EqLAndExpContext ctx);
/**
* Exit a parse tree produced by the {@code eqLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void exitEqLAndExp(SysYParser.EqLAndExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void enterBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void exitBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx);
/**
* Enter a parse tree produced by the {@code andLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void enterAndLOrExp(SysYParser.AndLOrExpContext ctx);
/**
* Exit a parse tree produced by the {@code andLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void exitAndLOrExp(SysYParser.AndLOrExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void enterBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void exitBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constExp}.
* @param ctx the parse tree
*/
void enterConstExp(SysYParser.ConstExpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constExp}.
* @param ctx the parse tree
*/
void exitConstExp(SysYParser.ConstExpContext ctx);
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,178 @@
grammar SysY;
////Grammer
module: compUnit EOF;
compUnit: (decl | funcDef)+;
decl: constDecl | varDecl;
constDecl: CONST bType constDef (COMMA constDef)* SEMI;
bType: INT | FLOAT;
constDef: Ident (LBRACK constExp RBRACK)* ASSIGN constInitVal;
constInitVal:
constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE;
varDecl: bType varDef (COMMA varDef)* SEMI;
varDef:
Ident (LBRACK constExp RBRACK)*
| Ident (LBRACK constExp RBRACK)* ASSIGN initVal;
initVal: exp | LBRACE (initVal (COMMA initVal)*)? RBRACE;
funcDef: funcType Ident LPAREN funcFParams? RPAREN block;
funcType: VOID | INT | FLOAT;
funcFParams: funcFParam (COMMA funcFParam)*;
funcFParam:
bType Ident
| bType Ident LBRACK RBRACK (LBRACK exp RBRACK)*;
block: LBRACE blockItem* RBRACE;
blockItem: decl | stmt;
stmt:
lVal ASSIGN exp SEMI
| exp? SEMI
| block
| IF LPAREN cond RPAREN stmt (ELSE stmt)?
| WHILE LPAREN cond RPAREN stmt
| BREAK SEMI
| CONTINUE SEMI
| RETURN exp? SEMI;
exp: addExp;
cond: lOrExp;
lVal: Ident (LBRACK exp RBRACK)*;
primaryExp: LPAREN exp RPAREN | lVal | number;
number: IntConst | FloatConst;
unaryExp:
primaryExp
| Ident LPAREN funcRParams? RPAREN
| unaryOp unaryExp;
unaryOp: ADD | SUB | NOT;
funcRParams: exp (COMMA exp)*;
mulExp:
unaryExp
| mulExp op = (MUL | DIV | MOD) unaryExp;
addExp:
mulExp
| addExp op = (ADD | SUB) mulExp;
relExp:
addExp
| relExp op = (LT | GT | LE | GE) addExp;
eqExp:
relExp
| eqExp op = (EQ | NE) relExp;
lAndExp: eqExp | lAndExp AND eqExp ;
lOrExp: lAndExp | lOrExp OR lAndExp ;
constExp: addExp;
////Lexer
//keywords
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
RETURN: 'return';
//operators
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
ASSIGN: '=';
EQ: '==';
NE: '!=';
LT: '<';
LE: '<=';
GT: '>';
GE: '>=';
NOT: '!';
AND: '&&';
OR: '||';
//括号
LPAREN: '(';
RPAREN: ')';
LBRACK: '[';
RBRACK: ']';
LBRACE: '{';
RBRACE: '}';
COMMA: ',';
SEMI: ';';
//标识符
Ident: [a-zA-Z_] [a-zA-Z_0-9]*;
//数字常量片段
// 十进制数字
fragment Digit: [0-9];
// 非零十进制数字
fragment NonzeroDigit: [1-9];
// 八进制数字
fragment OctDigit: [0-7];
// 十六进制数字
fragment HexDigit: [0-9a-fA-F];
// 十进制整数:非零开头,后接若干十进制数字
fragment DecInteger: NonzeroDigit Digit*;
// 八进制整数:以 0 开头
fragment OctInteger: '0' OctDigit*;
// 十六进制整数:以 0x 或 0X 开头
fragment HexInteger: '0' [xX] HexDigit+;
// 十进制小数部分
fragment DecFraction: Digit+ '.' Digit* | '.' Digit+;
// 十进制指数部分
fragment DecExponent: [eE] [+\-]? Digit+;
// 十进制浮点数
fragment DecFloat:
DecFraction DecExponent?
| DecInteger DecExponent;
// 十六进制小数部分
fragment HexFraction: HexDigit* '.' HexDigit+ | HexDigit+ '.';
// 十六进制浮点数的二进制指数部分
fragment BinExponent: [pP] [+\-]? Digit+;
// 十六进制浮点数
fragment HexFloat:
'0' [xX] HexFraction BinExponent
| HexInteger BinExponent;
//整型常量
IntConst: DecInteger | OctInteger | HexInteger;
//浮点常量
FloatConst: DecFloat | HexFloat;
//空白符规则
WS: [ \t\r\n]+ -> skip;
// 单行注释
LINE_COMMENT: '//' ~[\r\n]* -> skip;
// 跨行注释
BLOCK_COMMENT: '/*' .*? '*/' -> skip;

@ -0,0 +1,74 @@
// 调用前端解析流程,返回语法树。
#include "frontend/AntlrDriver.h"
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <string>
#include "SysYLexer.h"
#include "SysYParser.h"
#include "antlr4-runtime.h"
#include "utils/Log.h"
namespace {
class ParseErrorListener : public antlr4::BaseErrorListener {
public:
void syntaxError(antlr4::Recognizer* /*recognizer*/, antlr4::Token* /*offendingSymbol*/,
size_t line, size_t charPositionInLine,
const std::string& msg, std::exception_ptr /*e*/) override {
throw std::runtime_error(FormatErrorAt("parse", line, charPositionInLine,
"暂不支持的语法/词法 - " + msg));
}
};
} // namespace
AntlrResult ParseFileWithAntlr(const std::string& path) {
std::ifstream fin(path);
if (!fin.is_open()) {
throw std::runtime_error(FormatError("parse", "无法打开输入文件: " + path));
}
std::ostringstream ss;
ss << fin.rdbuf();
auto input = std::make_unique<antlr4::ANTLRInputStream>(ss.str());
auto lexer = std::make_unique<SysYLexer>(input.get());
auto tokens = std::make_unique<antlr4::CommonTokenStream>(lexer.get());
auto parser = std::make_unique<SysYParser>(tokens.get());
ParseErrorListener error_listener;
lexer->removeErrorListeners();
lexer->addErrorListener(&error_listener);
parser->removeErrorListeners();
parser->addErrorListener(&error_listener);
parser->setErrorHandler(std::make_shared<antlr4::BailErrorStrategy>());
antlr4::tree::ParseTree* tree = nullptr;
try {
tree = parser->compUnit();
} catch (const std::exception& ex) {
const std::string msg = ex.what();
if (!msg.empty()) {
if (HasErrorPrefix(msg, "parse")) {
throw;
}
throw std::runtime_error(
FormatError("parse", "暂不支持的语法/词法 - " + msg));
}
if (auto* tok = parser->getCurrentToken()) {
throw std::runtime_error(
FormatErrorAt("parse", tok->getLine(), tok->getCharPositionInLine(),
"暂不支持的语法/词法 near token '" + tok->getText() + "'"));
}
throw std::runtime_error(FormatError("parse", "暂不支持的语法/词法"));
}
AntlrResult result;
result.input = std::move(input);
result.lexer = std::move(lexer);
result.tokens = std::move(tokens);
result.parser = std::move(parser);
result.tree = tree;
return result;
}

@ -0,0 +1,34 @@
find_package(Java REQUIRED COMPONENTS Runtime)
set(SYSY_GRAMMAR "${PROJECT_SOURCE_DIR}/src/antlr4/SysY.g4")
set(SYSY_ANTLR_JAR "${PROJECT_SOURCE_DIR}/third_party/antlr-4.13.2-complete.jar")
set(SYSY_ANTLR_OUTPUTS
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.h"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.h"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.cpp"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.h"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.cpp"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.h"
)
add_custom_command(
OUTPUT ${SYSY_ANTLR_OUTPUTS}
COMMAND ${CMAKE_COMMAND} -E make_directory "${ANTLR4_GENERATED_DIR}"
COMMAND ${Java_JAVA_EXECUTABLE} -jar "${SYSY_ANTLR_JAR}" -Dlanguage=Cpp -visitor -no-listener -o "${ANTLR4_GENERATED_DIR}" -Xexact-output-dir "${SYSY_GRAMMAR}"
DEPENDS "${SYSY_GRAMMAR}" "${SYSY_ANTLR_JAR}"
COMMENT "Generating SysY parser with ANTLR4"
VERBATIM
)
add_library(frontend STATIC
AntlrDriver.cpp
SyntaxTreePrinter.cpp
${SYSY_ANTLR_OUTPUTS}
)
target_link_libraries(frontend PUBLIC
build_options
${ANTLR4_RUNTIME_TARGET}
)

@ -0,0 +1,71 @@
#include "frontend/SyntaxTreePrinter.h"
#include <string>
namespace {
std::string GetTokenName(const antlr4::Token* tok, antlr4::Parser* parser) {
if (!tok || !parser) {
return "UNKNOWN";
}
const int token_type = tok->getType();
const auto& vocab = parser->getVocabulary();
std::string token_name(vocab.getSymbolicName(token_type));
if (token_name.empty()) {
token_name = std::string(vocab.getLiteralName(token_type));
}
if (token_name.empty()) {
token_name = std::to_string(token_type);
}
return token_name;
}
std::string RuleName(antlr4::ParserRuleContext* rule, antlr4::Parser* parser) {
if (!rule || !parser) {
return "unknown";
}
const int idx = rule->getRuleIndex();
const auto& names = parser->getRuleNames();
if (idx >= 0 && idx < static_cast<int>(names.size())) {
return names[static_cast<size_t>(idx)];
}
return "unknown";
}
void PrintSyntaxTreeImpl(antlr4::tree::ParseTree* node, antlr4::Parser* parser,
std::ostream& os, const std::string& prefix,
bool is_last, bool is_root) {
if (!node) {
return;
}
std::string label;
if (auto* terminal = dynamic_cast<antlr4::tree::TerminalNode*>(node)) {
label = GetTokenName(terminal->getSymbol(), parser) + ": " + terminal->getText();
} else if (auto* rule = dynamic_cast<antlr4::ParserRuleContext*>(node)) {
label = RuleName(rule, parser);
} else {
label = "unknown";
}
if (is_root) {
os << label << "\n";
} else {
os << prefix << (is_last ? "`-- " : "|-- ") << label << "\n";
}
const std::string child_prefix =
is_root ? "" : prefix + (is_last ? " " : "| ");
const size_t child_count = node->children.size();
for (size_t i = 0; i < child_count; ++i) {
PrintSyntaxTreeImpl(node->children[i], parser, os, child_prefix,
i + 1 == child_count, false);
}
}
} // namespace
void PrintSyntaxTree(antlr4::tree::ParseTree* tree, antlr4::Parser* parser,
std::ostream& os) {
PrintSyntaxTreeImpl(tree, parser, os, "", true, true);
}

@ -0,0 +1,62 @@
#include "ir/IR.h"
#include <algorithm>
#include <stdexcept>
namespace ir {
BasicBlock::BasicBlock(const std::string& name)
: Value(Type::GetLabelType(), name) {}
BasicBlock::BasicBlock(Function* parent, const std::string& name)
: Value(Type::GetLabelType(), name), parent_(parent) {}
bool BasicBlock::HasTerminator() const {
return !instructions_.empty() && instructions_.back()->IsTerminator();
}
void BasicBlock::EraseInstruction(Instruction* inst) {
if (!inst) {
return;
}
if (inst->IsTerminator()) {
throw std::runtime_error("cannot erase terminator instruction");
}
auto it = std::find_if(instructions_.begin(), instructions_.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == inst;
});
if (it == instructions_.end()) {
return;
}
(*it)->ClearAllOperands();
instructions_.erase(it);
}
void BasicBlock::AddPredecessor(BasicBlock* pred) {
if (pred &&
std::find(predecessors_.begin(), predecessors_.end(), pred) ==
predecessors_.end()) {
predecessors_.push_back(pred);
}
}
void BasicBlock::AddSuccessor(BasicBlock* succ) {
if (succ && std::find(successors_.begin(), successors_.end(), succ) ==
successors_.end()) {
successors_.push_back(succ);
}
}
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
predecessors_.erase(
std::remove(predecessors_.begin(), predecessors_.end(), pred),
predecessors_.end());
}
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
successors_.erase(std::remove(successors_.begin(), successors_.end(), succ),
successors_.end());
}
} // namespace ir

@ -0,0 +1,27 @@
add_library(ir_core STATIC
Context.cpp
Module.cpp
Function.cpp
BasicBlock.cpp
GlobalValue.cpp
Type.cpp
Value.cpp
Instruction.cpp
IRBuilder.cpp
IRPrinter.cpp
)
target_link_libraries(ir_core PUBLIC
build_options
)
add_subdirectory(analysis)
add_subdirectory(passes)
# IR
add_library(ir INTERFACE)
target_link_libraries(ir INTERFACE
ir_core
ir_analysis
ir_passes
)

@ -0,0 +1,41 @@
#include "ir/IR.h"
#include <sstream>
namespace ir {
Context::~Context() = default;
ConstantInt* Context::GetConstInt(int v) {
auto it = const_ints_.find(v);
if (it != const_ints_.end()) {
return it->second.get();
}
auto inserted = const_ints_.emplace(
v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v));
return inserted.first->second.get();
}
ConstantI1* Context::GetConstBool(bool v) {
auto it = const_bools_.find(v);
if (it != const_bools_.end()) {
return it->second.get();
}
auto inserted = const_bools_.emplace(
v, std::make_unique<ConstantI1>(Type::GetInt1Type(), v));
return inserted.first->second.get();
}
std::string Context::NextTemp() {
std::ostringstream oss;
oss << "%t" << ++temp_index_;
return oss.str();
}
std::string Context::NextBlockName(const std::string& prefix) {
std::ostringstream oss;
oss << prefix << "." << ++block_index_;
return oss.str();
}
} // namespace ir

@ -0,0 +1,51 @@
#include "ir/IR.h"
namespace ir {
Argument::Argument(std::shared_ptr<Type> type, std::string name, size_t index)
: Value(std::move(type), std::move(name)), index_(index) {}
Function::Function(std::string name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types,
const std::vector<std::string>& param_names,
bool is_external)
: Value(Type::GetPointerType(), std::move(name)),
return_type_(std::move(ret_type)),
param_types_(param_types),
is_external_(is_external) {
for (size_t i = 0; i < param_types_.size(); ++i) {
std::string arg_name = i < param_names.size() && !param_names[i].empty()
? param_names[i]
: "%arg" + std::to_string(i);
arguments_.push_back(
std::make_unique<Argument>(param_types_[i], std::move(arg_name), i));
}
}
Argument* Function::GetArgument(size_t index) const {
return index < arguments_.size() ? arguments_[index].get() : nullptr;
}
BasicBlock* Function::EnsureEntryBlock() {
if (!entry_) {
entry_ = CreateBlock("entry");
}
return entry_;
}
BasicBlock* Function::CreateBlock(const std::string& name) {
auto block = std::make_unique<BasicBlock>(this, name);
return AddBlock(std::move(block));
}
BasicBlock* Function::AddBlock(std::unique_ptr<BasicBlock> block) {
auto* ptr = block.get();
ptr->SetParent(this);
blocks_.push_back(std::move(block));
if (!entry_) {
entry_ = ptr;
}
return ptr;
}
} // namespace ir

@ -0,0 +1,12 @@
#include "ir/IR.h"
namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> object_type,
const std::string& name, bool is_const, Value* init)
: User(Type::GetPointerType(object_type), name),
object_type_(std::move(object_type)),
is_const_(is_const),
init_(init) {}
} // namespace ir

@ -0,0 +1,213 @@
#include "ir/IR.h"
#include <stdexcept>
namespace ir {
namespace {
BasicBlock* RequireInsertBlock(BasicBlock* block) {
if (!block) {
throw std::runtime_error("IRBuilder has no insert block");
}
return block;
}
bool IsFloatBinaryOp(Opcode op) {
return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul ||
op == Opcode::FDiv || op == Opcode::FRem || op == Opcode::FCmpEQ ||
op == Opcode::FCmpNE || op == Opcode::FCmpLT || op == Opcode::FCmpGT ||
op == Opcode::FCmpLE || op == Opcode::FCmpGE;
}
bool IsCompareOp(Opcode op) {
return (op >= Opcode::ICmpEQ && op <= Opcode::ICmpGE) ||
(op >= Opcode::FCmpEQ && op <= Opcode::FCmpGE);
}
std::shared_ptr<Type> ResultTypeForBinary(Opcode op, Value* lhs) {
if (IsCompareOp(op)) {
return Type::GetInt1Type();
}
if (IsFloatBinaryOp(op)) {
return Type::GetFloatType();
}
return lhs->GetType();
}
} // namespace
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) : ctx_(ctx), insert_block_(bb) {}
void IRBuilder::SetInsertPoint(BasicBlock* bb) { insert_block_ = bb; }
ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); }
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return new ConstantFloat(Type::GetFloatType(), v);
}
ConstantI1* IRBuilder::CreateConstBool(bool v) { return ctx_.GetConstBool(v); }
ConstantArrayValue* IRBuilder::CreateConstArray(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name) {
return new ConstantArrayValue(std::move(array_type), elements, dims, name);
}
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<BinaryInst>(op, ResultTypeForBinary(op, lhs), lhs, rhs,
nullptr, name);
}
BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Add, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Div, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateRem(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Rem, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::And, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Or, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateXor(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Xor, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateShl(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Shl, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateAShr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::AShr, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateLShr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::LShr, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(op, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(op, lhs, rhs, name);
}
UnaryInst* IRBuilder::CreateNeg(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::Neg, operand->GetType(), operand, nullptr,
name);
}
UnaryInst* IRBuilder::CreateNot(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::Not, operand->GetType(), operand, nullptr,
name);
}
UnaryInst* IRBuilder::CreateFNeg(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::FNeg, operand->GetType(), operand,
nullptr, name);
}
UnaryInst* IRBuilder::CreateFtoI(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::FtoI, Type::GetInt32Type(), operand,
nullptr, name);
}
UnaryInst* IRBuilder::CreateIToF(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::IToF, Type::GetFloatType(), operand,
nullptr, name);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> allocated_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<AllocaInst>(std::move(allocated_type), nullptr, name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, std::shared_ptr<Type> value_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<LoadInst>(std::move(value_type), ptr, nullptr, name);
}
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<StoreInst>(val, ptr, nullptr);
}
UncondBrInst* IRBuilder::CreateBr(BasicBlock* dest) {
auto* block = RequireInsertBlock(insert_block_);
auto* inst = block->Append<UncondBrInst>(dest, nullptr);
block->AddSuccessor(dest);
dest->AddPredecessor(block);
return inst;
}
CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* then_bb,
BasicBlock* else_bb) {
auto* block = RequireInsertBlock(insert_block_);
auto* inst = block->Append<CondBrInst>(cond, then_bb, else_bb, nullptr);
block->AddSuccessor(then_bb);
block->AddSuccessor(else_bb);
then_bb->AddPredecessor(block);
else_bb->AddPredecessor(block);
return inst;
}
ReturnInst* IRBuilder::CreateRet(Value* val) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<ReturnInst>(val, nullptr);
}
UnreachableInst* IRBuilder::CreateUnreachable() {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnreachableInst>(nullptr);
}
CallInst* IRBuilder::CreateCall(Function* callee, const std::vector<Value*>& args,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
std::string real_name = callee->GetReturnType()->IsVoid() ? std::string() : name;
return block->Append<CallInst>(callee, args, nullptr, real_name);
}
GetElementPtrInst* IRBuilder::CreateGEP(Value* ptr,
std::shared_ptr<Type> source_type,
const std::vector<Value*>& indices,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<GetElementPtrInst>(std::move(source_type), ptr, indices,
nullptr, name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<PhiInst>(std::move(type), nullptr, name);
}
ZextInst* IRBuilder::CreateZext(Value* val, std::shared_ptr<Type> target_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<ZextInst>(val, std::move(target_type), nullptr, name);
}
MemsetInst* IRBuilder::CreateMemset(Value* dst, Value* val, Value* len,
Value* is_volatile) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<MemsetInst>(dst, val, len, is_volatile, nullptr);
}
} // namespace ir

@ -0,0 +1,484 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
namespace ir {
namespace {
std::string TypeToString(const Type& ty) {
std::ostringstream oss;
ty.Print(oss);
return oss.str();
}
std::string FloatToString(float value) {
double promoted = static_cast<double>(value);
std::uint64_t bits = 0;
std::memcpy(&bits, &promoted, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::uppercase << std::hex << std::setw(16)
<< std::setfill('0') << bits;
return oss.str();
}
std::string ValueRef(const Value* value) {
if (!value) {
return "<null>";
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return FloatToString(cf->GetValue());
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return cb->GetValue() ? "1" : "0";
}
if (isa<Function>(value) || isa<GlobalValue>(value)) {
return "@" + value->GetName();
}
return value->GetName();
}
std::string BlockRef(const BasicBlock* block) {
if (!block) {
return "%<null>";
}
return "%" + block->GetName();
}
bool IsZeroScalarConstant(const Value* value) {
if (!value) {
return true;
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return ci->GetValue() == 0;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return cf->GetValue() == 0.0f;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return !cb->GetValue();
}
return false;
}
size_t CountScalarElements(const Type& type) {
if (!type.IsArray()) {
return 1;
}
return type.GetNumElements() * CountScalarElements(*type.GetElementType());
}
bool IsZeroArrayRange(const std::vector<Value*>& elements, const Type& type,
size_t offset) {
const auto count = CountScalarElements(type);
for (size_t i = 0; i < count; ++i) {
if (offset + i < elements.size() &&
!IsZeroScalarConstant(elements[offset + i])) {
return false;
}
}
return true;
}
void PrintConstantForType(std::ostream& os, const Type& type, Value* value);
void PrintArrayConstant(std::ostream& os, const Type& type,
const std::vector<Value*>& elements, size_t& offset) {
if (IsZeroArrayRange(elements, type, offset)) {
os << "zeroinitializer";
offset += CountScalarElements(type);
return;
}
const auto elem_type = type.GetElementType();
os << "[";
for (size_t i = 0; i < type.GetNumElements(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*elem_type) << " ";
if (elem_type->IsArray()) {
PrintArrayConstant(os, *elem_type, elements, offset);
} else {
Value* elem = offset < elements.size() ? elements[offset] : nullptr;
PrintConstantForType(os, *elem_type, elem);
++offset;
}
}
os << "]";
}
void PrintConstantForType(std::ostream& os, const Type& type, Value* value) {
if (type.IsArray()) {
auto* array_value = dyncast<ConstantArrayValue>(value);
size_t offset = 0;
if (array_value) {
PrintArrayConstant(os, type, array_value->GetElements(), offset);
} else {
os << "zeroinitializer";
}
return;
}
if (!value) {
if (type.IsFloat()) {
os << FloatToString(0.0f);
} else {
os << "0";
}
return;
}
if (auto* ci = dyncast<ConstantInt>(value)) {
os << ci->GetValue();
return;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
os << FloatToString(cf->GetValue());
return;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
os << (cb->GetValue() ? "1" : "0");
return;
}
throw std::runtime_error("global initializer must be constant");
}
const char* BinaryOpcodeMnemonic(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
return "add";
case Opcode::Sub:
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Div:
return "sdiv";
case Opcode::Rem:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::FRem:
return "frem";
case Opcode::And:
return "and";
case Opcode::Or:
return "or";
case Opcode::Xor:
return "xor";
case Opcode::Shl:
return "shl";
case Opcode::AShr:
return "ashr";
case Opcode::LShr:
return "lshr";
case Opcode::ICmpEQ:
return "icmp eq";
case Opcode::ICmpNE:
return "icmp ne";
case Opcode::ICmpLT:
return "icmp slt";
case Opcode::ICmpGT:
return "icmp sgt";
case Opcode::ICmpLE:
return "icmp sle";
case Opcode::ICmpGE:
return "icmp sge";
case Opcode::FCmpEQ:
return "fcmp oeq";
case Opcode::FCmpNE:
return "fcmp one";
case Opcode::FCmpLT:
return "fcmp olt";
case Opcode::FCmpGT:
return "fcmp ogt";
case Opcode::FCmpLE:
return "fcmp ole";
case Opcode::FCmpGE:
return "fcmp oge";
default:
throw std::runtime_error("unsupported binary opcode");
}
}
bool NeedsMemsetDeclaration(const Module& module) {
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsExternal()) {
continue;
}
for (const auto& bb : func->GetBlocks()) {
for (const auto& inst : bb->GetInstructions()) {
if (inst->GetOpcode() == Opcode::Memset) {
return true;
}
}
}
}
return false;
}
void PrintInstruction(const Instruction& inst, std::ostream& os) {
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto& bin = static_cast<const BinaryInst&>(inst);
os << " " << bin.GetName() << " = " << BinaryOpcodeMnemonic(bin.GetOpcode())
<< " " << TypeToString(*bin.GetLhs()->GetType()) << " "
<< ValueRef(bin.GetLhs()) << ", " << ValueRef(bin.GetRhs()) << "\n";
return;
}
case Opcode::Neg: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = sub " << TypeToString(*un.GetType())
<< " 0, " << ValueRef(un.GetOprd()) << "\n";
return;
}
case Opcode::Not: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = xor " << TypeToString(*un.GetType())
<< " " << ValueRef(un.GetOprd()) << ", 1\n";
return;
}
case Opcode::FNeg: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = fneg " << TypeToString(*un.GetType())
<< " " << ValueRef(un.GetOprd()) << "\n";
return;
}
case Opcode::FtoI: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = fptosi "
<< TypeToString(*un.GetOprd()->GetType()) << " " << ValueRef(un.GetOprd())
<< " to " << TypeToString(*un.GetType()) << "\n";
return;
}
case Opcode::IToF: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = sitofp "
<< TypeToString(*un.GetOprd()->GetType()) << " " << ValueRef(un.GetOprd())
<< " to " << TypeToString(*un.GetType()) << "\n";
return;
}
case Opcode::Alloca: {
auto& alloca_inst = static_cast<const AllocaInst&>(inst);
os << " " << alloca_inst.GetName() << " = alloca "
<< TypeToString(*alloca_inst.GetAllocatedType()) << "\n";
return;
}
case Opcode::Load: {
auto& load = static_cast<const LoadInst&>(inst);
os << " " << load.GetName() << " = load "
<< TypeToString(*load.GetType()) << ", ptr " << ValueRef(load.GetPtr())
<< "\n";
return;
}
case Opcode::Store: {
auto& store = static_cast<const StoreInst&>(inst);
os << " store " << TypeToString(*store.GetValue()->GetType()) << " "
<< ValueRef(store.GetValue()) << ", ptr " << ValueRef(store.GetPtr())
<< "\n";
return;
}
case Opcode::Br: {
auto& br = static_cast<const UncondBrInst&>(inst);
os << " br label " << BlockRef(br.GetDest()) << "\n";
return;
}
case Opcode::CondBr: {
auto& br = static_cast<const CondBrInst&>(inst);
os << " br i1 " << ValueRef(br.GetCondition()) << ", label "
<< BlockRef(br.GetThenBlock()) << ", label "
<< BlockRef(br.GetElseBlock()) << "\n";
return;
}
case Opcode::Return: {
auto& ret = static_cast<const ReturnInst&>(inst);
if (!ret.HasReturnValue()) {
os << " ret void\n";
} else {
os << " ret " << TypeToString(*ret.GetReturnValue()->GetType()) << " "
<< ValueRef(ret.GetReturnValue()) << "\n";
}
return;
}
case Opcode::Unreachable:
os << " unreachable\n";
return;
case Opcode::Call: {
auto& call = static_cast<const CallInst&>(inst);
if (!call.GetType()->IsVoid()) {
os << " " << call.GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call.GetCallee()->GetReturnType()) << " @"
<< call.GetCallee()->GetName() << "(";
const auto args = call.GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType()) << " " << ValueRef(args[i]);
}
os << ")\n";
return;
}
case Opcode::GetElementPtr: {
auto& gep = static_cast<const GetElementPtrInst&>(inst);
os << " " << gep.GetName() << " = getelementptr "
<< TypeToString(*gep.GetSourceType()) << ", ptr "
<< ValueRef(gep.GetPointer());
for (size_t i = 0; i < gep.GetNumIndices(); ++i) {
auto* index = gep.GetIndex(i);
os << ", " << TypeToString(*index->GetType()) << " " << ValueRef(index);
}
os << "\n";
return;
}
case Opcode::Phi: {
auto& phi = static_cast<const PhiInst&>(inst);
os << " " << phi.GetName() << " = phi " << TypeToString(*phi.GetType())
<< " ";
for (int i = 0; i < phi.GetNumIncomings(); ++i) {
if (i > 0) {
os << ", ";
}
os << "[ " << ValueRef(phi.GetIncomingValue(i)) << ", "
<< BlockRef(phi.GetIncomingBlock(i)) << " ]";
}
os << "\n";
return;
}
case Opcode::Zext: {
auto& zext = static_cast<const ZextInst&>(inst);
os << " " << zext.GetName() << " = zext "
<< TypeToString(*zext.GetValue()->GetType()) << " "
<< ValueRef(zext.GetValue()) << " to " << TypeToString(*zext.GetType())
<< "\n";
return;
}
case Opcode::Memset: {
auto& memset = static_cast<const MemsetInst&>(inst);
os << " call void @llvm.memset.p0.i32(ptr " << ValueRef(memset.GetDest())
<< ", i8 " << ValueRef(memset.GetValue()) << ", i32 "
<< ValueRef(memset.GetLength()) << ", i1 "
<< ValueRef(memset.GetIsVolatile()) << ")\n";
return;
}
}
throw std::runtime_error("unsupported instruction in printer");
}
} // namespace
void IRPrinter::Print(const Module& module, std::ostream& os) {
if (NeedsMemsetDeclaration(module)) {
os << "declare void @llvm.memset.p0.i32(ptr, i8, i32, i1)\n\n";
}
for (const auto& global : module.GetGlobalValues()) {
os << "@" << global->GetName() << " = "
<< (global->IsConstant() ? "constant " : "global ")
<< TypeToString(*global->GetObjectType()) << " ";
PrintConstantForType(os, *global->GetObjectType(), global->GetInitializer());
os << "\n";
}
if (!module.GetGlobalValues().empty()) {
os << "\n";
}
for (const auto& func : module.GetFunctions()) {
if (!func->IsExternal()) {
continue;
}
os << "declare " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType());
}
os << ")\n";
}
bool printed_decl = false;
for (const auto& func : module.GetFunctions()) {
if (func->IsExternal()) {
printed_decl = true;
}
}
if (printed_decl) {
os << "\n";
}
for (const auto& func : module.GetFunctions()) {
if (func->IsExternal()) {
continue;
}
os << "define " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName();
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
os << bb->GetName() << ":\n";
for (const auto& inst : bb->GetInstructions()) {
PrintInstruction(*inst, os);
}
}
os << "}\n\n";
}
}
} // namespace ir

@ -0,0 +1,263 @@
#include "ir/IR.h"
#include <stdexcept>
namespace ir {
User::User(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
Value* User::GetOperand(size_t index) const {
if (index >= operands_.size()) {
throw std::out_of_range("operand index out of range");
}
return operands_[index].GetValue();
}
void User::SetOperand(size_t index, Value* value) {
if (index >= operands_.size()) {
throw std::out_of_range("operand index out of range");
}
auto* old_value = operands_[index].GetValue();
if (old_value == value) {
return;
}
if (old_value) {
old_value->RemoveUse(this, index);
}
operands_[index].SetValue(value);
if (value) {
value->AddUse(this, index);
}
}
void User::AddOperand(Value* value) {
if (!value) {
throw std::runtime_error("operand cannot be null");
}
operands_.emplace_back(value, this, operands_.size());
value->AddUse(this, operands_.size() - 1);
}
void User::AddOperands(const std::vector<Value*>& values) {
for (auto* value : values) {
AddOperand(value);
}
}
void User::RemoveOperand(size_t index) {
if (index >= operands_.size()) {
throw std::out_of_range("operand index out of range");
}
if (auto* value = operands_[index].GetValue()) {
value->RemoveUse(this, index);
}
operands_.erase(operands_.begin() + static_cast<long long>(index));
for (size_t i = index; i < operands_.size(); ++i) {
if (auto* value = operands_[i].GetValue()) {
value->RemoveUse(this, i + 1);
value->AddUse(this, i);
}
operands_[i].SetOperandIndex(i);
}
}
void User::ClearAllOperands() {
for (size_t i = 0; i < operands_.size(); ++i) {
if (auto* value = operands_[i].GetValue()) {
value->RemoveUse(this, i);
}
}
operands_.clear();
}
Instruction::Instruction(Opcode opcode, std::shared_ptr<Type> ty,
BasicBlock* parent, const std::string& name)
: User(std::move(ty), name), opcode_(opcode), parent_(parent) {}
bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Br || opcode_ == Opcode::CondBr ||
opcode_ == Opcode::Return || opcode_ == Opcode::Unreachable;
}
static bool IsBinaryOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
return true;
default:
return false;
}
}
bool BinaryInst::classof(const Value* value) {
return value && Instruction::classof(value) &&
IsBinaryOpcode(static_cast<const Instruction*>(value)->GetOpcode());
}
BinaryInst::BinaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, BasicBlock* parent,
const std::string& name)
: Instruction(opcode, std::move(ty), parent, name) {
AddOperand(lhs);
AddOperand(rhs);
}
bool UnaryInst::classof(const Value* value) {
if (!value || !Instruction::classof(value)) {
return false;
}
switch (static_cast<const Instruction*>(value)->GetOpcode()) {
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
return true;
default:
return false;
}
}
UnaryInst::UnaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* operand,
BasicBlock* parent, const std::string& name)
: Instruction(opcode, std::move(ty), parent, name) {
AddOperand(operand);
}
ReturnInst::ReturnInst(Value* value, BasicBlock* parent)
: Instruction(Opcode::Return, Type::GetVoidType(), parent, "") {
if (value) {
AddOperand(value);
}
}
AllocaInst::AllocaInst(std::shared_ptr<Type> allocated_type,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Alloca, Type::GetPointerType(allocated_type), parent,
name),
allocated_type_(std::move(allocated_type)) {}
LoadInst::LoadInst(std::shared_ptr<Type> value_type, Value* ptr,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Load, std::move(value_type), parent, name) {
AddOperand(ptr);
}
StoreInst::StoreInst(Value* value, Value* ptr, BasicBlock* parent)
: Instruction(Opcode::Store, Type::GetVoidType(), parent, "") {
AddOperand(value);
AddOperand(ptr);
}
UncondBrInst::UncondBrInst(BasicBlock* dest, BasicBlock* parent)
: Instruction(Opcode::Br, Type::GetVoidType(), parent, "") {
AddOperand(dest);
}
BasicBlock* UncondBrInst::GetDest() const {
return dyncast<BasicBlock>(GetOperand(0));
}
CondBrInst::CondBrInst(Value* cond, BasicBlock* then_block,
BasicBlock* else_block, BasicBlock* parent)
: Instruction(Opcode::CondBr, Type::GetVoidType(), parent, "") {
AddOperand(cond);
AddOperand(then_block);
AddOperand(else_block);
}
BasicBlock* CondBrInst::GetThenBlock() const {
return dyncast<BasicBlock>(GetOperand(1));
}
BasicBlock* CondBrInst::GetElseBlock() const {
return dyncast<BasicBlock>(GetOperand(2));
}
UnreachableInst::UnreachableInst(BasicBlock* parent)
: Instruction(Opcode::Unreachable, Type::GetVoidType(), parent, "") {}
CallInst::CallInst(Function* callee, const std::vector<Value*>& args,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Call, callee->GetReturnType(), parent, name) {
AddOperand(callee);
AddOperands(args);
}
Function* CallInst::GetCallee() const { return dyncast<Function>(GetOperand(0)); }
std::vector<Value*> CallInst::GetArguments() const {
std::vector<Value*> args;
for (size_t i = 1; i < GetNumOperands(); ++i) {
args.push_back(GetOperand(i));
}
return args;
}
GetElementPtrInst::GetElementPtrInst(std::shared_ptr<Type> source_type,
Value* ptr,
const std::vector<Value*>& indices,
BasicBlock* parent,
const std::string& name)
: Instruction(Opcode::GetElementPtr, Type::GetPointerType(), parent, name),
source_type_(std::move(source_type)) {
AddOperand(ptr);
AddOperands(indices);
}
PhiInst::PhiInst(std::shared_ptr<Type> type, BasicBlock* parent,
const std::string& name)
: Instruction(Opcode::Phi, std::move(type), parent, name) {}
void PhiInst::AddIncoming(Value* value, BasicBlock* block) {
AddOperand(value);
AddOperand(block);
}
BasicBlock* PhiInst::GetIncomingBlock(int index) const {
return dyncast<BasicBlock>(GetOperand(static_cast<size_t>(2 * index + 1)));
}
ZextInst::ZextInst(Value* value, std::shared_ptr<Type> target_type,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Zext, std::move(target_type), parent, name) {
AddOperand(value);
}
MemsetInst::MemsetInst(Value* dst, Value* value, Value* len,
Value* is_volatile, BasicBlock* parent)
: Instruction(Opcode::Memset, Type::GetVoidType(), parent, "") {
AddOperand(dst);
AddOperand(value);
AddOperand(len);
AddOperand(is_volatile);
}
} // namespace ir

@ -0,0 +1,45 @@
#include "ir/IR.h"
namespace ir {
Function* Module::CreateFunction(
const std::string& name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types,
const std::vector<std::string>& param_names, bool is_external) {
if (auto* existing = GetFunction(name)) {
existing->SetExternal(existing->IsExternal() && is_external);
return existing;
}
auto func = std::make_unique<Function>(name, std::move(ret_type), param_types,
param_names, is_external);
auto* ptr = func.get();
functions_.push_back(std::move(func));
function_map_[name] = ptr;
return ptr;
}
Function* Module::GetFunction(const std::string& name) const {
auto it = function_map_.find(name);
return it == function_map_.end() ? nullptr : it->second;
}
GlobalValue* Module::CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> object_type,
bool is_const, Value* init) {
if (auto* existing = GetGlobalValue(name)) {
return existing;
}
auto global =
std::make_unique<GlobalValue>(std::move(object_type), name, is_const, init);
auto* ptr = global.get();
globals_.push_back(std::move(global));
global_map_[name] = ptr;
return ptr;
}
GlobalValue* Module::GetGlobalValue(const std::string& name) const {
auto it = global_map_.find(name);
return it == global_map_.end() ? nullptr : it->second;
}
} // namespace ir

@ -0,0 +1,111 @@
#include "ir/IR.h"
#include <ostream>
#include <stdexcept>
namespace ir {
Type::Type(Kind kind) : kind_(kind) {}
Type::Type(Kind kind, std::shared_ptr<Type> element_type, size_t num_elements)
: kind_(kind),
element_type_(std::move(element_type)),
num_elements_(num_elements) {}
const std::shared_ptr<Type>& Type::GetVoidType() {
static const auto type = std::make_shared<Type>(Kind::Void);
return type;
}
const std::shared_ptr<Type>& Type::GetInt1Type() {
static const auto type = std::make_shared<Type>(Kind::Int1);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() {
static const auto type = std::make_shared<Type>(Kind::Int32);
return type;
}
const std::shared_ptr<Type>& Type::GetFloatType() {
static const auto type = std::make_shared<Type>(Kind::Float);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const auto type = std::make_shared<Type>(Kind::Label);
return type;
}
const std::shared_ptr<Type>& Type::GetBoolType() { return GetInt1Type(); }
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> pointee) {
return std::make_shared<Type>(Kind::Pointer, std::move(pointee));
}
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const auto type = std::make_shared<Type>(Kind::Pointer);
return type;
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> element_type,
size_t num_elements) {
return std::make_shared<Type>(Kind::Array, std::move(element_type), num_elements);
}
int Type::GetSize() const {
switch (kind_) {
case Kind::Void:
case Kind::Label:
case Kind::Function:
return 0;
case Kind::Int1:
return 1;
case Kind::Int32:
case Kind::Float:
return 4;
case Kind::Pointer:
return 8;
case Kind::Array:
return static_cast<int>(num_elements_) *
(element_type_ ? element_type_->GetSize() : 0);
}
throw std::runtime_error("unknown IR type kind");
}
void Type::Print(std::ostream& os) const {
switch (kind_) {
case Kind::Void:
os << "void";
return;
case Kind::Int1:
os << "i1";
return;
case Kind::Int32:
os << "i32";
return;
case Kind::Float:
os << "float";
return;
case Kind::Label:
os << "label";
return;
case Kind::Function:
os << "fn";
return;
case Kind::Pointer:
os << "ptr";
return;
case Kind::Array:
os << "[" << num_elements_ << " x ";
if (element_type_) {
element_type_->Print(os);
} else {
os << "void";
}
os << "]";
return;
}
}
} // namespace ir

@ -0,0 +1,66 @@
#include "ir/IR.h"
#include <algorithm>
#include <ostream>
#include <stdexcept>
namespace ir {
Value::Value(std::shared_ptr<Type> ty, std::string name)
: type_(std::move(ty)), name_(std::move(name)) {}
void Value::AddUse(User* user, size_t operand_index) {
if (!user) {
return;
}
uses_.emplace_back(this, user, operand_index);
}
void Value::RemoveUse(User* user, size_t operand_index) {
uses_.erase(
std::remove_if(uses_.begin(), uses_.end(),
[&](const Use& use) {
return use.GetUser() == user &&
use.GetOperandIndex() == operand_index;
}),
uses_.end());
}
void Value::ReplaceAllUsesWith(Value* new_value) {
if (!new_value) {
throw std::runtime_error("ReplaceAllUsesWith requires a new value");
}
if (new_value == this) {
return;
}
auto uses = uses_;
for (const auto& use : uses) {
if (auto* user = use.GetUser()) {
user->SetOperand(use.GetOperandIndex(), new_value);
}
}
}
void Value::Print(std::ostream& os) const { os << name_; }
ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantI1::ConstantI1(std::shared_ptr<Type> ty, bool value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantArrayValue::ConstantArrayValue(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name)
: Value(std::move(array_type), name), elements_(elements), dims_(dims) {}
void ConstantArrayValue::Print(std::ostream& os) const { os << name_; }
} // namespace ir

@ -0,0 +1,9 @@
add_library(ir_analysis STATIC
DominatorTree.cpp
LoopInfo.cpp
)
target_link_libraries(ir_analysis PUBLIC
build_options
ir_core
)

@ -0,0 +1,167 @@
#include "ir/Analysis.h"
#include <algorithm>
#include <functional>
namespace ir {
namespace {
std::vector<BasicBlock*> BuildReversePostOrder(Function& function) {
std::vector<BasicBlock*> post_order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return post_order;
}
std::unordered_set<BasicBlock*> visited;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* block) {
if (!block || !visited.insert(block).second) {
return;
}
for (auto* succ : block->GetSuccessors()) {
dfs(succ);
}
post_order.push_back(block);
};
dfs(entry);
std::reverse(post_order.begin(), post_order.end());
return post_order;
}
} // namespace
DominatorTree::DominatorTree(Function& function) : function_(&function) {
Recalculate();
}
void DominatorTree::Recalculate() {
reverse_post_order_ = BuildReversePostOrder(*function_);
block_index_.clear();
dominates_.clear();
immediate_dominator_.clear();
dom_children_.clear();
const auto num_blocks = reverse_post_order_.size();
for (std::size_t i = 0; i < num_blocks; ++i) {
block_index_.emplace(reverse_post_order_[i], i);
}
if (num_blocks == 0) {
return;
}
dominates_.assign(num_blocks, std::vector<std::uint8_t>(num_blocks, 1));
dominates_[0].assign(num_blocks, 0);
dominates_[0][0] = 1;
bool changed = true;
while (changed) {
changed = false;
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
std::vector<std::uint8_t> next(num_blocks, 1);
bool has_reachable_pred = false;
for (auto* pred : block->GetPredecessors()) {
auto pred_it = block_index_.find(pred);
if (pred_it == block_index_.end()) {
continue;
}
has_reachable_pred = true;
const auto& pred_dom = dominates_[pred_it->second];
for (std::size_t bit = 0; bit < num_blocks; ++bit) {
next[bit] &= pred_dom[bit];
}
}
if (!has_reachable_pred) {
next.assign(num_blocks, 0);
}
next[i] = 1;
if (next != dominates_[i]) {
dominates_[i] = std::move(next);
changed = true;
}
}
}
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
BasicBlock* idom = nullptr;
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (candidate == i || !dominates_[i][candidate]) {
continue;
}
auto* candidate_block = reverse_post_order_[candidate];
bool immediate = true;
for (std::size_t other = 0; other < num_blocks; ++other) {
if (other == i || other == candidate || !dominates_[i][other]) {
continue;
}
if (Dominates(reverse_post_order_[other], candidate_block)) {
immediate = false;
break;
}
}
if (immediate) {
idom = candidate_block;
break;
}
}
immediate_dominator_.emplace(block, idom);
if (idom) {
dom_children_[idom].push_back(block);
}
}
}
bool DominatorTree::IsReachable(BasicBlock* block) const {
return block != nullptr && block_index_.find(block) != block_index_.end();
}
bool DominatorTree::Dominates(BasicBlock* dom, BasicBlock* node) const {
if (!dom || !node) {
return false;
}
const auto dom_it = block_index_.find(dom);
const auto node_it = block_index_.find(node);
if (dom_it == block_index_.end() || node_it == block_index_.end()) {
return false;
}
return dominates_[node_it->second][dom_it->second] != 0;
}
bool DominatorTree::Dominates(Instruction* dom, Instruction* user) const {
if (!dom || !user) {
return false;
}
if (dom == user) {
return true;
}
auto* dom_block = dom->GetParent();
auto* user_block = user->GetParent();
if (dom_block != user_block) {
return Dominates(dom_block, user_block);
}
for (const auto& inst_ptr : dom_block->GetInstructions()) {
if (inst_ptr.get() == dom) {
return true;
}
if (inst_ptr.get() == user) {
return false;
}
}
return false;
}
BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const {
auto it = immediate_dominator_.find(block);
return it == immediate_dominator_.end() ? nullptr : it->second;
}
const std::vector<BasicBlock*>& DominatorTree::GetChildren(BasicBlock* block) const {
static const std::vector<BasicBlock*> kEmpty;
auto it = dom_children_.find(block);
return it == dom_children_.end() ? kEmpty : it->second;
}
} // namespace ir

@ -0,0 +1,214 @@
#include "ir/Analysis.h"
#include <algorithm>
#include <functional>
namespace ir {
namespace {
std::vector<BasicBlock*> CollectNaturalLoopBlocks(BasicBlock* header,
BasicBlock* latch) {
std::vector<BasicBlock*> stack{latch};
std::unordered_set<BasicBlock*> loop_blocks{header, latch};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
for (auto* pred : block->GetPredecessors()) {
if (!pred || !loop_blocks.insert(pred).second) {
continue;
}
stack.push_back(pred);
}
}
return {loop_blocks.begin(), loop_blocks.end()};
}
} // namespace
bool Loop::Contains(BasicBlock* block) const {
return block != nullptr && blocks.find(block) != blocks.end();
}
bool Loop::Contains(const Loop* other) const {
if (!other) {
return false;
}
for (auto* block : other->blocks) {
if (!Contains(block)) {
return false;
}
}
return true;
}
bool Loop::IsInnermost() const { return subloops.empty(); }
LoopInfo::LoopInfo(Function& function, const DominatorTree& dom_tree)
: function_(&function), dom_tree_(&dom_tree) {
Recalculate();
}
void LoopInfo::Recalculate() {
loops_.clear();
top_level_loops_.clear();
block_to_loop_.clear();
std::unordered_map<BasicBlock*, Loop*> loops_by_header;
for (auto* block : dom_tree_->GetReversePostOrder()) {
for (auto* succ : block->GetSuccessors()) {
if (!dom_tree_->Dominates(succ, block)) {
continue;
}
Loop* loop = nullptr;
auto it = loops_by_header.find(succ);
if (it == loops_by_header.end()) {
auto new_loop = std::make_unique<Loop>();
new_loop->header = succ;
loop = new_loop.get();
loops_.push_back(std::move(new_loop));
loops_by_header.emplace(succ, loop);
} else {
loop = it->second;
}
if (std::find(loop->latches.begin(), loop->latches.end(), block) ==
loop->latches.end()) {
loop->latches.push_back(block);
}
for (auto* natural_block : CollectNaturalLoopBlocks(succ, block)) {
loop->blocks.insert(natural_block);
}
}
}
std::unordered_map<BasicBlock*, std::size_t> function_order;
for (std::size_t i = 0; i < function_->GetBlocks().size(); ++i) {
function_order.emplace(function_->GetBlocks()[i].get(), i);
}
for (const auto& loop_ptr : loops_) {
auto& loop = *loop_ptr;
loop.block_list.clear();
loop.exiting_blocks.clear();
loop.exit_blocks.clear();
loop.subloops.clear();
loop.parent = nullptr;
for (const auto& block_ptr : function_->GetBlocks()) {
if (loop.Contains(block_ptr.get())) {
loop.block_list.push_back(block_ptr.get());
}
}
std::sort(loop.latches.begin(), loop.latches.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
std::vector<BasicBlock*> outside_preds;
for (auto* pred : loop.header->GetPredecessors()) {
if (!loop.Contains(pred)) {
outside_preds.push_back(pred);
}
}
if (outside_preds.size() == 1 &&
outside_preds.front()->GetSuccessors().size() == 1) {
loop.preheader = outside_preds.front();
} else {
loop.preheader = nullptr;
}
std::unordered_set<BasicBlock*> exiting_seen;
std::unordered_set<BasicBlock*> exit_seen;
for (auto* block : loop.block_list) {
for (auto* succ : block->GetSuccessors()) {
if (loop.Contains(succ)) {
continue;
}
if (exiting_seen.insert(block).second) {
loop.exiting_blocks.push_back(block);
}
if (exit_seen.insert(succ).second) {
loop.exit_blocks.push_back(succ);
}
}
}
std::sort(loop.exiting_blocks.begin(), loop.exiting_blocks.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
std::sort(loop.exit_blocks.begin(), loop.exit_blocks.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
}
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
Loop* parent = nullptr;
for (const auto& candidate_ptr : loops_) {
auto* candidate = candidate_ptr.get();
if (candidate == loop || !candidate->Contains(loop)) {
continue;
}
if (!parent || candidate->blocks.size() < parent->blocks.size()) {
parent = candidate;
}
}
loop->parent = parent;
if (parent) {
parent->subloops.push_back(loop);
} else {
top_level_loops_.push_back(loop);
}
}
auto loop_order = [&](Loop* lhs, Loop* rhs) {
return function_order[lhs->header] < function_order[rhs->header];
};
std::sort(top_level_loops_.begin(), top_level_loops_.end(), loop_order);
for (const auto& loop_ptr : loops_) {
std::sort(loop_ptr->subloops.begin(), loop_ptr->subloops.end(), loop_order);
}
for (const auto& block_ptr : function_->GetBlocks()) {
Loop* innermost = nullptr;
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
if (!loop->Contains(block_ptr.get())) {
continue;
}
if (!innermost || loop->blocks.size() < innermost->blocks.size()) {
innermost = loop;
}
}
if (innermost) {
block_to_loop_.emplace(block_ptr.get(), innermost);
}
}
}
std::vector<Loop*> LoopInfo::GetTopLevelLoops() const { return top_level_loops_; }
std::vector<Loop*> LoopInfo::GetLoopsInPostOrder() const {
std::vector<Loop*> ordered;
std::function<void(Loop*)> dfs = [&](Loop* loop) {
for (auto* subloop : loop->subloops) {
dfs(subloop);
}
ordered.push_back(loop);
};
for (auto* loop : top_level_loops_) {
dfs(loop);
}
return ordered;
}
Loop* LoopInfo::GetLoopFor(BasicBlock* block) const {
auto it = block_to_loop_.find(block);
return it == block_to_loop_.end() ? nullptr : it->second;
}
} // namespace ir

@ -0,0 +1,107 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <vector>
namespace ir {
namespace {
bool TryGetConstBranchTarget(CondBrInst* br, BasicBlock*& target, BasicBlock*& removed) {
if (!br) {
return false;
}
auto* then_block = br->GetThenBlock();
auto* else_block = br->GetElseBlock();
if (then_block == else_block) {
target = then_block;
removed = nullptr;
return true;
}
if (auto* cond = dyncast<ConstantI1>(br->GetCondition())) {
target = cond->GetValue() ? then_block : else_block;
removed = cond->GetValue() ? else_block : then_block;
return true;
}
return false;
}
bool SimplifyBlockTerminator(BasicBlock* block) {
if (!block || block->GetInstructions().empty()) {
return false;
}
auto* term = block->GetInstructions().back().get();
auto* condbr = dyncast<CondBrInst>(term);
if (!condbr) {
return false;
}
BasicBlock* target = nullptr;
BasicBlock* removed = nullptr;
if (!TryGetConstBranchTarget(condbr, target, removed)) {
return false;
}
if (removed) {
passutils::RemoveIncomingFromSuccessor(removed, block);
removed->RemovePredecessor(block);
block->RemoveSuccessor(removed);
}
passutils::ReplaceTerminatorWithBr(block, target);
return true;
}
bool SimplifyPhiNodes(Function& function) {
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (!passutils::SimplifyPhiInst(phi)) {
continue;
}
local_changed = true;
changed = true;
break;
}
}
}
return changed;
}
bool RunCFGSimplifyOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
changed |= SimplifyBlockTerminator(block_ptr.get());
}
changed |= passutils::RemoveUnreachableBlocks(function);
changed |= SimplifyPhiNodes(function);
return changed;
}
} // namespace
bool RunCFGSimplify(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCFGSimplifyOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,24 @@
add_library(ir_passes STATIC
PassManager.cpp
Mem2Reg.cpp
ConstFold.cpp
ConstProp.cpp
Inline.cpp
CSE.cpp
GVN.cpp
LoadStoreElim.cpp
DCE.cpp
CFGSimplify.cpp
LICM.cpp
LoopMemoryPromotion.cpp
LoopUnswitch.cpp
LoopStrengthReduction.cpp
LoopUnroll.cpp
LoopFission.cpp
)
target_link_libraries(ir_passes PUBLIC
build_options
ir_core
ir_analysis
)

@ -0,0 +1,141 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::vector<std::uintptr_t> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && operands == rhs.operands;
}
};
struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
bool IsSupportedCSEInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Zext:
return true;
default:
return false;
}
}
ExprKey BuildExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.operands.reserve(inst->GetNumOperands());
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
std::swap(key.operands[0], key.operands[1]);
}
return key;
}
bool RunCSEOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::unordered_map<ExprKey, Value*, ExprKeyHash> available_exprs;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedCSEInstruction(inst)) {
continue;
}
const auto key = BuildExprKey(inst);
auto it = available_exprs.find(key);
if (it == available_exprs.end()) {
available_exprs.emplace(key, inst);
continue;
}
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunCSE(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCSEOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,469 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cmath>
#include <cstdint>
#include <limits>
#include <vector>
namespace ir {
namespace {
Value* GetInt32Const(Context& ctx, std::int32_t value) {
return ctx.GetConstInt(static_cast<int>(value));
}
Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); }
Value* GetFloatConst(float value) {
return new ConstantFloat(Type::GetFloatType(), value);
}
bool TryGetInt32(Value* value, std::int32_t& out) {
if (auto* ci = dyncast<ConstantInt>(value)) {
out = static_cast<std::int32_t>(ci->GetValue());
return true;
}
return false;
}
bool TryGetBool(Value* value, bool& out) {
if (auto* cb = dyncast<ConstantI1>(value)) {
out = cb->GetValue();
return true;
}
return false;
}
bool TryGetFloat(Value* value, float& out) {
if (auto* cf = dyncast<ConstantFloat>(value)) {
out = cf->GetValue();
return true;
}
return false;
}
bool IsZeroValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) ||
(TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0);
}
bool IsOneValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) ||
(TryGetFloat(value, f32) &&
passutils::FloatBits(f32) == passutils::FloatBits(1.0f));
}
bool IsAllOnesInt(Value* value) {
std::int32_t i32 = 0;
return TryGetInt32(value, i32) && i32 == -1;
}
std::int32_t WrapInt32(std::uint32_t value) {
return static_cast<std::int32_t>(value);
}
Value* FoldBinary(Context& ctx, BinaryInst* inst) {
const auto opcode = inst->GetOpcode();
auto* lhs = inst->GetLhs();
auto* rhs = inst->GetRhs();
std::int32_t lhs_i32 = 0;
std::int32_t rhs_i32 = 0;
bool lhs_i1 = false;
bool rhs_i1 = false;
float lhs_f32 = 0.0f;
float rhs_f32 = 0.0f;
const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32);
const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32);
const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1);
const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1);
const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32);
const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32);
if (has_lhs_i32 && has_rhs_i32) {
switch (opcode) {
case Opcode::Add:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) +
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Sub:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) -
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Mul:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) *
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Div:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 / rhs_i32);
case Opcode::Rem:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 % rhs_i32);
case Opcode::And:
return GetInt32Const(ctx, lhs_i32 & rhs_i32);
case Opcode::Or:
return GetInt32Const(ctx, lhs_i32 | rhs_i32);
case Opcode::Xor:
return GetInt32Const(ctx, lhs_i32 ^ rhs_i32);
case Opcode::Shl:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32)
<< rhs_i32));
case Opcode::AShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 >> rhs_i32);
case Opcode::LShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(
ctx,
WrapInt32(static_cast<std::uint32_t>(lhs_i32) >> rhs_i32));
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i32 == rhs_i32);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i32 != rhs_i32);
case Opcode::ICmpLT:
return GetBoolConst(ctx, lhs_i32 < rhs_i32);
case Opcode::ICmpGT:
return GetBoolConst(ctx, lhs_i32 > rhs_i32);
case Opcode::ICmpLE:
return GetBoolConst(ctx, lhs_i32 <= rhs_i32);
case Opcode::ICmpGE:
return GetBoolConst(ctx, lhs_i32 >= rhs_i32);
default:
break;
}
}
if (has_lhs_i1 && has_rhs_i1) {
switch (opcode) {
case Opcode::And:
return GetBoolConst(ctx, lhs_i1 && rhs_i1);
case Opcode::Or:
return GetBoolConst(ctx, lhs_i1 || rhs_i1);
case Opcode::Xor:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i1 == rhs_i1);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpLT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) < static_cast<int>(rhs_i1));
case Opcode::ICmpGT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) > static_cast<int>(rhs_i1));
case Opcode::ICmpLE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) <= static_cast<int>(rhs_i1));
case Opcode::ICmpGE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) >= static_cast<int>(rhs_i1));
default:
break;
}
}
if (has_lhs_f32 && has_rhs_f32) {
switch (opcode) {
case Opcode::FAdd:
return GetFloatConst(lhs_f32 + rhs_f32);
case Opcode::FSub:
return GetFloatConst(lhs_f32 - rhs_f32);
case Opcode::FMul:
return GetFloatConst(lhs_f32 * rhs_f32);
case Opcode::FDiv:
return GetFloatConst(lhs_f32 / rhs_f32);
case Opcode::FRem:
return GetFloatConst(std::fmod(lhs_f32, rhs_f32));
case Opcode::FCmpEQ:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32);
case Opcode::FCmpNE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32);
case Opcode::FCmpLT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32);
case Opcode::FCmpGT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32);
case Opcode::FCmpLE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32);
case Opcode::FCmpGE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32);
default:
break;
}
}
switch (opcode) {
case Opcode::Add:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return rhs;
}
break;
case Opcode::Sub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::Mul:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsOneValue(lhs)) {
return rhs;
}
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Div:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Rem:
if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) ||
(has_rhs_i1 && rhs_i1)) {
return GetInt32Const(ctx, 0);
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::And:
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
if (has_lhs_i1 && lhs_i1) {
return rhs;
}
if (has_rhs_i1 && rhs_i1) {
return lhs;
}
if (IsAllOnesInt(lhs)) {
return rhs;
}
if (IsAllOnesInt(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Or:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (has_lhs_i1 && lhs_i1) {
return GetBoolConst(ctx, true);
}
if (has_rhs_i1 && rhs_i1) {
return GetBoolConst(ctx, true);
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Xor:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
break;
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::FAdd:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FSub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FMul:
if (IsOneValue(lhs)) {
return rhs;
}
if (IsOneValue(rhs)) {
return lhs;
}
break;
case Opcode::FDiv:
if (IsOneValue(rhs)) {
return lhs;
}
break;
default:
break;
}
return nullptr;
}
Value* FoldUnary(Context& ctx, UnaryInst* inst) {
auto* operand = inst->GetOprd();
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
switch (inst->GetOpcode()) {
case Opcode::Neg:
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, WrapInt32(0u - static_cast<std::uint32_t>(i32)));
}
break;
case Opcode::Not:
if (TryGetBool(operand, i1)) {
return GetBoolConst(ctx, !i1);
}
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, i32 ^ 1);
}
break;
case Opcode::FNeg:
if (TryGetFloat(operand, f32)) {
return GetFloatConst(-f32);
}
break;
case Opcode::FtoI:
if (TryGetFloat(operand, f32)) {
return GetInt32Const(ctx, static_cast<std::int32_t>(f32));
}
break;
case Opcode::IToF:
if (TryGetInt32(operand, i32)) {
return GetFloatConst(static_cast<float>(i32));
}
if (TryGetBool(operand, i1)) {
return GetFloatConst(i1 ? 1.0f : 0.0f);
}
break;
default:
break;
}
return nullptr;
}
Value* FoldZext(Context& ctx, ZextInst* inst) {
auto* value = inst->GetValue();
bool i1 = false;
std::int32_t i32 = 0;
if (inst->GetType()->IsInt1()) {
if (TryGetBool(value, i1)) {
return GetBoolConst(ctx, i1);
}
if (TryGetInt32(value, i32)) {
return GetBoolConst(ctx, i32 != 0);
}
}
if (inst->GetType()->IsInt32()) {
if (TryGetBool(value, i1)) {
return GetInt32Const(ctx, i1 ? 1 : 0);
}
if (TryGetInt32(value, i32)) {
return GetInt32Const(ctx, i32);
}
}
return nullptr;
}
bool FoldFunction(Function& function, Context& ctx) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
Value* replacement = nullptr;
if (auto* binary = dyncast<BinaryInst>(inst)) {
replacement = FoldBinary(ctx, binary);
} else if (auto* unary = dyncast<UnaryInst>(inst)) {
replacement = FoldUnary(ctx, unary);
} else if (auto* zext = dyncast<ZextInst>(inst)) {
replacement = FoldZext(ctx, zext);
}
if (!replacement || replacement == inst) {
continue;
}
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunConstFold(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= FoldFunction(*function, ctx);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,550 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cmath>
#include <cstdint>
#include <limits>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
enum class LatticeKind { Unknown, Constant, Overdefined };
struct ConstantValue {
enum class Kind { Int32, Bool, Float };
Kind kind = Kind::Int32;
std::int32_t int32_value = 0;
bool bool_value = false;
float float_value = 0.0f;
};
struct LatticeValue {
LatticeKind kind = LatticeKind::Unknown;
ConstantValue constant;
};
bool EqualConstants(const ConstantValue& lhs, const ConstantValue& rhs) {
if (lhs.kind != rhs.kind) {
return false;
}
switch (lhs.kind) {
case ConstantValue::Kind::Int32:
return lhs.int32_value == rhs.int32_value;
case ConstantValue::Kind::Bool:
return lhs.bool_value == rhs.bool_value;
case ConstantValue::Kind::Float:
return passutils::FloatBits(lhs.float_value) ==
passutils::FloatBits(rhs.float_value);
}
return false;
}
Value* MaterializeConstant(Context& ctx, const ConstantValue& constant) {
switch (constant.kind) {
case ConstantValue::Kind::Int32:
return ctx.GetConstInt(static_cast<int>(constant.int32_value));
case ConstantValue::Kind::Bool:
return ctx.GetConstBool(constant.bool_value);
case ConstantValue::Kind::Float:
return new ConstantFloat(Type::GetFloatType(), constant.float_value);
}
return nullptr;
}
bool TryGetConstantValue(Value* value, ConstantValue& out) {
if (auto* ci = dyncast<ConstantInt>(value)) {
out.kind = ConstantValue::Kind::Int32;
out.int32_value = static_cast<std::int32_t>(ci->GetValue());
return true;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
out.kind = ConstantValue::Kind::Bool;
out.bool_value = cb->GetValue();
return true;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
out.kind = ConstantValue::Kind::Float;
out.float_value = cf->GetValue();
return true;
}
return false;
}
LatticeValue ConstantLattice(const ConstantValue& constant) {
LatticeValue value;
value.kind = LatticeKind::Constant;
value.constant = constant;
return value;
}
LatticeValue OverdefinedLattice() {
LatticeValue value;
value.kind = LatticeKind::Overdefined;
return value;
}
LatticeValue GetValueState(
Value* value, const std::unordered_map<Value*, LatticeValue>& states) {
ConstantValue constant;
if (TryGetConstantValue(value, constant)) {
return ConstantLattice(constant);
}
auto it = states.find(value);
if (it != states.end()) {
return it->second;
}
return OverdefinedLattice();
}
LatticeValue Meet(LatticeValue lhs, const LatticeValue& rhs) {
if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (lhs.kind == LatticeKind::Unknown) {
return rhs;
}
if (rhs.kind == LatticeKind::Unknown) {
return lhs;
}
if (EqualConstants(lhs.constant, rhs.constant)) {
return lhs;
}
return OverdefinedLattice();
}
bool EvaluateUnary(Opcode opcode, const ConstantValue& operand,
ConstantValue& result) {
switch (opcode) {
case Opcode::Neg:
if (operand.kind != ConstantValue::Kind::Int32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
0u - static_cast<std::uint32_t>(operand.int32_value));
return true;
case Opcode::Not:
if (operand.kind == ConstantValue::Kind::Bool) {
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !operand.bool_value;
return true;
}
if (operand.kind == ConstantValue::Kind::Int32) {
result.kind = ConstantValue::Kind::Int32;
result.int32_value = operand.int32_value ^ 1;
return true;
}
return false;
case Opcode::FNeg:
if (operand.kind != ConstantValue::Kind::Float) {
return false;
}
result.kind = ConstantValue::Kind::Float;
result.float_value = -operand.float_value;
return true;
case Opcode::FtoI:
if (operand.kind != ConstantValue::Kind::Float) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(operand.float_value);
return true;
case Opcode::IToF:
if (operand.kind == ConstantValue::Kind::Int32) {
result.kind = ConstantValue::Kind::Float;
result.float_value = static_cast<float>(operand.int32_value);
return true;
}
if (operand.kind == ConstantValue::Kind::Bool) {
result.kind = ConstantValue::Kind::Float;
result.float_value = operand.bool_value ? 1.0f : 0.0f;
return true;
}
return false;
default:
return false;
}
}
bool EvaluateBinary(Opcode opcode, const ConstantValue& lhs,
const ConstantValue& rhs, ConstantValue& result) {
if (lhs.kind == ConstantValue::Kind::Int32 &&
rhs.kind == ConstantValue::Kind::Int32) {
const auto left = lhs.int32_value;
const auto right = rhs.int32_value;
switch (opcode) {
case Opcode::Add:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) + static_cast<std::uint32_t>(right));
return true;
case Opcode::Sub:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) - static_cast<std::uint32_t>(right));
return true;
case Opcode::Mul:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) * static_cast<std::uint32_t>(right));
return true;
case Opcode::Div:
if (right == 0 ||
(left == std::numeric_limits<std::int32_t>::min() && right == -1)) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left / right;
return true;
case Opcode::Rem:
if (right == 0 ||
(left == std::numeric_limits<std::int32_t>::min() && right == -1)) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left % right;
return true;
case Opcode::And:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left & right;
return true;
case Opcode::Or:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left | right;
return true;
case Opcode::Xor:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left ^ right;
return true;
case Opcode::Shl:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value =
static_cast<std::int32_t>(static_cast<std::uint32_t>(left) << right);
return true;
case Opcode::AShr:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left >> right;
return true;
case Opcode::LShr:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) >> right);
return true;
case Opcode::ICmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left == right;
return true;
case Opcode::ICmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left < right;
return true;
case Opcode::ICmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left > right;
return true;
case Opcode::ICmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left <= right;
return true;
case Opcode::ICmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left >= right;
return true;
default:
break;
}
}
if (lhs.kind == ConstantValue::Kind::Bool && rhs.kind == ConstantValue::Kind::Bool) {
const auto left = lhs.bool_value;
const auto right = rhs.bool_value;
switch (opcode) {
case Opcode::And:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left && right;
return true;
case Opcode::Or:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left || right;
return true;
case Opcode::Xor:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left == right;
return true;
case Opcode::ICmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) < static_cast<int>(right);
return true;
case Opcode::ICmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) > static_cast<int>(right);
return true;
case Opcode::ICmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) <= static_cast<int>(right);
return true;
case Opcode::ICmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) >= static_cast<int>(right);
return true;
default:
break;
}
}
if (lhs.kind == ConstantValue::Kind::Float &&
rhs.kind == ConstantValue::Kind::Float) {
const auto left = lhs.float_value;
const auto right = rhs.float_value;
switch (opcode) {
case Opcode::FAdd:
result.kind = ConstantValue::Kind::Float;
result.float_value = left + right;
return true;
case Opcode::FSub:
result.kind = ConstantValue::Kind::Float;
result.float_value = left - right;
return true;
case Opcode::FMul:
result.kind = ConstantValue::Kind::Float;
result.float_value = left * right;
return true;
case Opcode::FDiv:
result.kind = ConstantValue::Kind::Float;
result.float_value = left / right;
return true;
case Opcode::FRem:
result.kind = ConstantValue::Kind::Float;
result.float_value = std::fmod(left, right);
return true;
case Opcode::FCmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left == right;
return true;
case Opcode::FCmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left != right;
return true;
case Opcode::FCmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left < right;
return true;
case Opcode::FCmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left > right;
return true;
case Opcode::FCmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left <= right;
return true;
case Opcode::FCmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left >= right;
return true;
default:
break;
}
}
return false;
}
LatticeValue EvaluateInstruction(
Instruction* inst, const std::unordered_map<Value*, LatticeValue>& states) {
if (!inst || inst->IsVoid()) {
return OverdefinedLattice();
}
if (auto* phi = dyncast<PhiInst>(inst)) {
LatticeValue merged;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
merged = Meet(merged, GetValueState(phi->GetIncomingValue(i), states));
if (merged.kind == LatticeKind::Overdefined) {
break;
}
}
return merged;
}
if (auto* binary = dyncast<BinaryInst>(inst)) {
const auto lhs = GetValueState(binary->GetLhs(), states);
const auto rhs = GetValueState(binary->GetRhs(), states);
if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (lhs.kind != LatticeKind::Constant || rhs.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (!EvaluateBinary(binary->GetOpcode(), lhs.constant, rhs.constant, folded)) {
return OverdefinedLattice();
}
return ConstantLattice(folded);
}
if (auto* unary = dyncast<UnaryInst>(inst)) {
const auto operand = GetValueState(unary->GetOprd(), states);
if (operand.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (operand.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (!EvaluateUnary(unary->GetOpcode(), operand.constant, folded)) {
return OverdefinedLattice();
}
return ConstantLattice(folded);
}
if (auto* zext = dyncast<ZextInst>(inst)) {
const auto operand = GetValueState(zext->GetValue(), states);
if (operand.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (operand.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (zext->GetType()->IsInt1()) {
folded.kind = ConstantValue::Kind::Bool;
if (operand.constant.kind == ConstantValue::Kind::Bool) {
folded.bool_value = operand.constant.bool_value;
return ConstantLattice(folded);
}
if (operand.constant.kind == ConstantValue::Kind::Int32) {
folded.bool_value = operand.constant.int32_value != 0;
return ConstantLattice(folded);
}
return OverdefinedLattice();
}
if (zext->GetType()->IsInt32()) {
folded.kind = ConstantValue::Kind::Int32;
if (operand.constant.kind == ConstantValue::Kind::Bool) {
folded.int32_value = operand.constant.bool_value ? 1 : 0;
return ConstantLattice(folded);
}
if (operand.constant.kind == ConstantValue::Kind::Int32) {
folded.int32_value = operand.constant.int32_value;
return ConstantLattice(folded);
}
}
return OverdefinedLattice();
}
return OverdefinedLattice();
}
bool RewriteFunction(Function& function, Context& ctx) {
if (function.IsExternal()) {
return false;
}
std::unordered_map<Value*, LatticeValue> states;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst->IsVoid()) {
states[inst] = {};
}
}
}
bool changed = true;
while (changed) {
changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsVoid()) {
continue;
}
const auto evaluated = EvaluateInstruction(inst, states);
if (evaluated.kind != states[inst].kind ||
(evaluated.kind == LatticeKind::Constant &&
!EqualConstants(evaluated.constant, states[inst].constant))) {
states[inst] = evaluated;
changed = true;
}
}
}
}
bool rewritten = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
if (isa<BasicBlock>(operand) || isa<Function>(operand) || operand->IsConstant()) {
continue;
}
const auto state = GetValueState(operand, states);
if (state.kind != LatticeKind::Constant) {
continue;
}
inst->SetOperand(i, MaterializeConstant(ctx, state.constant));
rewritten = true;
}
if (inst->IsVoid()) {
continue;
}
const auto state = states[inst];
if (state.kind != LatticeKind::Constant) {
continue;
}
inst->ReplaceAllUsesWith(MaterializeConstant(ctx, state.constant));
to_remove.push_back(inst);
rewritten = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return rewritten;
}
} // namespace
bool RunConstProp(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RewriteFunction(*function, ctx);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,55 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <vector>
namespace ir {
namespace {
bool RunDCEOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!passutils::IsTriviallyDead(inst)) {
continue;
}
to_remove.push_back(inst);
}
if (to_remove.empty()) {
continue;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
local_changed = true;
changed = true;
}
}
return changed;
}
} // namespace
bool RunDCE(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunDCEOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,196 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "MemoryUtils.h"
#include "PassUtils.h"
#include <algorithm>
#include <cstdint>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::uintptr_t result_type = 0;
std::uintptr_t aux_type = 0;
std::vector<std::uintptr_t> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && result_type == rhs.result_type &&
aux_type == rhs.aux_type && operands == rhs.operands;
}
};
struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
h ^= std::hash<std::uintptr_t>{}(key.result_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
h ^= std::hash<std::uintptr_t>{}(key.aux_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
struct ScopedExpr {
ExprKey key;
Value* previous = nullptr;
bool had_previous = false;
};
bool IsSupportedGVNInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::GetElementPtr:
case Opcode::Zext:
return true;
case Opcode::Call:
return memutils::IsPureCall(dyncast<CallInst>(inst));
default:
return false;
}
}
ExprKey BuildExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.result_type =
reinterpret_cast<std::uintptr_t>(inst->GetType().get());
if (auto* gep = dyncast<GetElementPtrInst>(inst)) {
key.aux_type = reinterpret_cast<std::uintptr_t>(gep->GetSourceType().get());
}
key.operands.reserve(inst->GetNumOperands());
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 &&
passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
std::swap(key.operands[0], key.operands[1]);
}
return key;
}
bool RunGVNInDomSubtree(
BasicBlock* block, const DominatorTree& dom_tree,
std::unordered_map<ExprKey, Value*, ExprKeyHash>& available) {
if (!block) {
return false;
}
bool changed = false;
std::vector<ScopedExpr> scope;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedGVNInstruction(inst)) {
continue;
}
const auto key = BuildExprKey(inst);
auto it = available.find(key);
if (it != available.end()) {
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
continue;
}
ScopedExpr scoped{key, nullptr, false};
auto existing = available.find(key);
if (existing != available.end()) {
scoped.previous = existing->second;
scoped.had_previous = true;
existing->second = inst;
} else {
available.emplace(key, inst);
}
scope.push_back(std::move(scoped));
}
for (auto* inst : to_remove) {
block->EraseInstruction(inst);
}
for (auto* child : dom_tree.GetChildren(block)) {
changed |= RunGVNInDomSubtree(child, dom_tree, available);
}
for (auto it = scope.rbegin(); it != scope.rend(); ++it) {
if (it->had_previous) {
available[it->key] = it->previous;
} else {
available.erase(it->key);
}
}
return changed;
}
bool RunGVNOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
DominatorTree dom_tree(function);
std::unordered_map<ExprKey, Value*, ExprKeyHash> available;
return RunGVNInDomSubtree(function.GetEntryBlock(), dom_tree, available);
}
} // namespace
bool RunGVN(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunGVNOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,692 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <algorithm>
#include <cstdint>
#include <unordered_set>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct InlineCandidateInfo {
bool valid = false;
int cost = 0;
bool has_nested_call = false;
bool has_control_flow = false;
};
bool IsInlineableInstruction(const Instruction* inst) {
if (!inst) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Load:
case Opcode::Store:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Memset:
case Opcode::Call:
case Opcode::Return:
case Opcode::Br:
case Opcode::CondBr:
return true;
default:
return false;
}
}
int EstimateInstructionCost(const Instruction* inst) {
if (!inst) {
return 0;
}
switch (inst->GetOpcode()) {
case Opcode::Return:
return 0;
case Opcode::Load:
case Opcode::Store:
case Opcode::Memset:
return 3;
case Opcode::Call:
return 8;
case Opcode::GetElementPtr:
return 2;
default:
return 1;
}
}
InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) {
InlineCandidateInfo info;
if (function.IsExternal() || function.IsRecursive()) {
return info;
}
if (function.GetBlocks().empty() || function.GetBlocks().size() > 4) {
return info;
}
DominatorTree dom_tree(const_cast<Function&>(function));
LoopInfo loop_info(const_cast<Function&>(function), dom_tree);
if (!loop_info.GetLoops().empty()) {
return info;
}
bool saw_return = false;
for (const auto& block : function.GetBlocks()) {
if (!block || block->GetInstructions().empty()) {
return info;
}
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
auto* inst = block->GetInstructions()[i].get();
if (!IsInlineableInstruction(inst) || dyncast<PhiInst>(inst) ||
dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst)) {
return {};
}
if (dyncast<ReturnInst>(inst)) {
if (i + 1 != block->GetInstructions().size()) {
return {};
}
saw_return = true;
continue;
}
if ((dyncast<UncondBrInst>(inst) || dyncast<CondBrInst>(inst)) &&
i + 1 != block->GetInstructions().size()) {
return {};
}
if (dyncast<CondBrInst>(inst) || dyncast<UncondBrInst>(inst)) {
info.has_control_flow = true;
}
if (dyncast<CallInst>(inst)) {
info.has_nested_call = true;
}
info.cost += EstimateInstructionCost(inst);
}
}
if (!saw_return) {
return {};
}
info.valid = true;
return info;
}
std::unordered_map<Function*, int> CountDirectCalls(Module& module) {
std::unordered_map<Function*, int> counts;
for (const auto& function_ptr : module.GetFunctions()) {
if (!function_ptr) {
continue;
}
for (const auto& block_ptr : function_ptr->GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
if (auto* callee = call->GetCallee()) {
++counts[callee];
}
}
}
}
}
return counts;
}
bool ShouldInlineCallSite(const Function& caller, const CallInst& call,
const InlineCandidateInfo& callee_info, int call_count) {
auto* callee = call.GetCallee();
if (!callee || callee == &caller || !callee_info.valid) {
return false;
}
int budget = callee->CanDiscardUnusedCall() ? 40 : 24;
if (call_count <= 1) {
budget += 12;
}
if (callee_info.has_nested_call) {
budget -= 8;
}
if (callee_info.has_control_flow) {
budget -= 6;
}
if (callee->MayWriteMemory()) {
budget -= 4;
}
return callee_info.cost <= budget;
}
Instruction* CloneInstructionAt(Function& function, Instruction* inst, BasicBlock* dest,
std::size_t insert_index,
std::unordered_map<Value*, Value*>& remap) {
if (!inst || !dest) {
return nullptr;
}
const auto name = inst->IsVoid() ? std::string()
: looputils::NextSyntheticName(function, "inline.");
auto remap_operand = [&](Value* value) { return looputils::RemapValue(remap, value); };
auto remember = [&](Instruction* clone) {
if (clone && !inst->IsVoid()) {
remap[inst] = clone;
}
return clone;
};
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<BinaryInst*>(inst);
return remember(dest->Insert<BinaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(bin->GetLhs()),
remap_operand(bin->GetRhs()), nullptr, name));
}
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF: {
auto* un = static_cast<UnaryInst*>(inst);
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(un->GetOprd()), nullptr, name));
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
remap_operand(load->GetPtr()), nullptr, name));
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return dest->Insert<StoreInst>(insert_index, remap_operand(store->GetValue()),
remap_operand(store->GetPtr()), nullptr);
}
case Opcode::Memset: {
auto* memset = static_cast<MemsetInst*>(inst);
return dest->Insert<MemsetInst>(insert_index, remap_operand(memset->GetDest()),
remap_operand(memset->GetValue()),
remap_operand(memset->GetLength()),
remap_operand(memset->GetIsVolatile()), nullptr);
}
case Opcode::GetElementPtr: {
auto* gep = static_cast<GetElementPtrInst*>(inst);
std::vector<Value*> indices;
indices.reserve(gep->GetNumIndices());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
indices.push_back(remap_operand(gep->GetIndex(i)));
}
return remember(dest->Insert<GetElementPtrInst>(
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()), indices, nullptr,
name));
}
case Opcode::Zext: {
auto* zext = static_cast<ZextInst*>(inst);
return remember(dest->Insert<ZextInst>(insert_index, remap_operand(zext->GetValue()),
inst->GetType(), nullptr, name));
}
case Opcode::Call: {
auto* call = static_cast<CallInst*>(inst);
std::vector<Value*> args;
const auto original_args = call->GetArguments();
args.reserve(original_args.size());
for (auto* arg : original_args) {
args.push_back(remap_operand(arg));
}
return remember(
dest->Insert<CallInst>(insert_index, call->GetCallee(), args, nullptr, name));
}
case Opcode::Return:
case Opcode::Alloca:
case Opcode::Phi:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Unreachable:
break;
}
return nullptr;
}
bool InlineCallSite(Function& caller, CallInst* call) {
if (!call) {
return false;
}
auto* callee = call->GetCallee();
if (!callee || callee->GetBlocks().size() != 1) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
auto* block = call->GetParent();
if (!block) {
return false;
}
auto& instructions = block->GetInstructions();
auto call_it = std::find_if(instructions.begin(), instructions.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == call;
});
if (call_it == instructions.end()) {
return false;
}
std::size_t insert_index = static_cast<std::size_t>(call_it - instructions.begin());
std::unordered_map<Value*, Value*> remap;
for (std::size_t i = 0; i < call_args.size(); ++i) {
remap[callee_args[i].get()] = call_args[i];
}
Value* return_value = nullptr;
for (const auto& inst_ptr : callee->GetBlocks().front()->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* ret = dyncast<ReturnInst>(inst)) {
if (ret->HasReturnValue()) {
return_value = looputils::RemapValue(remap, ret->GetReturnValue());
}
break;
}
if (!CloneInstructionAt(caller, inst, block, insert_index, remap)) {
return false;
}
++insert_index;
}
if (!call->GetType()->IsVoid()) {
if (!return_value) {
return false;
}
call->ReplaceAllUsesWith(return_value);
}
block->EraseInstruction(call);
return true;
}
void ReplaceIncomingBlock(BasicBlock* block, BasicBlock* old_pred, BasicBlock* new_pred) {
if (!block || !old_pred || !new_pred) {
return;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int index = looputils::GetPhiIncomingIndex(phi, old_pred);
if (index >= 0) {
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_pred);
}
}
}
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::vector<BasicBlock*> stack{entry};
std::unordered_set<BasicBlock*> visited;
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it) {
stack.push_back(*it);
}
}
}
return order;
}
BasicBlock* SplitBlockAfterCall(Function& caller, BasicBlock* block, CallInst* call) {
if (!block || !call) {
return nullptr;
}
auto& instructions = block->GetInstructions();
auto call_it = std::find_if(instructions.begin(), instructions.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == call;
});
if (call_it == instructions.end() || std::next(call_it) == instructions.end()) {
return nullptr;
}
auto* continuation =
caller.CreateBlock(looputils::NextSyntheticBlockName(caller, "inline.cont"));
auto& continuation_insts = continuation->GetInstructions();
for (auto it = std::next(call_it); it != instructions.end(); ++it) {
(*it)->SetParent(continuation);
continuation_insts.push_back(std::move(*it));
}
instructions.erase(std::next(call_it), instructions.end());
auto old_succs = block->GetSuccessors();
for (auto* succ : old_succs) {
block->RemoveSuccessor(succ);
succ->RemovePredecessor(block);
succ->AddPredecessor(continuation);
continuation->AddSuccessor(succ);
ReplaceIncomingBlock(succ, block, continuation);
}
return continuation;
}
bool CanInlineCFGCallSite(Function& caller, CallInst* call,
std::vector<BasicBlock*>& callee_blocks) {
auto* callee = call ? call->GetCallee() : nullptr;
if (!call || !callee || callee->GetBlocks().size() <= 1 ||
callee == &caller) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
callee_blocks = CollectReachableBlocks(*callee);
if (callee_blocks.empty()) {
return false;
}
std::unordered_set<BasicBlock*> reachable(callee_blocks.begin(), callee_blocks.end());
for (auto* block : callee_blocks) {
if (!block || block->GetInstructions().empty()) {
return false;
}
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
auto* inst = block->GetInstructions()[i].get();
if (dyncast<PhiInst>(inst) || dyncast<AllocaInst>(inst) ||
dyncast<UnreachableInst>(inst) || !IsInlineableInstruction(inst)) {
return false;
}
if (auto* br = dyncast<UncondBrInst>(inst)) {
if (i + 1 != block->GetInstructions().size() ||
reachable.count(br->GetDest()) == 0) {
return false;
}
continue;
}
if (auto* condbr = dyncast<CondBrInst>(inst)) {
if (i + 1 != block->GetInstructions().size() ||
reachable.count(condbr->GetThenBlock()) == 0 ||
reachable.count(condbr->GetElseBlock()) == 0) {
return false;
}
continue;
}
if (dyncast<ReturnInst>(inst)) {
if (i + 1 != block->GetInstructions().size()) {
return false;
}
continue;
}
if (inst->IsTerminator() || !looputils::IsCloneableInstruction(inst)) {
return false;
}
}
}
return true;
}
bool InlineCFGCallSite(Function& caller, CallInst* call) {
auto* callee = call ? call->GetCallee() : nullptr;
if (!call || !callee || callee->GetBlocks().size() <= 1) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
std::vector<BasicBlock*> callee_blocks;
if (!CanInlineCFGCallSite(caller, call, callee_blocks)) {
return false;
}
auto* call_block = call->GetParent();
auto* continuation = SplitBlockAfterCall(caller, call_block, call);
if (!call_block || !continuation) {
return false;
}
std::unordered_map<Value*, Value*> remap;
for (std::size_t i = 0; i < call_args.size(); ++i) {
remap[callee_args[i].get()] = call_args[i];
}
std::unordered_map<BasicBlock*, BasicBlock*> block_map;
for (auto* block : callee_blocks) {
block_map[block] =
caller.CreateBlock(looputils::NextSyntheticBlockName(caller, "inline.bb"));
}
std::vector<std::pair<BasicBlock*, Value*>> return_edges;
for (auto* block : callee_blocks) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst)) {
return false;
}
if (auto* ret = dyncast<ReturnInst>(inst)) {
clone->Append<UncondBrInst>(continuation, nullptr);
clone->AddSuccessor(continuation);
continuation->AddPredecessor(clone);
return_edges.emplace_back(
clone, ret->HasReturnValue() ? looputils::RemapValue(remap, ret->GetReturnValue())
: nullptr);
continue;
}
if (auto* br = dyncast<UncondBrInst>(inst)) {
auto* target = block_map.at(br->GetDest());
clone->Append<UncondBrInst>(target, nullptr);
clone->AddSuccessor(target);
target->AddPredecessor(clone);
continue;
}
if (auto* condbr = dyncast<CondBrInst>(inst)) {
auto* then_block = block_map.at(condbr->GetThenBlock());
auto* else_block = block_map.at(condbr->GetElseBlock());
clone->Append<CondBrInst>(looputils::RemapValue(remap, condbr->GetCondition()),
then_block, else_block, nullptr);
clone->AddSuccessor(then_block);
clone->AddSuccessor(else_block);
then_block->AddPredecessor(clone);
else_block->AddPredecessor(clone);
continue;
}
if (!CloneInstructionAt(caller, inst, clone,
looputils::GetTerminatorIndex(clone), remap)) {
return false;
}
}
}
call_block->Append<UncondBrInst>(block_map.at(callee->GetEntryBlock()), nullptr);
call_block->AddSuccessor(block_map.at(callee->GetEntryBlock()));
block_map.at(callee->GetEntryBlock())->AddPredecessor(call_block);
if (!call->GetType()->IsVoid()) {
Value* return_value = nullptr;
if (return_edges.size() == 1) {
return_value = return_edges.front().second;
} else {
auto* phi = continuation->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(continuation), call->GetType(), nullptr,
looputils::NextSyntheticName(caller, "inline.ret."));
for (const auto& [pred, value] : return_edges) {
if (!value) {
return false;
}
phi->AddIncoming(value, pred);
}
return_value = phi;
}
if (!return_value) {
return false;
}
call->ReplaceAllUsesWith(return_value);
}
call_block->EraseInstruction(call);
return true;
}
bool RunFunctionInliningOnFunction(
Function& function,
const std::unordered_map<Function*, InlineCandidateInfo>& callee_info,
const std::unordered_map<Function*, int>& call_counts) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
std::vector<BasicBlock*> block_snapshot;
block_snapshot.reserve(function.GetBlocks().size());
for (const auto& block_ptr : function.GetBlocks()) {
if (block_ptr) {
block_snapshot.push_back(block_ptr.get());
}
}
for (auto* block : block_snapshot) {
if (!block) {
continue;
}
std::vector<CallInst*> calls;
for (const auto& inst_ptr : block->GetInstructions()) {
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
calls.push_back(call);
}
}
for (auto* call : calls) {
auto* callee = call->GetCallee();
if (!callee) {
continue;
}
auto info_it = callee_info.find(callee);
if (info_it == callee_info.end()) {
continue;
}
const int call_count =
call_counts.count(callee) != 0 ? call_counts.at(callee) : 0;
if (!ShouldInlineCallSite(function, *call, info_it->second, call_count)) {
continue;
}
if (callee->GetBlocks().size() == 1) {
changed |= InlineCallSite(function, call);
} else {
changed |= InlineCFGCallSite(function, call);
}
}
}
return changed;
}
} // namespace
bool RunFunctionInlining(Module& module) {
std::unordered_map<Function*, InlineCandidateInfo> callee_info;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
callee_info.emplace(function_ptr.get(), AnalyzeInlineCandidate(*function_ptr));
}
}
const auto call_counts = CountDirectCalls(module);
bool changed = false;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
changed |= RunFunctionInliningOnFunction(*function_ptr, callee_info, call_counts);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,236 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <cstdint>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct HoistedLoadKey {
memutils::AddressKey address;
std::uintptr_t type_id = 0;
bool operator==(const HoistedLoadKey& rhs) const {
return type_id == rhs.type_id && address == rhs.address;
}
};
struct HoistedLoadKeyHash {
std::size_t operator()(const HoistedLoadKey& key) const {
std::size_t h = memutils::AddressKeyHash{}(key.address);
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};
bool IsHoistableInstruction(const Instruction* inst) {
if (!inst || inst->IsTerminator() || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Load:
return true;
default:
return false;
}
}
bool IsLoopInvariantInstruction(
const Loop& loop, Instruction* inst,
const std::unordered_set<Instruction*>& invariant_insts,
PhiInst* iv, int iv_stride,
const std::vector<loopmem::MemoryAccessInfo>& accesses,
const memutils::EscapeSummary& escapes) {
if (!IsHoistableInstruction(inst)) {
return false;
}
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
auto* operand_inst = dyncast<Instruction>(operand);
if (!operand_inst) {
continue;
}
if (!loop.Contains(operand_inst->GetParent())) {
continue;
}
if (invariant_insts.find(operand_inst) == invariant_insts.end()) {
return false;
}
}
if (auto* load = dyncast<LoadInst>(inst)) {
return loopmem::IsSafeInvariantLoadToHoist(loop, load, iv, iv_stride, accesses, &escapes);
}
return true;
}
bool HoistLoopInvariants(Function& function, const Loop& loop,
BasicBlock* preheader) {
if (!preheader) {
return false;
}
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
int iv_stride = 1;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (loopmem::MatchSimpleInductionVariable(loop, preheader, phi, induction_var)) {
iv = induction_var.phi;
iv_stride = induction_var.stride;
break;
}
}
const auto escapes = memutils::AnalyzeEscapes(function);
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
std::unordered_set<Instruction*> invariant_insts;
std::vector<Instruction*> hoist_list;
bool progress = true;
while (progress) {
progress = false;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (!loop.Contains(block) || block == preheader) {
continue;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (invariant_insts.find(inst) != invariant_insts.end()) {
continue;
}
if (!IsLoopInvariantInstruction(loop, inst, invariant_insts, iv, iv_stride,
accesses, escapes)) {
continue;
}
invariant_insts.insert(inst);
hoist_list.push_back(inst);
progress = true;
}
}
}
bool changed = false;
std::unordered_map<HoistedLoadKey, LoadInst*, HoistedLoadKeyHash> hoisted_loads;
for (auto* inst : hoist_list) {
if (auto* load = dyncast<LoadInst>(inst)) {
auto ptr = loopmem::AnalyzePointer(load->GetPtr(), iv, loop,
load->GetType()->GetSize(), &escapes);
if (ptr.exact_key_valid) {
HoistedLoadKey key{ptr.exact_key,
reinterpret_cast<std::uintptr_t>(load->GetType().get())};
auto it = hoisted_loads.find(key);
if (it != hoisted_loads.end()) {
load->ReplaceAllUsesWith(it->second);
load->GetParent()->EraseInstruction(load);
changed = true;
continue;
}
auto* moved = dyncast<LoadInst>(
looputils::MoveInstructionBeforeTerminator(load, preheader));
if (moved) {
hoisted_loads.emplace(std::move(key), moved);
changed = true;
}
continue;
}
}
if (looputils::MoveInstructionBeforeTerminator(inst, preheader)) {
changed = true;
}
}
return changed;
}
bool RunLICMOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
bool loop_changed = preheader != old_preheader;
loop_changed |= HoistLoopInvariants(function, *loop, preheader);
if (!loop_changed) {
continue;
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLICM(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLICMOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,319 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "MemoryUtils.h"
#include "PassUtils.h"
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct AvailableValue {
Value* value = nullptr;
bool operator==(const AvailableValue& rhs) const {
return passutils::AreEquivalentValues(value, rhs.value) || value == rhs.value;
}
};
using MemoryState =
std::unordered_map<memutils::AddressKey, AvailableValue,
memutils::AddressKeyHash>;
bool SameAvailableValue(const AvailableValue& lhs, const AvailableValue& rhs) {
return lhs == rhs;
}
bool SameMemoryState(const MemoryState& lhs, const MemoryState& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& [key, value] : lhs) {
auto it = rhs.find(key);
if (it == rhs.end() || !SameAvailableValue(value, it->second)) {
return false;
}
}
return true;
}
MemoryState MeetMemoryStates(const std::vector<MemoryState*>& predecessors) {
if (predecessors.empty()) {
return {};
}
MemoryState in = *predecessors.front();
for (auto it = in.begin(); it != in.end();) {
bool keep = true;
for (std::size_t i = 1; i < predecessors.size(); ++i) {
auto pred_it = predecessors[i]->find(it->first);
if (pred_it == predecessors[i]->end() ||
!SameAvailableValue(it->second, pred_it->second)) {
keep = false;
break;
}
}
if (!keep) {
it = in.erase(it);
continue;
}
++it;
}
return in;
}
void InvalidateAliasStates(MemoryState& state,
const memutils::AddressKey& key) {
for (auto it = state.begin(); it != state.end();) {
if (memutils::MayAliasConservatively(it->first, key)) {
it = state.erase(it);
continue;
}
++it;
}
}
void InvalidateStatesForCall(MemoryState& state, Function* callee) {
for (auto it = state.begin(); it != state.end();) {
if (memutils::CallMayWriteRoot(callee, it->first.kind)) {
it = state.erase(it);
continue;
}
++it;
}
}
void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* inst,
MemoryState& state) {
if (!inst) {
return;
}
if (auto* load = dyncast<LoadInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
}
return;
}
if (auto* store = dyncast<StoreInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state.clear();
return;
}
InvalidateAliasStates(state, key);
state[key] = {store->GetValue()};
return;
}
if (auto* call = dyncast<CallInst>(inst)) {
InvalidateStatesForCall(state, call->GetCallee());
return;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
state.clear();
return;
}
InvalidateAliasStates(state, key);
return;
}
}
MemoryState SimulateBlock(const memutils::EscapeSummary& escapes, BasicBlock* block,
const MemoryState& in_state) {
MemoryState state = in_state;
for (const auto& inst_ptr : block->GetInstructions()) {
SimulateInstruction(escapes, inst_ptr.get(), state);
}
return state;
}
bool MarkLoadObserved(
const memutils::AddressKey& key,
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
pending_stores) {
bool changed = false;
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (memutils::MayAliasConservatively(it->first, key)) {
it = pending_stores.erase(it);
changed = true;
continue;
}
++it;
}
return changed;
}
void InvalidatePendingForCall(
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
pending_stores,
Function* callee) {
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (memutils::CallMayReadRoot(callee, it->first.kind) ||
memutils::CallMayWriteRoot(callee, it->first.kind)) {
it = pending_stores.erase(it);
continue;
}
++it;
}
}
bool OptimizeBlock(
const memutils::EscapeSummary& escapes, BasicBlock* block,
const MemoryState& in_state) {
bool changed = false;
MemoryState state = in_state;
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>
pending_stores;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dyncast<LoadInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
MarkLoadObserved(key, pending_stores);
auto it = state.find(key);
if (it != state.end() && it->second.value != load) {
load->ReplaceAllUsesWith(it->second.value);
to_remove.push_back(load);
changed = true;
continue;
}
// Keep block-local load reuse, but do not expose load results to cross-block
// dataflow because the defining load itself may be removed later.
state[key] = {load};
continue;
}
if (auto* store = dyncast<StoreInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (!memutils::MayAliasConservatively(it->first, key)) {
++it;
continue;
}
if (it->first == key) {
to_remove.push_back(it->second);
changed = true;
}
it = pending_stores.erase(it);
}
pending_stores.emplace(key, store);
InvalidateAliasStates(state, key);
state[key] = {store->GetValue()};
continue;
}
if (auto* call = dyncast<CallInst>(inst)) {
InvalidateStatesForCall(state, call->GetCallee());
InvalidatePendingForCall(pending_stores, call->GetCallee());
continue;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
InvalidateAliasStates(state, key);
MarkLoadObserved(key, pending_stores);
continue;
}
}
for (auto* inst : to_remove) {
if (inst->GetParent() == block) {
block->EraseInstruction(inst);
}
}
return changed;
}
bool RunLoadStoreElimOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
const auto escapes = memutils::AnalyzeEscapes(function);
const auto reachable_blocks = passutils::CollectReachableBlocks(function);
if (reachable_blocks.empty()) {
return false;
}
std::unordered_map<BasicBlock*, MemoryState> in_states;
std::unordered_map<BasicBlock*, MemoryState> out_states;
bool dataflow_changed = true;
while (dataflow_changed) {
dataflow_changed = false;
for (auto* block : reachable_blocks) {
MemoryState in_state;
if (block != function.GetEntryBlock()) {
std::vector<MemoryState*> predecessors;
for (auto* pred : block->GetPredecessors()) {
auto it = out_states.find(pred);
if (it != out_states.end()) {
predecessors.push_back(&it->second);
}
}
in_state = MeetMemoryStates(predecessors);
}
auto out_state = SimulateBlock(escapes, block, in_state);
auto in_it = in_states.find(block);
if (in_it == in_states.end() || !SameMemoryState(in_it->second, in_state)) {
in_states[block] = in_state;
dataflow_changed = true;
}
auto out_it = out_states.find(block);
if (out_it == out_states.end() || !SameMemoryState(out_it->second, out_state)) {
out_states[block] = std::move(out_state);
dataflow_changed = true;
}
}
}
bool changed = false;
for (auto* block : reachable_blocks) {
changed |= OptimizeBlock(escapes, block, in_states[block]);
}
return changed;
}
} // namespace
bool RunLoadStoreElim(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoadStoreElimOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,326 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct FissionLoopInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
CondBrInst* branch = nullptr;
BinaryInst* compare = nullptr;
Opcode compare_opcode = Opcode::ICmpLT;
Value* bound = nullptr;
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
BinaryInst* step_inst = nullptr;
};
bool HasSyntheticLoopTag(const std::string& name) {
return name.find("unroll.") != std::string::npos ||
name.find("fission.") != std::string::npos;
}
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
if (!loop.preheader || !loop.header || !body) {
return true;
}
return HasSyntheticLoopTag(loop.preheader->GetName()) ||
HasSyntheticLoopTag(loop.header->GetName()) ||
HasSyntheticLoopTag(body->GetName());
}
Opcode SwapCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
return Opcode::ICmpGT;
case Opcode::ICmpLE:
return Opcode::ICmpGE;
case Opcode::ICmpGT:
return Opcode::ICmpLT;
case Opcode::ICmpGE:
return Opcode::ICmpLE;
default:
return opcode;
}
}
bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
return false;
}
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
return false;
}
if (IsAlreadyTransformedLoop(loop, body)) {
return false;
}
std::vector<PhiInst*> phis;
loopmem::SimpleInductionVar induction_var;
bool found_iv = false;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
if (!found_iv &&
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
found_iv = true;
}
}
if (!found_iv || phis.size() != 1) {
return false;
}
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
if (!branch || branch->GetThenBlock() != body || !compare) {
return false;
}
Opcode compare_opcode = compare->GetOpcode();
Value* bound = nullptr;
if (compare->GetLhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
bound = compare->GetRhs();
} else if (compare->GetRhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
bound = compare->GetLhs();
compare_opcode = SwapCompareOpcode(compare_opcode);
} else {
return false;
}
auto* step_inst = dyncast<BinaryInst>(induction_var.latch_value);
if (!step_inst || step_inst->GetParent() != body) {
return false;
}
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == step_inst) {
continue;
}
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
return false;
}
}
info.loop = &loop;
info.preheader = loop.preheader;
info.header = loop.header;
info.body = body;
info.exit = exit;
info.branch = branch;
info.compare = compare;
info.compare_opcode = compare_opcode;
info.bound = bound;
info.induction_var = induction_var;
info.iv = induction_var.phi;
info.step_inst = step_inst;
return true;
}
bool ContainsInterestingPayload(const std::vector<Instruction*>& group) {
bool has_memory = false;
for (auto* inst : group) {
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
has_memory = true;
}
}
return has_memory;
}
Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) {
if (value == old_iv) {
return new_iv;
}
return value;
}
bool BuildSecondLoop(Function& function, const FissionLoopInfo& info,
const std::vector<Instruction*>& second_group) {
auto* second_header =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header"));
auto* second_body =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body"));
const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader);
if (preheader_index < 0) {
return false;
}
auto* second_iv = second_header->Append<PhiInst>(
info.iv->GetType(), nullptr,
looputils::NextSyntheticName(function, "fission.iv."));
second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header);
auto* second_cmp = second_header->Append<BinaryInst>(
info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr,
looputils::NextSyntheticName(function, "fission.cmp."));
second_header->Append<CondBrInst>(second_cmp, second_body, info.exit, nullptr);
second_header->AddPredecessor(info.header);
second_header->AddSuccessor(second_body);
second_header->AddSuccessor(info.exit);
std::unordered_map<Value*, Value*> remap;
remap[info.iv] = second_iv;
std::unordered_set<Instruction*> selected(second_group.begin(), second_group.end());
selected.insert(info.step_inst);
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || selected.find(inst) == selected.end()) {
continue;
}
looputils::CloneInstruction(function, inst, second_body, remap, "fission.");
}
auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst);
if (!cloned_step_value) {
return false;
}
second_iv->AddIncoming(cloned_step_value, second_body);
second_body->Append<UncondBrInst>(second_header, nullptr);
second_body->AddPredecessor(second_header);
second_body->AddSuccessor(second_header);
second_header->AddPredecessor(second_body);
if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) {
return false;
}
info.exit->RemovePredecessor(info.header);
info.exit->AddPredecessor(second_header);
for (const auto& inst_ptr : info.exit->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int incoming = looputils::GetPhiIncomingIndex(phi, info.header);
if (incoming < 0) {
continue;
}
phi->SetOperand(static_cast<std::size_t>(2 * incoming),
RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv));
phi->SetOperand(static_cast<std::size_t>(2 * incoming + 1), second_header);
}
return true;
}
bool RunLoopFissionOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
FissionLoopInfo info;
if (!MatchFissionLoop(*loop, info)) {
continue;
}
const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv);
std::vector<Instruction*> payload;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == info.step_inst) {
continue;
}
payload.push_back(inst);
}
if (payload.size() < 2) {
continue;
}
int chosen_cut = -1;
std::vector<Instruction*> first_group;
std::vector<Instruction*> second_group;
for (std::size_t cut = 1; cut < payload.size(); ++cut) {
std::vector<Instruction*> first(payload.begin(), payload.begin() + static_cast<long long>(cut));
std::vector<Instruction*> second(payload.begin() + static_cast<long long>(cut),
payload.end());
if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) {
continue;
}
std::unordered_set<Instruction*> first_set(first.begin(), first.end());
std::unordered_set<Instruction*> second_set(second.begin(), second.end());
if (loopmem::HasScalarDependenceAcrossCut(first, second_set) ||
loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set,
info.induction_var.stride)) {
continue;
}
chosen_cut = static_cast<int>(cut);
first_group = std::move(first);
second_group = std::move(second);
break;
}
if (chosen_cut < 0) {
continue;
}
std::unordered_set<Instruction*> keep(first_group.begin(), first_group.end());
keep.insert(info.step_inst);
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || keep.find(inst) != keep.end()) {
continue;
}
to_remove.push_back(inst);
}
if (!BuildSecondLoop(function, info, second_group)) {
continue;
}
for (auto* inst : to_remove) {
info.body->EraseInstruction(inst);
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopFission(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopFissionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,855 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct DominatorInfo {
std::vector<BasicBlock*> blocks;
std::unordered_map<BasicBlock*, size_t> index;
std::vector<std::vector<bool>> dominates;
std::vector<BasicBlock*> idom;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_tree_children;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dominance_frontier;
};
enum class SeedStateKind { Unavailable, Available, Conflict };
struct SeedState {
SeedStateKind kind = SeedStateKind::Unavailable;
StoreInst* store = nullptr;
bool operator==(const SeedState& rhs) const {
return kind == rhs.kind && store == rhs.store;
}
bool operator!=(const SeedState& rhs) const { return !(*this == rhs); }
};
struct CandidateKey {
memutils::AddressKey address;
std::uintptr_t type_id = 0;
bool operator==(const CandidateKey& rhs) const {
return type_id == rhs.type_id && address == rhs.address;
}
};
struct CandidateKeyHash {
std::size_t operator()(const CandidateKey& key) const {
std::size_t h = memutils::AddressKeyHash{}(key.address);
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};
struct PromotionCandidate {
CandidateKey key;
std::shared_ptr<Type> value_type;
loopmem::PointerInfo pointer_info;
std::vector<LoadInst*> loads;
std::vector<StoreInst*> stores;
StoreInst* seed_store = nullptr;
Value* canonical_ptr = nullptr;
Value* initial_value = nullptr;
std::unordered_set<BasicBlock*> def_blocks;
std::unordered_map<BasicBlock*, PhiInst*> phis;
int EstimatedBenefit() const {
return static_cast<int>(loads.size()) + 2 * static_cast<int>(stores.size()) - 1;
}
};
bool IsScalarPromotableType(const std::shared_ptr<Type>& type) {
return type && (type->IsInt1() || type->IsInt32() || type->IsFloat());
}
int CountFunctionInstructions(const Function& function) {
int count = 0;
for (const auto& block_ptr : function.GetBlocks()) {
if (!block_ptr) {
continue;
}
count += static_cast<int>(block_ptr->GetInstructions().size());
}
return count;
}
int CountLoopInstructions(const Loop& loop) {
int count = 0;
for (auto* block : loop.block_list) {
if (!block) {
continue;
}
count += static_cast<int>(block->GetInstructions().size());
}
return count;
}
bool ShouldAnalyzeFunction(const Function& function) {
constexpr int kMaxFunctionInstructions = 2000;
return CountFunctionInstructions(function) <= kMaxFunctionInstructions;
}
bool ShouldAnalyzeLoop(const Loop& loop) {
constexpr int kMaxLoopBlocks = 8;
constexpr int kMaxLoopInstructions = 96;
return static_cast<int>(loop.block_list.size()) <= kMaxLoopBlocks &&
CountLoopInstructions(loop) <= kMaxLoopInstructions;
}
bool DominatesBlock(const DominatorInfo& info, BasicBlock* dom, BasicBlock* block) {
if (!dom || !block) {
return false;
}
auto dom_it = info.index.find(dom);
auto block_it = info.index.find(block);
if (dom_it == info.index.end() || block_it == info.index.end()) {
return false;
}
return info.dominates[block_it->second][dom_it->second];
}
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it) {
stack.push_back(*it);
}
}
}
return order;
}
std::vector<bool> IntersectDominators(const std::vector<std::vector<bool>>& doms,
const std::vector<size_t>& pred_indices,
size_t self_index) {
std::vector<bool> result(doms.size(), true);
if (pred_indices.empty()) {
std::fill(result.begin(), result.end(), false);
result[self_index] = true;
return result;
}
result = doms[pred_indices.front()];
for (size_t i = 1; i < pred_indices.size(); ++i) {
const auto& pred_dom = doms[pred_indices[i]];
for (size_t j = 0; j < result.size(); ++j) {
result[j] = result[j] && pred_dom[j];
}
}
result[self_index] = true;
return result;
}
DominatorInfo BuildDominatorInfo(Function& function) {
DominatorInfo info;
info.blocks = CollectReachableBlocks(function);
info.idom.resize(info.blocks.size(), nullptr);
info.dominates.assign(info.blocks.size(),
std::vector<bool>(info.blocks.size(), true));
if (info.blocks.empty()) {
return info;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
info.index[info.blocks[i]] = i;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
std::fill(info.dominates[i].begin(), info.dominates[i].end(), i != 0);
info.dominates[i][i] = true;
}
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 1; i < info.blocks.size(); ++i) {
std::vector<size_t> pred_indices;
for (auto* pred : info.blocks[i]->GetPredecessors()) {
auto it = info.index.find(pred);
if (it != info.index.end()) {
pred_indices.push_back(it->second);
}
}
auto new_dom = IntersectDominators(info.dominates, pred_indices, i);
if (new_dom != info.dominates[i]) {
info.dominates[i] = std::move(new_dom);
changed = true;
}
}
}
for (size_t i = 1; i < info.blocks.size(); ++i) {
BasicBlock* candidate_idom = nullptr;
for (size_t j = 0; j < info.blocks.size(); ++j) {
if (i == j || !info.dominates[i][j]) {
continue;
}
bool is_immediate = true;
for (size_t k = 0; k < info.blocks.size(); ++k) {
if (k == i || k == j || !info.dominates[i][k]) {
continue;
}
if (info.dominates[k][j]) {
is_immediate = false;
break;
}
}
if (is_immediate) {
candidate_idom = info.blocks[j];
break;
}
}
info.idom[i] = candidate_idom;
if (candidate_idom) {
info.dom_tree_children[candidate_idom].push_back(info.blocks[i]);
}
}
for (auto* block : info.blocks) {
info.dominance_frontier[block] = {};
}
for (auto* block : info.blocks) {
std::vector<BasicBlock*> reachable_preds;
for (auto* pred : block->GetPredecessors()) {
if (info.index.find(pred) != info.index.end()) {
reachable_preds.push_back(pred);
}
}
if (reachable_preds.size() < 2) {
continue;
}
auto* idom_block = info.idom[info.index[block]];
for (auto* pred : reachable_preds) {
auto* runner = pred;
while (runner && runner != idom_block) {
auto& frontier = info.dominance_frontier[runner];
if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) {
frontier.push_back(block);
}
auto idom_it = info.index.find(runner);
if (idom_it == info.index.end()) {
break;
}
runner = info.idom[idom_it->second];
}
}
}
return info;
}
SeedState MergeSeedState(const SeedState& lhs, const SeedState& rhs) {
if (lhs == rhs) {
return lhs;
}
return {SeedStateKind::Conflict, nullptr};
}
SeedState TransferSeedState(const SeedState& in, BasicBlock* block,
const PromotionCandidate& candidate,
const memutils::EscapeSummary& escapes) {
SeedState state = in;
if (!block) {
return state;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* call = dyncast<CallInst>(inst)) {
if (memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
state = {SeedStateKind::Unavailable, nullptr};
}
continue;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key) ||
memutils::MayAliasConservatively(key, candidate.key.address)) {
state = {SeedStateKind::Unavailable, nullptr};
}
continue;
}
auto* store = dyncast<StoreInst>(inst);
if (!store) {
continue;
}
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state = {SeedStateKind::Unavailable, nullptr};
continue;
}
if (!memutils::MayAliasConservatively(key, candidate.key.address)) {
continue;
}
if (key == candidate.key.address && store->GetValue()->GetType() == candidate.value_type) {
state = {SeedStateKind::Available, store};
} else {
state = {SeedStateKind::Unavailable, nullptr};
}
}
return state;
}
StoreInst* FindSeedStoreInPreheader(const Loop& loop,
const PromotionCandidate& candidate,
const memutils::EscapeSummary& escapes) {
auto* preheader = loop.preheader;
if (!preheader) {
return nullptr;
}
StoreInst* seed = nullptr;
for (const auto& inst_ptr : preheader->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* call = dyncast<CallInst>(inst)) {
if (memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
seed = nullptr;
}
continue;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key) ||
memutils::MayAliasConservatively(key, candidate.key.address)) {
seed = nullptr;
}
continue;
}
auto* store = dyncast<StoreInst>(inst);
if (!store) {
continue;
}
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
seed = nullptr;
continue;
}
if (!memutils::MayAliasConservatively(key, candidate.key.address)) {
continue;
}
if (key == candidate.key.address && store->GetValue()->GetType() == candidate.value_type) {
seed = store;
} else {
seed = nullptr;
}
}
return seed;
}
StoreInst* FindReachingSeedStoreAtLoopEntry(Function& function, const Loop& loop,
const PromotionCandidate& candidate,
const memutils::EscapeSummary& escapes) {
auto* preheader = loop.preheader;
if (!preheader) {
return nullptr;
}
const auto blocks = CollectReachableBlocks(function);
std::unordered_map<BasicBlock*, SeedState> in_state;
std::unordered_map<BasicBlock*, SeedState> out_state;
for (auto* block : blocks) {
in_state[block] = {SeedStateKind::Unavailable, nullptr};
out_state[block] = {SeedStateKind::Unavailable, nullptr};
}
bool changed = true;
while (changed) {
changed = false;
for (auto* block : blocks) {
SeedState merged{SeedStateKind::Unavailable, nullptr};
bool first_pred = true;
for (auto* pred : block->GetPredecessors()) {
auto it = out_state.find(pred);
if (it == out_state.end()) {
continue;
}
if (first_pred) {
merged = it->second;
first_pred = false;
} else {
merged = MergeSeedState(merged, it->second);
}
}
if (block == function.GetEntryBlock() && first_pred) {
merged = {SeedStateKind::Unavailable, nullptr};
}
SeedState next_out = TransferSeedState(merged, block, candidate, escapes);
if (merged != in_state[block] || next_out != out_state[block]) {
in_state[block] = merged;
out_state[block] = next_out;
changed = true;
}
}
}
const auto it = out_state.find(preheader);
if (it == out_state.end() || it->second.kind != SeedStateKind::Available) {
return nullptr;
}
return it->second.store;
}
bool ExitBlocksArePromotable(const Loop& loop) {
for (auto* exit : loop.exit_blocks) {
if (!exit) {
return false;
}
for (auto* pred : exit->GetPredecessors()) {
if (!loop.Contains(pred)) {
return false;
}
}
}
return !loop.exit_blocks.empty();
}
bool IsSafeToPromoteCandidate(const Loop& loop, const PromotionCandidate& candidate,
const std::vector<loopmem::MemoryAccessInfo>& accesses,
int iv_stride, const DominatorInfo& dom_info) {
if (!candidate.seed_store || !candidate.canonical_ptr || !candidate.initial_value) {
return false;
}
if (loop.parent != nullptr && candidate.seed_store->GetParent() != loop.preheader) {
return false;
}
if (!DominatesBlock(dom_info, candidate.seed_store->GetParent(), loop.preheader)) {
return false;
}
auto* ptr_inst = dyncast<Instruction>(candidate.canonical_ptr);
if (ptr_inst &&
(loop.Contains(ptr_inst->GetParent()) ||
!DominatesBlock(dom_info, ptr_inst->GetParent(), loop.preheader))) {
return false;
}
if (!ExitBlocksArePromotable(loop)) {
return false;
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* call = dyncast<CallInst>(inst_ptr.get());
if (!call) {
continue;
}
if (memutils::CallMayReadRoot(call->GetCallee(), candidate.pointer_info.root_kind) ||
memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
return false;
}
}
}
for (const auto& access : accesses) {
if (!loopmem::MayAliasSameIteration(candidate.pointer_info, access.ptr) &&
!loopmem::HasCrossIterationDependence(candidate.pointer_info, access.ptr, iv_stride)) {
continue;
}
if (!access.ptr.exact_key_valid || !(access.ptr.exact_key == candidate.key.address)) {
return false;
}
if (isa<MemsetInst>(access.inst)) {
return false;
}
if (auto* load = dyncast<LoadInst>(access.inst)) {
if (load->GetType() != candidate.value_type) {
return false;
}
continue;
}
if (auto* store = dyncast<StoreInst>(access.inst)) {
if (store->GetValue()->GetType() != candidate.value_type) {
return false;
}
continue;
}
return false;
}
return true;
}
void InsertPhiNodes(const Loop& loop, PromotionCandidate& candidate,
const DominatorInfo& dom_info, Function& function) {
std::queue<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> queued;
for (auto* block : candidate.def_blocks) {
worklist.push(block);
queued.insert(block);
}
while (!worklist.empty()) {
auto* block = worklist.front();
worklist.pop();
auto frontier_it = dom_info.dominance_frontier.find(block);
if (frontier_it == dom_info.dominance_frontier.end()) {
continue;
}
for (auto* frontier_block : frontier_it->second) {
if (!loop.Contains(frontier_block)) {
continue;
}
if (candidate.phis.find(frontier_block) != candidate.phis.end()) {
continue;
}
auto* phi = frontier_block->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(frontier_block), candidate.value_type, nullptr,
looputils::NextSyntheticName(function, "lmp.phi."));
candidate.phis[frontier_block] = phi;
if (candidate.def_blocks.insert(frontier_block).second && queued.insert(frontier_block).second) {
worklist.push(frontier_block);
}
}
}
}
void RenameCandidateInLoop(
BasicBlock* block, const Loop& loop, PromotionCandidate& candidate,
const DominatorInfo& dom_info, std::vector<Value*>& stack,
std::unordered_map<BasicBlock*, Value*>& block_out) {
if (!block || !loop.Contains(block)) {
return;
}
size_t pushed = 0;
PhiInst* block_phi = nullptr;
auto phi_it = candidate.phis.find(block);
if (phi_it != candidate.phis.end()) {
block_phi = phi_it->second;
stack.push_back(block_phi);
++pushed;
}
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == block_phi) {
continue;
}
if (auto* load = dyncast<LoadInst>(inst)) {
auto it = std::find(candidate.loads.begin(), candidate.loads.end(), load);
if (it == candidate.loads.end()) {
continue;
}
load->ReplaceAllUsesWith(stack.back());
to_remove.push_back(load);
continue;
}
auto* store = dyncast<StoreInst>(inst);
if (!store) {
continue;
}
auto it = std::find(candidate.stores.begin(), candidate.stores.end(), store);
if (it == candidate.stores.end()) {
continue;
}
stack.push_back(store->GetValue());
++pushed;
to_remove.push_back(store);
}
block_out[block] = stack.back();
for (auto* succ : block->GetSuccessors()) {
if (!loop.Contains(succ)) {
continue;
}
auto succ_phi_it = candidate.phis.find(succ);
if (succ_phi_it == candidate.phis.end()) {
continue;
}
succ_phi_it->second->AddIncoming(stack.back(), block);
}
auto child_it = dom_info.dom_tree_children.find(block);
if (child_it != dom_info.dom_tree_children.end()) {
for (auto* child : child_it->second) {
RenameCandidateInLoop(child, loop, candidate, dom_info, stack, block_out);
}
}
for (auto* inst : to_remove) {
if (inst->GetParent() == block) {
block->EraseInstruction(inst);
}
}
while (pushed > 0) {
stack.pop_back();
--pushed;
}
}
void InsertExitStores(Function& function, const Loop& loop, PromotionCandidate& candidate,
const std::unordered_map<BasicBlock*, Value*>& block_out) {
std::unordered_set<BasicBlock*> seen;
for (auto* exit : loop.exit_blocks) {
if (!exit || !seen.insert(exit).second) {
continue;
}
std::vector<BasicBlock*> preds;
preds.reserve(exit->GetPredecessors().size());
for (auto* pred : exit->GetPredecessors()) {
if (loop.Contains(pred)) {
preds.push_back(pred);
}
}
if (preds.empty()) {
continue;
}
Value* final_value = nullptr;
auto insert_index = looputils::GetFirstNonPhiIndex(exit);
if (preds.size() == 1) {
auto it = block_out.find(preds.front());
if (it == block_out.end()) {
continue;
}
final_value = it->second;
} else {
auto* phi = exit->Insert<PhiInst>(insert_index, candidate.value_type, nullptr,
looputils::NextSyntheticName(function, "lmp.exit."));
++insert_index;
for (auto* pred : preds) {
auto it = block_out.find(pred);
if (it != block_out.end()) {
phi->AddIncoming(it->second, pred);
}
}
final_value = phi;
}
exit->Insert<StoreInst>(insert_index, final_value, candidate.canonical_ptr, nullptr);
}
}
bool PromoteCandidate(Function& function, const Loop& loop, PromotionCandidate& candidate,
const DominatorInfo& dom_info) {
if (!candidate.seed_store || !candidate.initial_value) {
return false;
}
InsertPhiNodes(loop, candidate, dom_info, function);
auto header_phi_it = candidate.phis.find(loop.header);
if (header_phi_it != candidate.phis.end()) {
header_phi_it->second->AddIncoming(candidate.initial_value, loop.preheader);
}
std::vector<Value*> stack{candidate.initial_value};
std::unordered_map<BasicBlock*, Value*> block_out;
RenameCandidateInLoop(loop.header, loop, candidate, dom_info, stack, block_out);
InsertExitStores(function, loop, candidate, block_out);
return true;
}
std::vector<PromotionCandidate> CollectCandidates(
const Loop& loop, const std::vector<loopmem::MemoryAccessInfo>& accesses,
const memutils::EscapeSummary& escapes, int iv_stride, Function& function,
const DominatorInfo& dom_info) {
constexpr std::size_t kMaxLoopAccesses = 64;
if (accesses.size() > kMaxLoopAccesses) {
return {};
}
std::unordered_map<CandidateKey, PromotionCandidate, CandidateKeyHash> groups;
for (const auto& access : accesses) {
if (!access.ptr.exact_key_valid || !access.ptr.invariant_address) {
continue;
}
if (!access.is_read && !access.is_write) {
continue;
}
std::shared_ptr<Type> value_type;
if (auto* load = dyncast<LoadInst>(access.inst)) {
value_type = load->GetType();
} else if (auto* store = dyncast<StoreInst>(access.inst)) {
value_type = store->GetValue()->GetType();
} else {
continue;
}
if (!IsScalarPromotableType(value_type)) {
continue;
}
CandidateKey key{access.ptr.exact_key,
reinterpret_cast<std::uintptr_t>(value_type.get())};
auto& candidate = groups[key];
candidate.key = key;
candidate.value_type = value_type;
candidate.pointer_info = access.ptr;
if (auto* load = dyncast<LoadInst>(access.inst)) {
candidate.loads.push_back(load);
} else if (auto* store = dyncast<StoreInst>(access.inst)) {
candidate.stores.push_back(store);
candidate.def_blocks.insert(store->GetParent());
}
}
std::vector<PromotionCandidate> candidates;
candidates.reserve(groups.size());
for (auto& [key, candidate] : groups) {
if (candidate.stores.empty()) {
continue;
}
candidate.seed_store = FindSeedStoreInPreheader(loop, candidate, escapes);
if (!candidate.seed_store && loop.parent == nullptr) {
candidate.seed_store =
FindReachingSeedStoreAtLoopEntry(function, loop, candidate, escapes);
}
if (!candidate.seed_store) {
continue;
}
candidate.initial_value = candidate.seed_store->GetValue();
candidate.canonical_ptr = candidate.seed_store->GetPtr();
if (!IsSafeToPromoteCandidate(loop, candidate, accesses, iv_stride, dom_info)) {
continue;
}
candidates.push_back(std::move(candidate));
}
std::sort(candidates.begin(), candidates.end(),
[](const PromotionCandidate& lhs, const PromotionCandidate& rhs) {
return lhs.EstimatedBenefit() > rhs.EstimatedBenefit();
});
return candidates;
}
bool PromoteLoopMemory(Function& function, const Loop& loop,
const DominatorInfo& dom_info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost() ||
!ShouldAnalyzeLoop(loop)) {
return false;
}
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
int iv_stride = 1;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
iv = induction_var.phi;
iv_stride = induction_var.stride;
break;
}
}
bool changed = false;
while (true) {
const auto escapes = memutils::AnalyzeEscapes(function);
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
auto candidates = CollectCandidates(loop, accesses, escapes, iv_stride, function, dom_info);
if (candidates.empty()) {
break;
}
if (!PromoteCandidate(function, loop, candidates.front(), dom_info)) {
break;
}
changed = true;
}
return changed;
}
bool RunLoopMemoryPromotionOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
if (!ShouldAnalyzeFunction(function)) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool cfg_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
if (preheader != old_preheader) {
changed = true;
cfg_changed = true;
break;
}
}
if (cfg_changed) {
continue;
}
auto dom_info = BuildDominatorInfo(function);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
local_changed |= PromoteLoopMemory(function, *loop, dom_info);
}
changed |= local_changed;
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopMemoryPromotion(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopMemoryPromotionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,506 @@
#pragma once
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <vector>
namespace ir::loopmem {
struct SimpleInductionVar {
PhiInst* phi = nullptr;
Value* start = nullptr;
Value* latch_value = nullptr;
BasicBlock* latch = nullptr;
int stride = 0;
};
inline bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
PhiInst* phi, SimpleInductionVar& info) {
if (!phi || !preheader || phi->GetParent() != loop.header ||
!phi->GetType()->IsInt32() || phi->GetNumIncomings() != 2 ||
loop.latches.size() != 1) {
return false;
}
auto* latch = loop.latches.front();
const int preheader_index = looputils::GetPhiIncomingIndex(phi, preheader);
const int latch_index = looputils::GetPhiIncomingIndex(phi, latch);
if (preheader_index < 0 || latch_index < 0) {
return false;
}
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
if (!step_inst || step_inst->GetParent() != latch) {
return false;
}
int stride = 0;
if (step_inst->GetOpcode() == Opcode::Add) {
if (step_inst->GetLhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else if (step_inst->GetRhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else {
return false;
}
} else if (step_inst->GetOpcode() == Opcode::Sub) {
if (step_inst->GetLhs() != phi) {
return false;
}
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = -delta->GetValue();
} else {
return false;
}
if (stride == 0) {
return false;
}
info.phi = phi;
info.start = phi->GetIncomingValue(preheader_index);
info.latch_value = phi->GetIncomingValue(latch_index);
info.latch = latch;
info.stride = stride;
return true;
}
inline bool GetCanonicalLoopBlocks(const Loop& loop, BasicBlock*& body,
BasicBlock*& exit) {
body = nullptr;
exit = nullptr;
if (!loop.header || loop.latches.size() != 1 || loop.block_list.size() != 2) {
return false;
}
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
if (!condbr) {
return false;
}
auto* then_block = condbr->GetThenBlock();
auto* else_block = condbr->GetElseBlock();
const bool then_in_loop = loop.Contains(then_block);
const bool else_in_loop = loop.Contains(else_block);
if (then_in_loop == else_in_loop) {
return false;
}
body = then_in_loop ? then_block : else_block;
exit = then_in_loop ? else_block : then_block;
if (!body || !exit || body != loop.latches.front() ||
body->GetSuccessors().size() != 1 || body->GetSuccessors().front() != loop.header) {
return false;
}
return true;
}
struct AffineExpr {
bool valid = false;
Value* var = nullptr;
std::int64_t coeff = 0;
std::int64_t constant = 0;
};
inline AffineExpr MakeConst(std::int64_t value) {
return {true, nullptr, 0, value};
}
inline AffineExpr Scale(const AffineExpr& expr, std::int64_t factor) {
if (!expr.valid) {
return {};
}
return {true, expr.var, expr.coeff * factor, expr.constant * factor};
}
inline AffineExpr Combine(const AffineExpr& lhs, const AffineExpr& rhs, int sign) {
if (!lhs.valid || !rhs.valid) {
return {};
}
if (lhs.var != nullptr && rhs.var != nullptr && lhs.var != rhs.var) {
return {};
}
AffineExpr out;
out.valid = true;
out.var = lhs.var ? lhs.var : rhs.var;
out.coeff = lhs.coeff + sign * rhs.coeff;
out.constant = lhs.constant + sign * rhs.constant;
return out;
}
inline AffineExpr AnalyzeAffine(Value* value, PhiInst* iv, const Loop& loop) {
if (!value) {
return {};
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return MakeConst(ci->GetValue());
}
if (value == iv) {
return {true, iv, 1, 0};
}
if (looputils::IsLoopInvariantValue(loop, value)) {
return {};
}
if (auto* zext = dyncast<ZextInst>(value)) {
return AnalyzeAffine(zext->GetValue(), iv, loop);
}
auto* inst = dyncast<Instruction>(value);
if (!inst) {
return {};
}
switch (inst->GetOpcode()) {
case Opcode::Add:
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), +1);
case Opcode::Sub:
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), -1);
case Opcode::Mul: {
auto* bin = static_cast<BinaryInst*>(inst);
auto lhs = AnalyzeAffine(bin->GetLhs(), iv, loop);
auto rhs = AnalyzeAffine(bin->GetRhs(), iv, loop);
if (lhs.valid && lhs.var == nullptr && rhs.valid) {
return Scale(rhs, lhs.constant);
}
if (rhs.valid && rhs.var == nullptr && lhs.valid) {
return Scale(lhs, rhs.constant);
}
return {};
}
case Opcode::Neg:
return Scale(AnalyzeAffine(static_cast<UnaryInst*>(inst)->GetOprd(), iv, loop), -1);
default:
return {};
}
}
struct PointerInfo {
Value* base = nullptr;
AffineExpr byte_offset;
bool invariant_address = false;
bool distinct_root = false;
bool argument_root = false;
bool readonly_root = false;
bool exact_key_valid = false;
memutils::PointerRootKind root_kind = memutils::PointerRootKind::Unknown;
memutils::AddressKey exact_key;
int access_size = 0;
};
inline Value* StripPointerBase(Value* pointer) {
auto* value = pointer;
while (auto* gep = dyncast<GetElementPtrInst>(value)) {
value = gep->GetPointer();
}
return value;
}
inline std::shared_ptr<Type> AdvanceGEPType(std::shared_ptr<Type> current) {
if (current && current->IsArray()) {
return current->GetElementType();
}
return current;
}
inline PointerInfo AnalyzePointer(Value* pointer, PhiInst* iv, const Loop& loop,
int access_size,
const memutils::EscapeSummary* escapes = nullptr) {
PointerInfo info;
info.access_size = access_size;
info.base = StripPointerBase(pointer);
info.root_kind = memutils::ClassifyRoot(info.base, escapes);
info.argument_root = info.root_kind == memutils::PointerRootKind::Param;
info.readonly_root = info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
info.distinct_root = info.root_kind == memutils::PointerRootKind::Local ||
info.root_kind == memutils::PointerRootKind::Global ||
info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
info.exact_key_valid =
escapes != nullptr && memutils::BuildExactAddressKey(pointer, escapes, info.exact_key);
info.invariant_address = looputils::IsLoopInvariantValue(loop, pointer);
if (!dyncast<GetElementPtrInst>(pointer)) {
info.byte_offset = MakeConst(0);
return info;
}
auto* gep = static_cast<GetElementPtrInst*>(pointer);
std::shared_ptr<Type> current = gep->GetSourceType();
AffineExpr total = MakeConst(0);
bool all_indices_loop_invariant = looputils::IsLoopInvariantValue(loop, gep->GetPointer());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
auto* index = gep->GetIndex(i);
all_indices_loop_invariant &= looputils::IsLoopInvariantValue(loop, index);
const std::int64_t stride = current ? current->GetSize() : 0;
auto term = AnalyzeAffine(index, iv, loop);
if (!term.valid) {
total = {};
} else if (total.valid) {
total = Combine(total, Scale(term, stride), +1);
}
current = AdvanceGEPType(current);
}
info.invariant_address = all_indices_loop_invariant;
info.byte_offset = total;
return info;
}
struct MemoryAccessInfo {
Instruction* inst = nullptr;
Value* pointer = nullptr;
PointerInfo ptr;
bool is_read = false;
bool is_write = false;
};
inline std::vector<MemoryAccessInfo> CollectMemoryAccesses(const Loop& loop,
PhiInst* iv,
const memutils::EscapeSummary* escapes =
nullptr) {
std::vector<MemoryAccessInfo> accesses;
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dyncast<LoadInst>(inst)) {
accesses.push_back(
{inst, load->GetPtr(),
AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes),
true,
false});
} else if (auto* store = dyncast<StoreInst>(inst)) {
accesses.push_back({inst, store->GetPtr(),
AnalyzePointer(store->GetPtr(), iv, loop,
store->GetValue()->GetType()->GetSize(), escapes),
false, true});
} else if (auto* memset = dyncast<MemsetInst>(inst)) {
accesses.push_back(
{inst, memset->GetDest(),
AnalyzePointer(memset->GetDest(), iv, loop, 1, escapes), false, true});
}
}
}
return accesses;
}
inline bool SameAffineAddress(const PointerInfo& lhs, const PointerInfo& rhs) {
return lhs.base == rhs.base && lhs.byte_offset.valid && rhs.byte_offset.valid &&
lhs.byte_offset.var == rhs.byte_offset.var &&
lhs.byte_offset.coeff == rhs.byte_offset.coeff &&
lhs.byte_offset.constant == rhs.byte_offset.constant;
}
inline bool MayAliasSameIteration(const PointerInfo& lhs, const PointerInfo& rhs) {
if (lhs.exact_key_valid && rhs.exact_key_valid) {
return memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key);
}
if (!lhs.base || !rhs.base) {
return true;
}
if (lhs.base != rhs.base) {
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
return false;
}
return true;
}
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
return true;
}
if (lhs.byte_offset.var != rhs.byte_offset.var) {
return true;
}
if (lhs.byte_offset.coeff != rhs.byte_offset.coeff) {
return true;
}
const auto diff = std::llabs(lhs.byte_offset.constant - rhs.byte_offset.constant);
const auto overlap = std::min(lhs.access_size, rhs.access_size);
return diff < overlap;
}
inline bool HasCrossIterationDependence(const PointerInfo& lhs, const PointerInfo& rhs,
int iv_stride) {
if (lhs.exact_key_valid && rhs.exact_key_valid &&
!memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key)) {
return false;
}
if (!lhs.base || !rhs.base) {
return true;
}
if (lhs.base != rhs.base) {
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
return false;
}
return true;
}
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
return true;
}
if (lhs.byte_offset.var != rhs.byte_offset.var) {
return true;
}
const auto lhs_step = lhs.byte_offset.coeff * iv_stride;
const auto rhs_step = rhs.byte_offset.coeff * iv_stride;
if (lhs_step == 0 && rhs_step == 0) {
return MayAliasSameIteration(lhs, rhs);
}
if (lhs_step == rhs_step && lhs_step != 0) {
const auto diff = rhs.byte_offset.constant - lhs.byte_offset.constant;
return diff != 0 && diff % std::llabs(lhs_step) == 0;
}
return true;
}
inline bool CallMayWritePointer(Function* callee, const PointerInfo& ptr) {
if (ptr.readonly_root) {
return false;
}
return memutils::CallMayWriteRoot(callee, ptr.root_kind);
}
inline bool IsSafeInvariantLoadToHoist(const Loop& loop, LoadInst* load, PhiInst* iv,
int iv_stride,
const std::vector<MemoryAccessInfo>& accesses,
const memutils::EscapeSummary* escapes = nullptr) {
if (!load) {
return false;
}
auto ptr = AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes);
if (!ptr.invariant_address) {
return false;
}
if (ptr.readonly_root) {
return true;
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == load) {
continue;
}
if (auto* call = dyncast<CallInst>(inst)) {
if (CallMayWritePointer(call->GetCallee(), ptr)) {
return false;
}
}
}
}
for (const auto& access : accesses) {
if (access.inst == load || !access.is_write) {
continue;
}
if (MayAliasSameIteration(ptr, access.ptr)) {
return false;
}
if (HasCrossIterationDependence(ptr, access.ptr, iv_stride)) {
return false;
}
}
return true;
}
inline bool HasScalarDependenceAcrossCut(const std::vector<Instruction*>& first_group,
const std::unordered_set<Instruction*>& second_set) {
for (auto* inst : first_group) {
if (!inst || inst->IsVoid()) {
continue;
}
for (const auto& use : inst->GetUses()) {
auto* user = dyncast<Instruction>(use.GetUser());
if (user && second_set.find(user) != second_set.end()) {
return true;
}
}
}
return false;
}
inline bool HasMemoryDependenceAcrossCut(const std::vector<MemoryAccessInfo>& accesses,
const std::unordered_set<Instruction*>& first_set,
const std::unordered_set<Instruction*>& second_set,
int iv_stride) {
for (const auto& lhs : accesses) {
if (first_set.find(lhs.inst) == first_set.end()) {
continue;
}
for (const auto& rhs : accesses) {
if (second_set.find(rhs.inst) == second_set.end()) {
continue;
}
if (!lhs.is_write && !rhs.is_write) {
continue;
}
if (MayAliasSameIteration(lhs.ptr, rhs.ptr) ||
HasCrossIterationDependence(lhs.ptr, rhs.ptr, iv_stride)) {
return true;
}
}
}
return false;
}
inline bool IsLoopParallelizable(const Loop& loop, PhiInst* iv, int iv_stride,
const std::vector<MemoryAccessInfo>& accesses) {
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (phi != iv) {
return false;
}
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* call = dyncast<CallInst>(inst)) {
auto* callee = call->GetCallee();
if (callee == nullptr || callee->HasObservableSideEffects() || callee->IsRecursive()) {
return false;
}
for (const auto& access : accesses) {
if (CallMayWritePointer(callee, access.ptr)) {
return false;
}
}
continue;
}
if (dyncast<MemsetInst>(inst)) {
return false;
}
}
}
for (std::size_t i = 0; i < accesses.size(); ++i) {
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
if (!accesses[i].is_write && !accesses[j].is_write) {
continue;
}
if (HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr, iv_stride)) {
return false;
}
}
}
return true;
}
} // namespace ir::loopmem

@ -0,0 +1,440 @@
#pragma once
#include "ir/Analysis.h"
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir::looputils {
inline Instruction* GetTerminator(BasicBlock* block) {
if (!block || block->GetInstructions().empty()) {
return nullptr;
}
auto* inst = block->GetInstructions().back().get();
return inst && inst->IsTerminator() ? inst : nullptr;
}
inline std::size_t GetTerminatorIndex(BasicBlock* block) {
if (!block) {
return 0;
}
const auto size = block->GetInstructions().size();
if (!block->HasTerminator()) {
return size;
}
return size == 0 ? 0 : size - 1;
}
inline std::size_t GetFirstNonPhiIndex(BasicBlock* block) {
if (!block) {
return 0;
}
std::size_t index = 0;
for (const auto& inst_ptr : block->GetInstructions()) {
if (!dyncast<PhiInst>(inst_ptr.get())) {
break;
}
++index;
}
return index;
}
inline std::string NextSyntheticName(Function& function, const std::string& prefix) {
static std::unordered_map<Function*, int> counters;
const int id = ++counters[&function];
return "%" + prefix + std::to_string(id);
}
inline std::string NextSyntheticBlockName(Function& function,
const std::string& prefix) {
static std::unordered_map<Function*, int> counters;
const int id = ++counters[&function];
return prefix + "." + std::to_string(id);
}
inline ConstantInt* ConstInt(int value) {
return new ConstantInt(Type::GetInt32Type(), value);
}
inline int GetPhiIncomingIndex(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) {
return -1;
}
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return i;
}
}
return -1;
}
inline bool ReplacePhiIncoming(PhiInst* phi, BasicBlock* old_block,
Value* new_value, BasicBlock* new_block) {
if (!phi || !old_block || !new_value || !new_block) {
return false;
}
const int index = GetPhiIncomingIndex(phi, old_block);
if (index < 0) {
return false;
}
phi->SetOperand(static_cast<std::size_t>(2 * index), new_value);
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_block);
return true;
}
inline bool RedirectSuccessorEdge(BasicBlock* pred, BasicBlock* old_succ,
BasicBlock* new_succ) {
auto* terminator = GetTerminator(pred);
if (!terminator || !old_succ || !new_succ) {
return false;
}
if (auto* br = dyncast<UncondBrInst>(terminator)) {
if (br->GetDest() != old_succ) {
return false;
}
br->SetOperand(0, new_succ);
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
bool changed = false;
if (condbr->GetThenBlock() == old_succ) {
condbr->SetOperand(1, new_succ);
changed = true;
}
if (condbr->GetElseBlock() == old_succ) {
condbr->SetOperand(2, new_succ);
changed = true;
}
if (!changed) {
return false;
}
} else {
return false;
}
pred->RemoveSuccessor(old_succ);
pred->AddSuccessor(new_succ);
return true;
}
inline Instruction* MoveInstructionBeforeTerminator(Instruction* inst,
BasicBlock* dest) {
if (!inst || !dest) {
return nullptr;
}
auto* src = inst->GetParent();
if (!src || src == dest) {
return inst;
}
auto& src_insts = src->GetInstructions();
auto src_it = std::find_if(src_insts.begin(), src_insts.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == inst;
});
if (src_it == src_insts.end()) {
return nullptr;
}
auto moved = std::move(*src_it);
src_insts.erase(src_it);
moved->SetParent(dest);
auto& dest_insts = dest->GetInstructions();
auto insert_it = dest_insts.begin() +
static_cast<long long>(GetTerminatorIndex(dest));
auto* ptr = moved.get();
dest_insts.insert(insert_it, std::move(moved));
return ptr;
}
inline bool IsLoopInvariantValue(const Loop& loop, Value* value) {
auto* inst = dyncast<Instruction>(value);
return inst == nullptr || !loop.Contains(inst->GetParent());
}
inline Value* RemapValue(const std::unordered_map<Value*, Value*>& remap,
Value* value) {
auto it = remap.find(value);
return it == remap.end() ? value : it->second;
}
inline bool IsCloneableInstruction(const Instruction* inst) {
if (!inst || inst->IsTerminator() || inst->GetOpcode() == Opcode::Phi) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Alloca:
case Opcode::Load:
case Opcode::Store:
case Opcode::Memset:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Call:
return true;
default:
return false;
}
}
inline Instruction* CloneInstruction(Function& function, Instruction* inst,
BasicBlock* dest,
std::unordered_map<Value*, Value*>& remap,
const std::string& prefix) {
if (!inst || !dest || !IsCloneableInstruction(inst)) {
return nullptr;
}
const auto insert_index = GetTerminatorIndex(dest);
const auto name = inst->IsVoid() ? std::string()
: NextSyntheticName(function, prefix);
auto remap_operand = [&](Value* value) { return RemapValue(remap, value); };
auto remember = [&](Instruction* clone) {
if (clone && !inst->IsVoid()) {
remap[inst] = clone;
}
return clone;
};
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<BinaryInst*>(inst);
return remember(dest->Insert<BinaryInst>(
insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(bin->GetLhs()), remap_operand(bin->GetRhs()), nullptr,
name));
}
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF: {
auto* un = static_cast<UnaryInst*>(inst);
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(),
inst->GetType(),
remap_operand(un->GetOprd()),
nullptr, name));
}
case Opcode::Alloca: {
auto* alloca = static_cast<AllocaInst*>(inst);
return remember(dest->Insert<AllocaInst>(insert_index,
alloca->GetAllocatedType(),
nullptr, name));
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
remap_operand(load->GetPtr()),
nullptr, name));
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return dest->Insert<StoreInst>(insert_index,
remap_operand(store->GetValue()),
remap_operand(store->GetPtr()), nullptr);
}
case Opcode::Memset: {
auto* memset = static_cast<MemsetInst*>(inst);
return dest->Insert<MemsetInst>(insert_index,
remap_operand(memset->GetDest()),
remap_operand(memset->GetValue()),
remap_operand(memset->GetLength()),
remap_operand(memset->GetIsVolatile()),
nullptr);
}
case Opcode::GetElementPtr: {
auto* gep = static_cast<GetElementPtrInst*>(inst);
std::vector<Value*> indices;
indices.reserve(gep->GetNumIndices());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
indices.push_back(remap_operand(gep->GetIndex(i)));
}
return remember(dest->Insert<GetElementPtrInst>(
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()),
indices, nullptr, name));
}
case Opcode::Zext: {
auto* zext = static_cast<ZextInst*>(inst);
return remember(dest->Insert<ZextInst>(insert_index,
remap_operand(zext->GetValue()),
inst->GetType(), nullptr, name));
}
case Opcode::Call: {
auto* call = static_cast<CallInst*>(inst);
std::vector<Value*> args;
const auto original_args = call->GetArguments();
args.reserve(original_args.size());
for (auto* arg : original_args) {
args.push_back(remap_operand(arg));
}
return remember(dest->Insert<CallInst>(insert_index, call->GetCallee(),
args, nullptr, name));
}
case Opcode::Phi:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Return:
case Opcode::Unreachable:
break;
}
return nullptr;
}
inline BasicBlock* EnsurePreheader(Function& function, Loop& loop) {
if (loop.preheader) {
return loop.preheader;
}
auto* header = loop.header;
if (!header) {
return nullptr;
}
std::vector<BasicBlock*> outside_preds;
for (auto* pred : header->GetPredecessors()) {
if (!loop.Contains(pred)) {
outside_preds.push_back(pred);
}
}
if (outside_preds.empty()) {
return nullptr;
}
if (outside_preds.size() == 1 &&
outside_preds.front()->GetSuccessors().size() == 1) {
loop.preheader = outside_preds.front();
return loop.preheader;
}
auto* preheader = function.CreateBlock(
NextSyntheticBlockName(function, header->GetName() + ".preheader"));
std::size_t phi_insert_index = 0;
for (const auto& inst_ptr : header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
std::vector<int> outside_incomings;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (!loop.Contains(phi->GetIncomingBlock(i))) {
outside_incomings.push_back(i);
}
}
if (outside_incomings.empty()) {
continue;
}
Value* merged_value = nullptr;
if (outside_incomings.size() == 1) {
merged_value = phi->GetIncomingValue(outside_incomings.front());
} else {
auto new_phi = std::make_unique<PhiInst>(
phi->GetType(), nullptr,
NextSyntheticName(function, "preheader.phi."));
auto* new_phi_ptr = new_phi.get();
new_phi_ptr->SetParent(preheader);
auto& preheader_insts = preheader->GetInstructions();
preheader_insts.insert(preheader_insts.begin() +
static_cast<long long>(phi_insert_index),
std::move(new_phi));
++phi_insert_index;
for (int incoming_index : outside_incomings) {
new_phi_ptr->AddIncoming(phi->GetIncomingValue(incoming_index),
phi->GetIncomingBlock(incoming_index));
}
merged_value = new_phi_ptr;
}
for (auto it = outside_incomings.rbegin(); it != outside_incomings.rend();
++it) {
phi->RemoveOperand(static_cast<std::size_t>(2 * *it + 1));
phi->RemoveOperand(static_cast<std::size_t>(2 * *it));
}
phi->AddIncoming(merged_value, preheader);
}
preheader->Append<UncondBrInst>(header, nullptr);
preheader->AddSuccessor(header);
header->AddPredecessor(preheader);
for (auto* pred : outside_preds) {
if (RedirectSuccessorEdge(pred, header, preheader)) {
preheader->AddPredecessor(pred);
header->RemovePredecessor(pred);
}
}
loop.preheader = preheader;
return preheader;
}
} // namespace ir::looputils

@ -0,0 +1,295 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <cstdlib>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct InductionVarInfo {
PhiInst* phi = nullptr;
Value* start = nullptr;
BasicBlock* latch = nullptr;
int stride = 0;
};
Value* BuildMulValue(Function& function, BasicBlock* block, Value* lhs, Value* rhs,
const std::string& prefix) {
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
if (lhs_const->GetValue() == 0) {
return looputils::ConstInt(0);
}
if (lhs_const->GetValue() == 1) {
return rhs;
}
}
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
if (rhs_const->GetValue() == 0) {
return looputils::ConstInt(0);
}
if (rhs_const->GetValue() == 1) {
return lhs;
}
}
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
return looputils::ConstInt(lhs_const->GetValue() * rhs_const->GetValue());
}
}
return block->Insert<BinaryInst>(looputils::GetTerminatorIndex(block), Opcode::Mul,
Type::GetInt32Type(), lhs, rhs, nullptr,
looputils::NextSyntheticName(function, prefix));
}
Value* BuildScaledValue(Function& function, BasicBlock* block, Value* base,
int factor, const std::string& prefix) {
if (factor == 0) {
return looputils::ConstInt(0);
}
if (factor == 1) {
return base;
}
if (auto* base_const = dyncast<ConstantInt>(base)) {
return looputils::ConstInt(base_const->GetValue() * factor);
}
if (factor == -1) {
return block->Insert<UnaryInst>(looputils::GetTerminatorIndex(block), Opcode::Neg,
base->GetType(), base, nullptr,
looputils::NextSyntheticName(function, prefix));
}
return BuildMulValue(function, block, base, looputils::ConstInt(factor), prefix);
}
bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
PhiInst* phi, InductionVarInfo& info) {
if (!phi || !phi->GetType() || !phi->GetType()->IsInt32() ||
phi->GetParent() != loop.header || phi->GetNumIncomings() != 2 ||
loop.latches.size() != 1) {
return false;
}
auto* latch = loop.latches.front();
int preheader_index = -1;
int latch_index = -1;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == preheader) {
preheader_index = i;
} else if (phi->GetIncomingBlock(i) == latch) {
latch_index = i;
}
}
if (preheader_index < 0 || latch_index < 0) {
return false;
}
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
if (!step_inst || step_inst->GetParent() != latch) {
return false;
}
int stride = 0;
if (step_inst->GetOpcode() == Opcode::Add) {
if (step_inst->GetLhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else if (step_inst->GetRhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else {
return false;
}
} else if (step_inst->GetOpcode() == Opcode::Sub) {
if (step_inst->GetLhs() != phi) {
return false;
}
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = -delta->GetValue();
} else {
return false;
}
if (stride == 0) {
return false;
}
info.phi = phi;
info.start = phi->GetIncomingValue(preheader_index);
info.latch = latch;
info.stride = stride;
return true;
}
bool IsMulCandidate(const Loop& loop, Instruction* inst, PhiInst* phi, Value*& factor) {
auto* mul = dyncast<BinaryInst>(inst);
if (!mul || mul->GetOpcode() != Opcode::Mul || !mul->GetType()->IsInt32()) {
return false;
}
if (mul->GetLhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetRhs())) {
factor = mul->GetRhs();
return true;
}
if (mul->GetRhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetLhs())) {
factor = mul->GetLhs();
return true;
}
return false;
}
Value* CreateReducedPhi(Function& function, BasicBlock* header, BasicBlock* preheader,
const InductionVarInfo& iv, Value* factor) {
auto* reduced_phi = header->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(header), Type::GetInt32Type(), nullptr,
looputils::NextSyntheticName(function, "lsr.phi."));
Value* init = BuildMulValue(function, preheader, iv.start, factor, "lsr.init.");
reduced_phi->AddIncoming(init, preheader);
Value* step = BuildScaledValue(function, preheader, factor, std::abs(iv.stride),
"lsr.step.");
Instruction* next = nullptr;
if (iv.stride > 0) {
next = iv.latch->Insert<BinaryInst>(
looputils::GetTerminatorIndex(iv.latch), Opcode::Add, Type::GetInt32Type(),
reduced_phi, step, nullptr,
looputils::NextSyntheticName(function, "lsr.next."));
} else {
next = iv.latch->Insert<BinaryInst>(
looputils::GetTerminatorIndex(iv.latch), Opcode::Sub, Type::GetInt32Type(),
reduced_phi, step, nullptr,
looputils::NextSyntheticName(function, "lsr.next."));
}
reduced_phi->AddIncoming(next, iv.latch);
return reduced_phi;
}
bool ReduceLoopMultiplications(Function& function, const Loop& loop,
BasicBlock* preheader) {
if (!preheader || loop.latches.size() != 1) {
return false;
}
std::vector<InductionVarInfo> induction_vars;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
InductionVarInfo info;
if (MatchSimpleInductionVariable(loop, preheader, phi, info)) {
induction_vars.push_back(info);
}
}
if (induction_vars.empty()) {
return false;
}
bool changed = false;
std::vector<Instruction*> to_remove;
for (const auto& iv : induction_vars) {
std::vector<std::pair<Instruction*, Value*>> candidates;
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == iv.phi) {
continue;
}
Value* factor = nullptr;
if (IsMulCandidate(loop, inst, iv.phi, factor)) {
candidates.push_back({inst, factor});
}
}
}
if (candidates.empty()) {
continue;
}
std::unordered_map<Value*, Value*> reduced_cache;
for (const auto& candidate : candidates) {
auto* inst = candidate.first;
auto* factor = candidate.second;
auto cache_it = reduced_cache.find(factor);
Value* replacement = nullptr;
if (cache_it != reduced_cache.end()) {
replacement = cache_it->second;
} else {
replacement = CreateReducedPhi(function, loop.header, preheader, iv, factor);
reduced_cache.emplace(factor, replacement);
}
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
}
for (auto* inst : to_remove) {
if (inst && inst->GetParent()) {
inst->GetParent()->EraseInstruction(inst);
}
}
return changed;
}
bool RunLoopStrengthReductionOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
bool loop_changed = preheader != old_preheader;
loop_changed |= ReduceLoopMultiplications(function, *loop, preheader);
if (!loop_changed) {
continue;
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopStrengthReduction(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopStrengthReductionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,400 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct CountedLoopInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
CondBrInst* branch = nullptr;
BinaryInst* compare = nullptr;
Opcode compare_opcode = Opcode::ICmpLT;
Value* bound = nullptr;
loopmem::SimpleInductionVar induction_var;
std::vector<PhiInst*> phis;
};
bool HasSyntheticLoopTag(const std::string& name) {
return name.find("unroll.") != std::string::npos;
}
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
if (!loop.preheader || !loop.header || !body) {
return true;
}
if (HasSyntheticLoopTag(loop.preheader->GetName()) ||
HasSyntheticLoopTag(loop.header->GetName()) ||
HasSyntheticLoopTag(body->GetName())) {
return true;
}
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int incoming = looputils::GetPhiIncomingIndex(phi, loop.preheader);
if (incoming < 0) {
continue;
}
auto* incoming_phi = dyncast<PhiInst>(phi->GetIncomingValue(incoming));
if (incoming_phi && incoming_phi->GetParent() &&
HasSyntheticLoopTag(incoming_phi->GetParent()->GetName())) {
return true;
}
}
return false;
}
bool IsSupportedCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
case Opcode::ICmpLE:
case Opcode::ICmpGT:
case Opcode::ICmpGE:
return true;
default:
return false;
}
}
Opcode SwapCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
return Opcode::ICmpGT;
case Opcode::ICmpLE:
return Opcode::ICmpGE;
case Opcode::ICmpGT:
return Opcode::ICmpLT;
case Opcode::ICmpGE:
return Opcode::ICmpLE;
default:
return opcode;
}
}
int CountPayloadInstructions(BasicBlock* block) {
int count = 0;
if (!block) {
return 0;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
break;
}
++count;
}
return count;
}
int ChooseUnrollFactor(BasicBlock* body) {
const int inst_count = CountPayloadInstructions(body);
int mem_ops = 0;
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
break;
}
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
++mem_ops;
}
}
if (inst_count >= 2 && inst_count <= 6 && mem_ops <= 2) {
return 4;
}
if (inst_count >= 2 && inst_count <= 18) {
return 2;
}
return 1;
}
bool HasUnsafeLoopCarriedMemoryDependence(
const std::vector<loopmem::MemoryAccessInfo>& accesses, int iv_stride) {
for (std::size_t i = 0; i < accesses.size(); ++i) {
if (accesses[i].is_write &&
loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[i].ptr,
iv_stride)) {
return true;
}
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
if (!accesses[i].is_write && !accesses[j].is_write) {
continue;
}
if (loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr,
iv_stride)) {
return true;
}
}
}
return false;
}
bool MatchCountedLoop(Loop& loop, CountedLoopInfo& info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
return false;
}
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
return false;
}
if (IsAlreadyTransformedLoop(loop, body)) {
return false;
}
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
if (!branch || branch->GetThenBlock() != body) {
return false;
}
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
if (!compare || !compare->GetType()->IsBool() ||
!IsSupportedCompareOpcode(compare->GetOpcode())) {
return false;
}
bool found_iv = false;
loopmem::SimpleInductionVar induction_var;
std::vector<PhiInst*> phis;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
if (!found_iv &&
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
found_iv = true;
}
}
if (!found_iv) {
return false;
}
Opcode compare_opcode = compare->GetOpcode();
Value* bound = nullptr;
if (compare->GetLhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
bound = compare->GetRhs();
} else if (compare->GetRhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
bound = compare->GetLhs();
compare_opcode = SwapCompareOpcode(compare_opcode);
} else {
return false;
}
if (!bound) {
return false;
}
if ((induction_var.stride > 0 &&
!(compare_opcode == Opcode::ICmpLT || compare_opcode == Opcode::ICmpLE)) ||
(induction_var.stride < 0 &&
!(compare_opcode == Opcode::ICmpGT || compare_opcode == Opcode::ICmpGE))) {
return false;
}
const auto accesses = loopmem::CollectMemoryAccesses(loop, induction_var.phi);
if (HasUnsafeLoopCarriedMemoryDependence(accesses, induction_var.stride)) {
return false;
}
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst) || inst == compare || inst->IsTerminator()) {
continue;
}
return false;
}
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
continue;
}
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
return false;
}
}
info.loop = &loop;
info.preheader = loop.preheader;
info.header = loop.header;
info.body = body;
info.exit = exit;
info.branch = branch;
info.compare = compare;
info.compare_opcode = compare_opcode;
info.bound = bound;
info.induction_var = induction_var;
info.phis = std::move(phis);
return true;
}
Value* BuildAdjustedBound(Function& function, BasicBlock* preheader, Value* bound,
int stride, int factor) {
const int delta = std::abs(stride) * (factor - 1);
if (delta == 0) {
return bound;
}
if (auto* ci = dyncast<ConstantInt>(bound)) {
return looputils::ConstInt(stride > 0 ? ci->GetValue() - delta : ci->GetValue() + delta);
}
return preheader->Insert<BinaryInst>(
looputils::GetTerminatorIndex(preheader),
stride > 0 ? Opcode::Sub : Opcode::Add, Type::GetInt32Type(), bound,
looputils::ConstInt(delta), nullptr,
looputils::NextSyntheticName(function, "unroll.bound."));
}
bool RunLoopUnrollOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
CountedLoopInfo info;
if (!MatchCountedLoop(*loop, info)) {
continue;
}
const int factor = ChooseUnrollFactor(info.body);
if (factor <= 1) {
continue;
}
auto* unrolled_header =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.header"));
auto* unrolled_body =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.body"));
auto* unrolled_exit =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.exit"));
std::unordered_map<Value*, Value*> remap;
std::unordered_map<PhiInst*, PhiInst*> unrolled_phis;
std::unordered_map<PhiInst*, PhiInst*> exit_phis;
std::unordered_map<PhiInst*, Value*> current_phi_values;
std::unordered_map<PhiInst*, Value*> latch_values;
for (auto* phi : info.phis) {
auto* cloned_phi = unrolled_header->Append<PhiInst>(
phi->GetType(), nullptr,
looputils::NextSyntheticName(function, "unroll.phi."));
const int preheader_index = looputils::GetPhiIncomingIndex(phi, info.preheader);
const int latch_index = looputils::GetPhiIncomingIndex(phi, info.body);
if (preheader_index < 0 || latch_index < 0) {
continue;
}
cloned_phi->AddIncoming(phi->GetIncomingValue(preheader_index), info.preheader);
remap[phi] = cloned_phi;
unrolled_phis.emplace(phi, cloned_phi);
current_phi_values.emplace(phi, cloned_phi);
latch_values.emplace(phi, phi->GetIncomingValue(latch_index));
}
auto* adjusted_bound = BuildAdjustedBound(function, info.preheader, info.bound,
info.induction_var.stride, factor);
auto* unrolled_cond = unrolled_header->Append<BinaryInst>(
info.compare_opcode, Type::GetBoolType(), unrolled_phis[info.induction_var.phi],
adjusted_bound, nullptr,
looputils::NextSyntheticName(function, "unroll.cmp."));
unrolled_header->Append<CondBrInst>(unrolled_cond, unrolled_body, unrolled_exit, nullptr);
unrolled_header->AddPredecessor(info.preheader);
unrolled_header->AddSuccessor(unrolled_body);
unrolled_header->AddSuccessor(unrolled_exit);
for (int lane = 0; lane < factor; ++lane) {
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
continue;
}
looputils::CloneInstruction(function, inst, unrolled_body, remap,
"unroll." + std::to_string(lane) + ".");
}
std::unordered_map<PhiInst*, Value*> next_phi_values;
for (const auto& entry : latch_values) {
next_phi_values.emplace(entry.first,
looputils::RemapValue(remap, entry.second));
}
for (const auto& entry : next_phi_values) {
remap[entry.first] = entry.second;
current_phi_values[entry.first] = entry.second;
}
}
for (const auto& entry : unrolled_phis) {
entry.second->AddIncoming(current_phi_values[entry.first], unrolled_body);
}
unrolled_body->Append<UncondBrInst>(unrolled_header, nullptr);
unrolled_body->AddPredecessor(unrolled_header);
unrolled_body->AddSuccessor(unrolled_header);
unrolled_header->AddPredecessor(unrolled_body);
for (const auto& entry : unrolled_phis) {
auto* exit_phi = unrolled_exit->Append<PhiInst>(
entry.first->GetType(), nullptr,
looputils::NextSyntheticName(function, "unroll.exit."));
exit_phi->AddIncoming(entry.second, unrolled_header);
exit_phis.emplace(entry.first, exit_phi);
}
unrolled_exit->Append<UncondBrInst>(info.header, nullptr);
unrolled_exit->AddPredecessor(unrolled_header);
unrolled_exit->AddSuccessor(info.header);
if (!looputils::RedirectSuccessorEdge(info.preheader, info.header, unrolled_header)) {
continue;
}
info.header->RemovePredecessor(info.preheader);
info.header->AddPredecessor(unrolled_exit);
for (auto* phi : info.phis) {
looputils::ReplacePhiIncoming(phi, info.preheader, exit_phis[phi], unrolled_exit);
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopUnroll(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopUnrollOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,313 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct UnswitchInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* guard_block = nullptr;
CondBrInst* guard = nullptr;
Value* condition = nullptr;
std::vector<BasicBlock*> order;
};
bool HasSyntheticUnswitchTag(const std::string& name) {
return name.find("unswitch.") != std::string::npos;
}
bool IsSafeLoopBlockForUnswitch(BasicBlock* block) {
if (!block) {
return false;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<CallInst>(inst) || dyncast<MemsetInst>(inst) ||
dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst)) {
return false;
}
}
return true;
}
void CollectLoopDFS(BasicBlock* block, const Loop& loop,
std::unordered_set<BasicBlock*>& visited,
std::vector<BasicBlock*>& postorder) {
if (!block || !loop.Contains(block) || !visited.insert(block).second) {
return;
}
for (auto* succ : block->GetSuccessors()) {
if (loop.Contains(succ)) {
CollectLoopDFS(succ, loop, visited, postorder);
}
}
postorder.push_back(block);
}
std::vector<BasicBlock*> CollectLoopRPO(const Loop& loop) {
std::vector<BasicBlock*> postorder;
std::unordered_set<BasicBlock*> visited;
CollectLoopDFS(loop.header, loop, visited, postorder);
return std::vector<BasicBlock*>(postorder.rbegin(), postorder.rend());
}
void RemovePhiIncomingFromPred(BasicBlock* block, BasicBlock* pred) {
if (!block || !pred) {
return;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int index = looputils::GetPhiIncomingIndex(phi, pred);
if (index < 0) {
continue;
}
phi->RemoveOperand(static_cast<std::size_t>(2 * index + 1));
phi->RemoveOperand(static_cast<std::size_t>(2 * index));
}
}
void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
auto& instructions = block->GetInstructions();
if (instructions.empty() || !instructions.back()->IsTerminator()) {
return;
}
instructions.back()->ClearAllOperands();
instructions.pop_back();
block->Append<UncondBrInst>(dest, nullptr);
}
void ReplaceTerminatorWithCondBr(BasicBlock* block, Value* cond,
BasicBlock* then_block, BasicBlock* else_block) {
auto& instructions = block->GetInstructions();
if (instructions.empty() || !instructions.back()->IsTerminator()) {
return;
}
instructions.back()->ClearAllOperands();
instructions.pop_back();
block->Append<CondBrInst>(cond, then_block, else_block, nullptr);
}
bool MatchLoopUnswitch(Loop& loop, UnswitchInfo& info) {
if (!loop.preheader || !loop.IsInnermost() || loop.block_list.size() > 6) {
return false;
}
if (HasSyntheticUnswitchTag(loop.preheader->GetName()) ||
HasSyntheticUnswitchTag(loop.header->GetName())) {
return false;
}
int instruction_count = 0;
for (auto* block : loop.block_list) {
if (!IsSafeLoopBlockForUnswitch(block) || HasSyntheticUnswitchTag(block->GetName())) {
return false;
}
instruction_count += static_cast<int>(block->GetInstructions().size());
}
if (instruction_count > 48) {
return false;
}
for (auto* block : loop.block_list) {
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(block));
if (!condbr) {
continue;
}
auto* cond_inst = dyncast<Instruction>(condbr->GetCondition());
if (cond_inst && loop.Contains(cond_inst->GetParent())) {
continue;
}
info.loop = &loop;
info.preheader = loop.preheader;
info.guard_block = block;
info.guard = condbr;
info.condition = condbr->GetCondition();
info.order = CollectLoopRPO(loop);
return true;
}
return false;
}
bool CloneLoopForUnswitch(Function& function, const UnswitchInfo& info,
std::unordered_map<BasicBlock*, BasicBlock*>& block_map,
std::unordered_map<Value*, Value*>& value_map) {
for (auto* block : info.order) {
block_map[block] = function.CreateBlock(
looputils::NextSyntheticBlockName(function, "unswitch.loop"));
}
for (auto* block : info.order) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = clone->Append<PhiInst>(
phi->GetType(), nullptr, looputils::NextSyntheticName(function, "unswitch.phi."));
value_map[phi] = cloned_phi;
}
}
for (auto* block : info.order) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst) || inst->IsTerminator()) {
continue;
}
if (!looputils::CloneInstruction(function, inst, clone, value_map, "unswitch.")) {
return false;
}
}
}
for (auto* block : info.order) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = static_cast<PhiInst*>(value_map.at(phi));
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* incoming_block = phi->GetIncomingBlock(i);
auto value_it = value_map.find(phi->GetIncomingValue(i));
Value* incoming_value =
value_it == value_map.end() ? phi->GetIncomingValue(i) : value_it->second;
auto block_it = block_map.find(incoming_block);
cloned_phi->AddIncoming(incoming_value,
block_it == block_map.end() ? incoming_block
: block_it->second);
}
}
}
for (auto* block : info.order) {
auto* clone = block_map.at(block);
auto* terminator = looputils::GetTerminator(block);
if (auto* br = dyncast<UncondBrInst>(terminator)) {
auto target_it = block_map.find(br->GetDest());
auto* target = target_it == block_map.end() ? br->GetDest() : target_it->second;
clone->Append<UncondBrInst>(target, nullptr);
clone->AddSuccessor(target);
target->AddPredecessor(clone);
continue;
}
auto* condbr = dyncast<CondBrInst>(terminator);
if (!condbr) {
return false;
}
auto cond_it = value_map.find(condbr->GetCondition());
Value* cond = cond_it == value_map.end() ? condbr->GetCondition() : cond_it->second;
auto then_it = block_map.find(condbr->GetThenBlock());
auto else_it = block_map.find(condbr->GetElseBlock());
auto* then_block = then_it == block_map.end() ? condbr->GetThenBlock() : then_it->second;
auto* else_block = else_it == block_map.end() ? condbr->GetElseBlock() : else_it->second;
clone->Append<CondBrInst>(cond, then_block, else_block, nullptr);
clone->AddSuccessor(then_block);
clone->AddSuccessor(else_block);
then_block->AddPredecessor(clone);
else_block->AddPredecessor(clone);
}
return true;
}
bool RunLoopUnswitchOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
UnswitchInfo info;
if (!MatchLoopUnswitch(*loop, info)) {
continue;
}
std::unordered_map<BasicBlock*, BasicBlock*> block_map;
std::unordered_map<Value*, Value*> value_map;
if (!CloneLoopForUnswitch(function, info, block_map, value_map)) {
continue;
}
auto* then_target = info.guard->GetThenBlock();
auto* else_target = info.guard->GetElseBlock();
auto* cloned_guard = block_map.at(info.guard_block);
auto* cloned_then =
block_map.count(then_target) ? block_map.at(then_target) : then_target;
auto* cloned_else =
block_map.count(else_target) ? block_map.at(else_target) : else_target;
RemovePhiIncomingFromPred(else_target, info.guard_block);
if (then_target != else_target) {
RemovePhiIncomingFromPred(cloned_then, cloned_guard);
}
info.guard_block->RemoveSuccessor(else_target);
else_target->RemovePredecessor(info.guard_block);
ReplaceTerminatorWithBr(info.guard_block, then_target);
info.guard_block->AddSuccessor(then_target);
then_target->AddPredecessor(info.guard_block);
cloned_guard->RemoveSuccessor(cloned_then);
cloned_then->RemovePredecessor(cloned_guard);
ReplaceTerminatorWithBr(cloned_guard, cloned_else);
cloned_guard->AddSuccessor(cloned_else);
cloned_else->AddPredecessor(cloned_guard);
auto* cloned_header = block_map.at(loop->header);
auto* old_preheader_term = looputils::GetTerminator(info.preheader);
if (!old_preheader_term || !dyncast<UncondBrInst>(old_preheader_term)) {
continue;
}
info.preheader->RemoveSuccessor(loop->header);
loop->header->RemovePredecessor(info.preheader);
ReplaceTerminatorWithCondBr(info.preheader, info.condition, loop->header, cloned_header);
info.preheader->AddSuccessor(loop->header);
info.preheader->AddSuccessor(cloned_header);
loop->header->AddPredecessor(info.preheader);
cloned_header->AddPredecessor(info.preheader);
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopUnswitch(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopUnswitchOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,405 @@
// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
#include "ir/PassManager.h"
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <queue>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct DominatorInfo {
std::vector<BasicBlock*> blocks;
std::unordered_map<BasicBlock*, size_t> index;
std::vector<std::vector<bool>> dominates;
std::vector<BasicBlock*> idom;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_tree_children;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dominance_frontier;
};
struct PromotableAlloca {
AllocaInst* alloca = nullptr;
std::shared_ptr<Type> value_type;
std::unordered_set<BasicBlock*> def_blocks;
std::unordered_map<BasicBlock*, PhiInst*> phis;
};
bool IsScalarPromotableType(const std::shared_ptr<Type>& type) {
return type && (type->IsInt1() || type->IsInt32() || type->IsFloat());
}
Value* DefaultValueFor(Context& ctx, const std::shared_ptr<Type>& type) {
if (type->IsInt1()) {
return ctx.GetConstBool(false);
}
if (type->IsInt32()) {
return ctx.GetConstInt(0);
}
if (type->IsFloat()) {
return new ConstantFloat(Type::GetFloatType(), 0.0f);
}
throw std::runtime_error("Mem2Reg encountered unsupported promotable type");
}
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
return order;
}
std::vector<bool> IntersectDominators(const std::vector<std::vector<bool>>& doms,
const std::vector<size_t>& pred_indices,
size_t self_index) {
std::vector<bool> result(doms.size(), true);
if (pred_indices.empty()) {
std::fill(result.begin(), result.end(), false);
result[self_index] = true;
return result;
}
result = doms[pred_indices.front()];
for (size_t i = 1; i < pred_indices.size(); ++i) {
const auto& pred_dom = doms[pred_indices[i]];
for (size_t j = 0; j < result.size(); ++j) {
result[j] = result[j] && pred_dom[j];
}
}
result[self_index] = true;
return result;
}
DominatorInfo BuildDominatorInfo(Function& function) {
DominatorInfo info;
info.blocks = CollectReachableBlocks(function);
info.idom.resize(info.blocks.size(), nullptr);
info.dominates.assign(info.blocks.size(),
std::vector<bool>(info.blocks.size(), true));
if (info.blocks.empty()) {
return info;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
info.index[info.blocks[i]] = i;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
std::fill(info.dominates[i].begin(), info.dominates[i].end(), i != 0);
info.dominates[i][i] = true;
}
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 1; i < info.blocks.size(); ++i) {
std::vector<size_t> pred_indices;
for (auto* pred : info.blocks[i]->GetPredecessors()) {
auto it = info.index.find(pred);
if (it != info.index.end()) {
pred_indices.push_back(it->second);
}
}
auto new_dom = IntersectDominators(info.dominates, pred_indices, i);
if (new_dom != info.dominates[i]) {
info.dominates[i] = std::move(new_dom);
changed = true;
}
}
}
for (size_t i = 1; i < info.blocks.size(); ++i) {
BasicBlock* candidate_idom = nullptr;
for (size_t j = 0; j < info.blocks.size(); ++j) {
if (i == j || !info.dominates[i][j]) {
continue;
}
bool is_immediate = true;
for (size_t k = 0; k < info.blocks.size(); ++k) {
if (k == i || k == j || !info.dominates[i][k]) {
continue;
}
if (info.dominates[k][j]) {
is_immediate = false;
break;
}
}
if (is_immediate) {
candidate_idom = info.blocks[j];
break;
}
}
info.idom[i] = candidate_idom;
if (candidate_idom != nullptr) {
info.dom_tree_children[candidate_idom].push_back(info.blocks[i]);
}
}
for (auto* block : info.blocks) {
info.dominance_frontier[block] = {};
}
for (auto* block : info.blocks) {
std::vector<BasicBlock*> reachable_preds;
for (auto* pred : block->GetPredecessors()) {
if (info.index.find(pred) != info.index.end()) {
reachable_preds.push_back(pred);
}
}
if (reachable_preds.size() < 2) {
continue;
}
auto* idom_block = info.idom[info.index[block]];
for (auto* pred : reachable_preds) {
auto* runner = pred;
while (runner != nullptr && runner != idom_block) {
auto& frontier = info.dominance_frontier[runner];
if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) {
frontier.push_back(block);
}
auto idom_it = info.index.find(runner);
if (idom_it == info.index.end()) {
break;
}
runner = info.idom[idom_it->second];
}
}
}
return info;
}
bool IsPromotableAlloca(AllocaInst& alloca, const DominatorInfo& dom_info) {
if (!IsScalarPromotableType(alloca.GetAllocatedType())) {
return false;
}
for (const auto& use : alloca.GetUses()) {
auto* user = use.GetUser();
auto* inst = dynamic_cast<Instruction*>(user);
if (inst == nullptr || inst->GetParent() == nullptr ||
dom_info.index.find(inst->GetParent()) == dom_info.index.end()) {
return false;
}
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() != &alloca) {
return false;
}
continue;
}
auto* store = dynamic_cast<StoreInst*>(inst);
if (store == nullptr || store->GetPtr() != &alloca ||
store->GetValue() == &alloca) {
return false;
}
if (store->GetValue()->GetType() != alloca.GetAllocatedType()) {
return false;
}
}
return true;
}
size_t CountLeadingPhiNodes(BasicBlock& block) {
size_t count = 0;
for (const auto& inst : block.GetInstructions()) {
if (!isa<PhiInst>(inst.get())) {
break;
}
++count;
}
return count;
}
void InsertPhiNodes(Context& ctx, PromotableAlloca& slot,
const DominatorInfo& dom_info) {
std::queue<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> queued;
for (auto* block : slot.def_blocks) {
worklist.push(block);
queued.insert(block);
}
while (!worklist.empty()) {
auto* block = worklist.front();
worklist.pop();
auto frontier_it = dom_info.dominance_frontier.find(block);
if (frontier_it == dom_info.dominance_frontier.end()) {
continue;
}
for (auto* frontier_block : frontier_it->second) {
if (slot.phis.find(frontier_block) != slot.phis.end()) {
continue;
}
auto phi_index = CountLeadingPhiNodes(*frontier_block);
auto* phi = frontier_block->Insert<PhiInst>(phi_index, slot.value_type, nullptr,
ctx.NextTemp());
slot.phis[frontier_block] = phi;
if (slot.def_blocks.insert(frontier_block).second) {
worklist.push(frontier_block);
}
}
}
}
void RenamePromotedAlloca(BasicBlock* block, PromotableAlloca& slot,
const DominatorInfo& dom_info,
std::vector<Value*>& stack, Context& ctx) {
if (block == nullptr) {
return;
}
size_t pushed = 0;
PhiInst* block_phi = nullptr;
auto phi_it = slot.phis.find(block);
if (phi_it != slot.phis.end()) {
block_phi = phi_it->second;
stack.push_back(block_phi);
++pushed;
}
std::vector<Instruction*> to_remove;
Instruction* alloca_to_remove = nullptr;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == block_phi) {
continue;
}
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() != slot.alloca) {
continue;
}
auto* replacement =
stack.empty() ? DefaultValueFor(ctx, slot.value_type) : stack.back();
load->ReplaceAllUsesWith(replacement);
to_remove.push_back(load);
continue;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetPtr() != slot.alloca) {
continue;
}
stack.push_back(store->GetValue());
++pushed;
to_remove.push_back(store);
continue;
}
if (inst == slot.alloca) {
alloca_to_remove = inst;
}
}
for (auto* succ : block->GetSuccessors()) {
auto succ_phi_it = slot.phis.find(succ);
if (succ_phi_it == slot.phis.end()) {
continue;
}
auto* incoming =
stack.empty() ? DefaultValueFor(ctx, slot.value_type) : stack.back();
succ_phi_it->second->AddIncoming(incoming, block);
}
auto child_it = dom_info.dom_tree_children.find(block);
if (child_it != dom_info.dom_tree_children.end()) {
for (auto* child : child_it->second) {
RenamePromotedAlloca(child, slot, dom_info, stack, ctx);
}
}
for (auto* inst : to_remove) {
block->EraseInstruction(inst);
}
if (alloca_to_remove != nullptr) {
block->EraseInstruction(alloca_to_remove);
}
while (pushed > 0) {
stack.pop_back();
--pushed;
}
}
void PromoteAllocasInFunction(Function& function, Context& ctx) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return;
}
auto dom_info = BuildDominatorInfo(function);
if (dom_info.blocks.empty()) {
return;
}
std::vector<PromotableAlloca> promotable_allocas;
for (const auto& inst_ptr : function.GetEntryBlock()->GetInstructions()) {
auto* alloca = dynamic_cast<AllocaInst*>(inst_ptr.get());
if (alloca == nullptr || !IsPromotableAlloca(*alloca, dom_info)) {
continue;
}
PromotableAlloca slot;
slot.alloca = alloca;
slot.value_type = alloca->GetAllocatedType();
for (const auto& use : alloca->GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
auto* store = inst == nullptr ? nullptr : dynamic_cast<StoreInst*>(inst);
if (store != nullptr && store->GetPtr() == alloca) {
slot.def_blocks.insert(inst->GetParent());
}
}
promotable_allocas.push_back(std::move(slot));
}
for (auto& slot : promotable_allocas) {
InsertPhiNodes(ctx, slot, dom_info);
std::vector<Value*> stack;
RenamePromotedAlloca(function.GetEntryBlock(), slot, dom_info, stack, ctx);
}
}
} // namespace
void RunMem2Reg(Module& module) {
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function != nullptr) {
PromoteAllocasInFunction(*function, ctx);
}
}
}
} // namespace ir

@ -0,0 +1,260 @@
#pragma once
#include "ir/IR.h"
#include "PassUtils.h"
#include <cstdint>
#include <unordered_set>
#include <vector>
namespace ir::memutils {
enum class PointerRootKind {
Local,
Global,
ReadonlyGlobal,
Param,
Unknown,
};
struct AddressComponent {
bool is_constant = false;
std::int64_t constant = 0;
Value* value = nullptr;
bool operator==(const AddressComponent& rhs) const {
return is_constant == rhs.is_constant && constant == rhs.constant &&
value == rhs.value;
}
};
struct AddressKey {
PointerRootKind kind = PointerRootKind::Unknown;
Value* root = nullptr;
std::vector<AddressComponent> components;
bool operator==(const AddressKey& rhs) const {
return kind == rhs.kind && root == rhs.root && components == rhs.components;
}
};
struct AddressKeyHash {
std::size_t operator()(const AddressKey& key) const {
std::size_t h = static_cast<std::size_t>(key.kind);
h ^= std::hash<Value*>{}(key.root) + 0x9e3779b9 + (h << 6) + (h >> 2);
for (const auto& component : key.components) {
h ^= std::hash<bool>{}(component.is_constant) + 0x9e3779b9 + (h << 6) + (h >> 2);
if (component.is_constant) {
h ^= std::hash<std::int64_t>{}(component.constant) + 0x9e3779b9 + (h << 6) +
(h >> 2);
} else {
h ^= std::hash<Value*>{}(component.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
}
return h;
}
};
struct EscapeSummary {
std::unordered_set<Value*> escaped_locals;
bool IsEscaped(Value* value) const {
return value != nullptr && escaped_locals.find(value) != escaped_locals.end();
}
};
inline bool IsNoEscapePointerUse(Value* current, Instruction* user) {
if (!current || !user) {
return false;
}
if (auto* load = dyncast<LoadInst>(user)) {
return load->GetPtr() == current;
}
if (auto* store = dyncast<StoreInst>(user)) {
return store->GetPtr() == current;
}
if (auto* memset = dyncast<MemsetInst>(user)) {
return memset->GetDest() == current;
}
return false;
}
inline bool PointerValueEscapes(Value* current, Value* root,
std::unordered_set<Value*>& visiting) {
if (!current || !root || !visiting.insert(current).second) {
return false;
}
for (const auto& use : current->GetUses()) {
auto* user = dyncast<Instruction>(use.GetUser());
if (!user) {
return true;
}
if (auto* gep = dyncast<GetElementPtrInst>(user)) {
if (gep->GetPointer() == current &&
PointerValueEscapes(gep, root, visiting)) {
return true;
}
continue;
}
if (IsNoEscapePointerUse(current, user)) {
continue;
}
return true;
}
return false;
}
inline EscapeSummary AnalyzeEscapes(Function& function) {
EscapeSummary summary;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* alloca = dyncast<AllocaInst>(inst_ptr.get());
if (!alloca) {
continue;
}
std::unordered_set<Value*> visiting;
if (PointerValueEscapes(alloca, alloca, visiting)) {
summary.escaped_locals.insert(alloca);
}
}
}
return summary;
}
inline PointerRootKind ClassifyRoot(Value* root, const EscapeSummary* summary) {
if (root == nullptr) {
return PointerRootKind::Unknown;
}
if (auto* global = dyncast<GlobalValue>(root)) {
return global->IsConstant() ? PointerRootKind::ReadonlyGlobal
: PointerRootKind::Global;
}
if (isa<Argument>(root)) {
return PointerRootKind::Param;
}
if (isa<AllocaInst>(root)) {
if (summary != nullptr && summary->IsEscaped(root)) {
return PointerRootKind::Unknown;
}
return PointerRootKind::Local;
}
return PointerRootKind::Unknown;
}
inline Value* StripPointerRoot(Value* pointer) {
auto* current = pointer;
while (auto* gep = dyncast<GetElementPtrInst>(current)) {
current = gep->GetPointer();
}
return current;
}
inline AddressComponent MakeAddressComponent(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {true, ci->GetValue(), nullptr};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {true, cb->GetValue() ? 1 : 0, nullptr};
}
return {false, 0, value};
}
inline bool BuildExactAddressKey(Value* pointer, const EscapeSummary* summary,
AddressKey& key) {
if (!pointer) {
return false;
}
if (auto* gep = dyncast<GetElementPtrInst>(pointer)) {
if (!BuildExactAddressKey(gep->GetPointer(), summary, key)) {
return false;
}
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
key.components.push_back(MakeAddressComponent(gep->GetIndex(i)));
}
return true;
}
key.kind = ClassifyRoot(pointer, summary);
key.root = pointer;
key.components.clear();
return true;
}
inline bool HasOnlyConstantComponents(const AddressKey& key) {
for (const auto& component : key.components) {
if (!component.is_constant) {
return false;
}
}
return true;
}
inline bool MayAliasConservatively(const AddressKey& lhs, const AddressKey& rhs) {
if (lhs.kind == PointerRootKind::Unknown || rhs.kind == PointerRootKind::Unknown) {
return true;
}
if (lhs.kind != rhs.kind || lhs.root != rhs.root) {
return false;
}
if (lhs.components == rhs.components) {
return true;
}
if (HasOnlyConstantComponents(lhs) && HasOnlyConstantComponents(rhs)) {
return false;
}
return true;
}
inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) {
if (!callee) {
return true;
}
if (callee->HasUnknownEffects()) {
return true;
}
switch (kind) {
case PointerRootKind::ReadonlyGlobal:
return callee->ReadsGlobalMemory();
case PointerRootKind::Global:
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory();
case PointerRootKind::Param:
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
case PointerRootKind::Unknown:
return callee->MayReadMemory();
}
return true;
}
inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) {
if (!callee) {
return true;
}
if (callee->HasUnknownEffects()) {
return true;
}
switch (kind) {
case PointerRootKind::ReadonlyGlobal:
return false;
case PointerRootKind::Global:
return callee->WritesGlobalMemory();
case PointerRootKind::Param:
return callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
case PointerRootKind::Unknown:
return callee->MayWriteMemory();
}
return true;
}
inline bool IsPureCall(const CallInst* call) {
auto* callee = call == nullptr ? nullptr : call->GetCallee();
return callee != nullptr && callee->CanDiscardUnusedCall() &&
!callee->MayReadMemory();
}
} // namespace ir::memutils

@ -0,0 +1,66 @@
// IR Pass 管理骨架。
#include "ir/PassManager.h"
#include <cstdlib>
namespace ir {
void RunIRPassPipeline(Module& module) {
const char* disable_mem2reg = std::getenv("NUDTC_DISABLE_MEM2REG");
if (disable_mem2reg != nullptr && disable_mem2reg[0] != '\0' && disable_mem2reg[0] != '0') {
return;
}
const char* disable_loop_mem_promotion =
std::getenv("NUDTC_DISABLE_LOOP_MEM_PROMOTION");
const bool run_loop_mem_promotion =
disable_loop_mem_promotion == nullptr || disable_loop_mem_promotion[0] == '\0' ||
disable_loop_mem_promotion[0] == '0';
const char* disable_inline_cfg = std::getenv("NUDTC_DISABLE_CFG_INLINE");
const bool run_cfg_inline =
disable_inline_cfg == nullptr || disable_inline_cfg[0] == '\0' ||
disable_inline_cfg[0] == '0';
const char* disable_loop_unswitch = std::getenv("NUDTC_DISABLE_LOOP_UNSWITCH");
const bool run_loop_unswitch =
disable_loop_unswitch == nullptr || disable_loop_unswitch[0] == '\0' ||
disable_loop_unswitch[0] == '0';
RunMem2Reg(module);
constexpr int kMaxIterations = 8;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
if (run_cfg_inline) {
changed |= RunFunctionInlining(module);
}
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
changed |= RunLICM(module);
if (run_loop_mem_promotion) {
changed |= RunLoopMemoryPromotion(module);
}
if (run_loop_unswitch) {
changed |= RunLoopUnswitch(module);
}
changed |= RunLoopStrengthReduction(module);
changed |= RunLoopFission(module);
changed |= RunLoopUnroll(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
if (!changed) {
break;
}
}
}
} // namespace ir

@ -0,0 +1,234 @@
#pragma once
#include "ir/IR.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <unordered_set>
#include <vector>
namespace ir::passutils {
inline std::uint32_t FloatBits(float value) {
std::uint32_t bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
return bits;
}
inline bool AreEquivalentValues(Value* lhs, Value* rhs) {
if (lhs == rhs) {
return true;
}
auto* lhs_i32 = dyncast<ConstantInt>(lhs);
auto* rhs_i32 = dyncast<ConstantInt>(rhs);
if (lhs_i32 && rhs_i32) {
return lhs_i32->GetValue() == rhs_i32->GetValue();
}
auto* lhs_i1 = dyncast<ConstantI1>(lhs);
auto* rhs_i1 = dyncast<ConstantI1>(rhs);
if (lhs_i1 && rhs_i1) {
return lhs_i1->GetValue() == rhs_i1->GetValue();
}
auto* lhs_f32 = dyncast<ConstantFloat>(lhs);
auto* rhs_f32 = dyncast<ConstantFloat>(rhs);
if (lhs_f32 && rhs_f32) {
return FloatBits(lhs_f32->GetValue()) == FloatBits(rhs_f32->GetValue());
}
return false;
}
inline std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
return order;
}
inline bool IsSideEffectingInstruction(const Instruction* inst) {
if (!inst) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Store:
case Opcode::Memset:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Return:
case Opcode::Unreachable:
return true;
case Opcode::Call: {
auto* call = dyncast<const CallInst>(inst);
auto* callee = call == nullptr ? nullptr : call->GetCallee();
return callee == nullptr || !callee->CanDiscardUnusedCall();
}
default:
return false;
}
}
inline bool IsTriviallyDead(Instruction* inst) {
return inst != nullptr && !IsSideEffectingInstruction(inst) &&
inst->GetUses().empty();
}
inline void RemoveIncomingForBlock(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) {
return;
}
for (int i = phi->GetNumIncomings() - 1; i >= 0; --i) {
if (phi->GetIncomingBlock(i) != block) {
continue;
}
phi->RemoveOperand(static_cast<size_t>(2 * i + 1));
phi->RemoveOperand(static_cast<size_t>(2 * i));
}
}
inline void RemoveIncomingFromSuccessor(BasicBlock* succ, BasicBlock* pred) {
if (!succ || !pred) {
return;
}
for (const auto& inst_ptr : succ->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
RemoveIncomingForBlock(phi, pred);
}
}
inline void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
auto& instructions = block->GetInstructions();
if (instructions.empty() || !instructions.back()->IsTerminator()) {
return;
}
instructions.back()->ClearAllOperands();
auto branch = std::make_unique<UncondBrInst>(dest, nullptr);
branch->SetParent(block);
instructions.back() = std::move(branch);
}
inline bool SimplifyPhiInst(PhiInst* phi) {
if (!phi) {
return false;
}
Value* unique_value = nullptr;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* incoming = phi->GetIncomingValue(i);
if (incoming == phi) {
continue;
}
if (unique_value == nullptr) {
unique_value = incoming;
continue;
}
if (!AreEquivalentValues(unique_value, incoming)) {
return false;
}
}
if (unique_value == nullptr) {
return false;
}
auto* parent = phi->GetParent();
phi->ReplaceAllUsesWith(unique_value);
parent->EraseInstruction(phi);
return true;
}
inline void EraseBlock(Function& function, BasicBlock* block) {
if (!block) {
return;
}
auto& blocks = function.GetBlocks();
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[&](const std::unique_ptr<BasicBlock>& current) {
return current.get() == block;
}),
blocks.end());
}
inline bool RemoveUnreachableBlocks(Function& function) {
auto reachable = CollectReachableBlocks(function);
std::unordered_set<BasicBlock*> reachable_set(reachable.begin(), reachable.end());
std::vector<BasicBlock*> dead_blocks;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (reachable_set.find(block) == reachable_set.end()) {
dead_blocks.push_back(block);
}
}
if (dead_blocks.empty()) {
return false;
}
for (auto* block : dead_blocks) {
auto preds = block->GetPredecessors();
auto succs = block->GetSuccessors();
for (auto* succ : succs) {
RemoveIncomingFromSuccessor(succ, block);
succ->RemovePredecessor(block);
}
for (auto* pred : preds) {
pred->RemoveSuccessor(block);
}
}
for (auto* block : dead_blocks) {
for (const auto& inst_ptr : block->GetInstructions()) {
inst_ptr->ClearAllOperands();
}
}
for (auto* block : dead_blocks) {
EraseBlock(function, block);
}
return true;
}
inline bool IsCommutativeOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
case Opcode::Mul:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::FAdd:
case Opcode::FMul:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
return true;
default:
return false;
}
}
} // namespace ir::passutils

@ -0,0 +1,15 @@
add_library(irgen STATIC
IRGenDriver.cpp
IRGenFunc.cpp
IRGenStmt.cpp
IRGenExp.cpp
IRGenDecl.cpp
)
target_link_libraries(irgen PUBLIC
build_options
frontend
${ANTLR4_RUNTIME_TARGET}
ir
sem
)

@ -0,0 +1,670 @@
#include "irgen/IRGen.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <utility>
namespace {
std::vector<int> ExpandLinearIndex(const std::vector<int>& dims, size_t flat_index) {
std::vector<int> indices(dims.size(), 0);
for (size_t i = dims.size(); i > 0; --i) {
const auto dim_index = i - 1;
indices[dim_index] = static_cast<int>(flat_index % static_cast<size_t>(dims[dim_index]));
flat_index /= static_cast<size_t>(dims[dim_index]);
}
return indices;
}
} // namespace
std::string IRGenImpl::ExpectIdent(const antlr4::ParserRuleContext& ctx,
antlr4::tree::TerminalNode* ident) const {
if (ident == nullptr) {
ThrowError(&ctx, "?????");
}
return ident->getText();
}
SemanticType IRGenImpl::ParseBType(SysYParser::BTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? int/float ????");
}
SemanticType IRGenImpl::ParseFuncType(SysYParser::FuncTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "????????");
}
if (ctx->VOID()) {
return SemanticType::Void;
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? void/int/float ??????");
}
std::shared_ptr<ir::Type> IRGenImpl::GetIRScalarType(SemanticType type) const {
switch (type) {
case SemanticType::Void:
return ir::Type::GetVoidType();
case SemanticType::Int:
return ir::Type::GetInt32Type();
case SemanticType::Float:
return ir::Type::GetFloatType();
}
throw std::runtime_error("unknown semantic type");
}
std::shared_ptr<ir::Type> IRGenImpl::BuildArrayType(
SemanticType base_type, const std::vector<int>& dims) const {
auto type = GetIRScalarType(base_type);
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
type = ir::Type::GetArrayType(type, static_cast<size_t>(*it));
}
return type;
}
std::vector<int> IRGenImpl::ParseArrayDims(
const std::vector<SysYParser::ConstExpContext*>& dims_ctx) {
std::vector<int> dims;
dims.reserve(dims_ctx.size());
for (auto* dim_ctx : dims_ctx) {
if (dim_ctx == nullptr || dim_ctx->addExp() == nullptr) {
ThrowError(dim_ctx, "???????????");
}
auto dim = ConvertConst(EvalConstAddExp(*dim_ctx->addExp()), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(dim_ctx, "??????????");
}
dims.push_back(dim.int_value);
}
return dims;
}
std::vector<int> IRGenImpl::ParseParamDims(SysYParser::FuncFParamContext& ctx) {
std::vector<int> dims;
for (auto* exp_ctx : ctx.exp()) {
auto dim = ConvertConst(EvalConstExp(*exp_ctx), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(exp_ctx, "????????????");
}
dims.push_back(dim.int_value);
}
return dims;
}
void IRGenImpl::PredeclareGlobalDecl(SysYParser::DeclContext& ctx) {
auto declare_one = [&](const std::string& name, SemanticType type, bool is_const,
const std::vector<int>& dims, const antlr4::ParserRuleContext* node) {
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(node, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !dims.empty();
entry.is_param_array = false;
entry.dims = dims;
entry.ir_value = module_.CreateGlobalValue(
name, dims.empty() ? GetIRScalarType(type) : BuildArrayType(type, dims),
is_const, nullptr);
if (!symbols_.Insert(name, entry)) {
ThrowError(node, "????????: " + name);
}
};
if (ctx.constDecl() != nullptr) {
const auto type = ParseBType(ctx.constDecl()->bType());
for (auto* def : ctx.constDecl()->constDef()) {
const auto name = ExpectIdent(*def, def->Ident());
const auto dims = ParseArrayDims(def->constExp());
declare_one(name, type, true, dims, def);
auto* symbol = symbols_.Lookup(name);
if (symbol != nullptr && dims.empty()) {
symbol->const_scalar = ConvertConst(
EvalConstAddExp(*def->constInitVal()->constExp()->addExp()), type);
}
}
return;
}
if (ctx.varDecl() != nullptr) {
const auto type = ParseBType(ctx.varDecl()->bType());
for (auto* def : ctx.varDecl()->varDef()) {
declare_one(ExpectIdent(*def, def->Ident()), type, false,
ParseArrayDims(def->constExp()), def);
}
return;
}
ThrowError(&ctx, "????");
}
void IRGenImpl::EmitGlobalDecl(SysYParser::DeclContext& ctx) { EmitDecl(ctx, true); }
void IRGenImpl::EmitDecl(SysYParser::DeclContext& ctx, bool is_global) {
if (ctx.constDecl() != nullptr) {
EmitConstDecl(ctx.constDecl(), is_global);
return;
}
if (ctx.varDecl() != nullptr) {
EmitVarDecl(ctx.varDecl(), is_global, false);
return;
}
ThrowError(&ctx, "????");
}
void IRGenImpl::EmitVarDecl(SysYParser::VarDeclContext* ctx, bool is_global,
bool is_const) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->varDef()) {
if (is_global) {
EmitGlobalVarDef(*def, type);
} else {
EmitLocalVarDef(*def, type, is_const);
}
}
}
void IRGenImpl::EmitConstDecl(SysYParser::ConstDeclContext* ctx, bool is_global) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->constDef()) {
if (is_global) {
EmitGlobalConstDef(*def, type);
} else {
EmitLocalConstDef(*def, type);
}
}
}
void IRGenImpl::EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
}
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Variable;
symbol->type = type;
symbol->is_const = false;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
if (symbol->is_array) {
// Leave uninitialized globals as zeroinitializer instead of materializing
// an explicit all-zero constant array, which can explode memory usage.
if (ctx.initVal() == nullptr) {
global->SetInitializer(nullptr);
return;
}
if (IsExplicitZeroInitVal(ctx.initVal(), type)) {
global->SetInitializer(nullptr);
return;
}
auto flat = FlattenInitVal(ctx.initVal(), type, symbol->dims);
std::vector<ir::Value*> elements;
elements.reserve(flat.size());
for (const auto& value : flat) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
ConstantValue init = ZeroConst(type);
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init = ConvertConst(EvalConstExp(*ctx.initVal()->exp()), type);
}
global->SetInitializer(CreateTypedConstant(init));
}
}
void IRGenImpl::EmitGlobalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
}
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Constant;
symbol->type = type;
symbol->is_const = true;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
global->SetConstant(true);
if (symbol->is_array) {
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
global->SetInitializer(nullptr);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
std::vector<ir::Value*> elements;
elements.reserve(symbol->const_array.size());
for (const auto& value : symbol->const_array) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
global->SetInitializer(CreateTypedConstant(init));
}
}
ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr<ir::Type> allocated_type,
const std::string& name) {
if (current_function_ == nullptr || current_function_->GetEntryBlock() == nullptr) {
throw std::runtime_error("CreateEntryAlloca requires an active function entry block");
}
auto* entry = current_function_->GetEntryBlock();
size_t insert_pos = 0;
for (const auto& inst : entry->GetInstructions()) {
if (!ir::isa<ir::AllocaInst>(inst.get())) {
break;
}
++insert_pos;
}
return entry->Insert<ir::AllocaInst>(insert_pos, std::move(allocated_type), nullptr,
name);
}
void IRGenImpl::EmitLocalVarDef(SysYParser::VarDefContext& ctx, SemanticType type,
bool is_const) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
if (entry.is_array) {
entry.ir_value = CreateEntryAlloca(BuildArrayType(type, entry.dims),
NextTemp());
} else {
entry.ir_value = CreateEntryAlloca(GetIRScalarType(type),
NextTemp());
}
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
}
auto* symbol = symbols_.Lookup(name);
if (!entry.is_array) {
TypedValue init_value{ZeroIRValue(type), type, false, {}};
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init_value = CastScalar(EmitExp(*ctx.initVal()->exp()), type, ctx.initVal());
}
builder_.CreateStore(init_value.value, symbol->ir_value);
return;
}
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
if (ctx.initVal() != nullptr) {
auto init_slots = FlattenLocalInitVal(ctx.initVal(), symbol->dims);
StoreLocalArrayElements(symbol->ir_value, type, symbol->dims, init_slots);
}
}
void IRGenImpl::EmitLocalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = SymbolKind::Constant;
entry.type = type;
entry.is_const = true;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
entry.ir_value = CreateEntryAlloca(
entry.is_array ? BuildArrayType(type, entry.dims) : GetIRScalarType(type),
NextTemp());
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
}
auto* symbol = symbols_.Lookup(name);
if (!entry.is_array) {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
builder_.CreateStore(CreateTypedConstant(init), symbol->ir_value);
return;
}
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
for (size_t i = 0; i < symbol->const_array.size(); ++i) {
if (symbol->const_array[i].type == SemanticType::Int && symbol->const_array[i].int_value == 0) {
continue;
}
if (symbol->const_array[i].type == SemanticType::Float &&
symbol->const_array[i].float_value == 0.0f) {
continue;
}
const auto indices = ExpandLinearIndex(symbol->dims, i);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* addr = CreateArrayElementAddr(symbol->ir_value, false, type, symbol->dims,
index_values, &ctx);
builder_.CreateStore(CreateTypedConstant(symbol->const_array[i]), addr);
}
}
std::vector<ConstantValue> IRGenImpl::FlattenConstInitVal(
SysYParser::ConstInitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenConstInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<ConstantValue> IRGenImpl::FlattenInitVal(
SysYParser::InitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<IRGenImpl::InitExprSlot> IRGenImpl::FlattenLocalInitVal(
SysYParser::InitValContext* ctx, const std::vector<int>& dims) {
std::vector<InitExprSlot> out;
if (ctx != nullptr) {
size_t cursor = 0;
FlattenLocalInitValImpl(ctx, dims, 0, 0, CountArrayElements(dims), cursor, out);
}
return out;
}
void IRGenImpl::FlattenConstInitValImpl(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->constExp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type);
return;
}
for (auto* child : ctx->constInitVal()) {
if (cursor >= object_end) {
break;
}
if (child->constExp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenConstInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenConstInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenInitValImpl(SysYParser::InitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstExp(*ctx->exp()), base_type);
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenLocalInitValImpl(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<InitExprSlot>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out.push_back({cursor++, ctx->exp()});
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenLocalInitValImpl(child, dims, depth + 1, child_begin, child_end,
cursor, out);
cursor = child_end;
} else {
FlattenLocalInitValImpl(child, dims, depth + 1, object_begin, object_end,
cursor, out);
}
}
}
size_t IRGenImpl::CountArrayElements(const std::vector<int>& dims, size_t start) const {
size_t count = 1;
for (size_t i = start; i < dims.size(); ++i) {
count *= static_cast<size_t>(dims[i]);
}
return count;
}
size_t IRGenImpl::AlignInitializerCursor(const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t cursor) const {
if (depth + 1 >= dims.size()) {
return cursor;
}
const auto stride = CountArrayElements(dims, depth + 1);
const auto relative = cursor - object_begin;
return object_begin + ((relative + stride - 1) / stride) * stride;
}
size_t IRGenImpl::FlattenIndices(const std::vector<int>& dims,
const std::vector<int>& indices) const {
size_t offset = 0;
for (size_t i = 0; i < dims.size(); ++i) {
offset *= static_cast<size_t>(dims[i]);
offset += static_cast<size_t>(indices[i]);
}
return offset;
}
bool IRGenImpl::IsZeroConstant(const ConstantValue& value) const {
switch (value.type) {
case SemanticType::Int:
return value.int_value == 0;
case SemanticType::Float: {
std::uint32_t bits = 0;
std::memcpy(&bits, &value.float_value, sizeof(bits));
return bits == 0;
}
case SemanticType::Void:
return false;
}
return false;
}
bool IRGenImpl::IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->constExp() != nullptr) {
return IsZeroConstant(
ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type));
}
for (auto* child : ctx->constInitVal()) {
if (!IsExplicitZeroConstInitVal(child, base_type)) {
return false;
}
}
return true;
}
bool IRGenImpl::IsExplicitZeroInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->exp() != nullptr) {
return IsZeroConstant(ConvertConst(EvalConstExp(*ctx->exp()), base_type));
}
for (auto* child : ctx->initVal()) {
if (!IsExplicitZeroInitVal(child, base_type)) {
return false;
}
}
return true;
}
ConstantValue IRGenImpl::ZeroConst(SemanticType type) const {
ConstantValue value;
value.type = type;
value.int_value = 0;
value.float_value = 0.0f;
return value;
}
ir::Value* IRGenImpl::ZeroIRValue(SemanticType type) {
switch (type) {
case SemanticType::Int:
return builder_.CreateConstInt(0);
case SemanticType::Float:
return builder_.CreateConstFloat(0.0f);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no zero IR value");
}
ir::Value* IRGenImpl::CreateTypedConstant(const ConstantValue& value) {
switch (value.type) {
case SemanticType::Int:
return builder_.CreateConstInt(value.int_value);
case SemanticType::Float:
return builder_.CreateConstFloat(value.float_value);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no constant value");
}
void IRGenImpl::ZeroInitializeLocalArray(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims) {
const auto elem_count = CountArrayElements(dims);
int bytes = static_cast<int>(elem_count * (base_type == SemanticType::Float ? 4 : 4));
builder_.CreateMemset(addr, builder_.CreateConstInt(0), builder_.CreateConstInt(bytes),
builder_.CreateConstBool(false));
}
void IRGenImpl::StoreLocalArrayElements(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims,
const std::vector<InitExprSlot>& init_slots) {
for (const auto& slot : init_slots) {
const auto indices = ExpandLinearIndex(dims, slot.index);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* elem_addr = CreateArrayElementAddr(addr, false, base_type, dims,
index_values, slot.expr);
auto value = CastScalar(EmitExp(*slot.expr), base_type, slot.expr);
builder_.CreateStore(value.value, elem_addr);
}
}

@ -0,0 +1,11 @@
#include "irgen/IRGen.h"
#include <memory>
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>();
IRGenImpl gen(*module, sema);
tree.accept(&gen);
return module;
}

@ -0,0 +1,696 @@
#include "irgen/IRGen.h"
#include <cstdlib>
#include <stdexcept>
#include <utility>
bool IRGenImpl::IsNumeric(const TypedValue& value) const {
return !value.is_array && value.type != SemanticType::Void;
}
bool IRGenImpl::IsSameDims(const std::vector<int>& lhs,
const std::vector<int>& rhs) const {
if (rhs.empty()) {
return true;
}
if (lhs == rhs) {
return true;
}
if (lhs.size() == rhs.size() + 1) {
return std::equal(lhs.begin() + 1, lhs.end(), rhs.begin());
}
return false;
}
IRGenImpl::TypedValue IRGenImpl::CastScalar(
TypedValue value, SemanticType target_type,
const antlr4::ParserRuleContext* ctx) {
if (value.is_array) {
ThrowError(ctx, "????????????");
}
if (target_type == SemanticType::Void || value.type == SemanticType::Void) {
ThrowError(ctx, "void ?????????");
}
if (target_type == SemanticType::Int) {
if (value.type == SemanticType::Int) {
if (value.value->GetType()->IsInt1()) {
value.value = builder_.CreateZext(value.value, ir::Type::GetInt32Type(), NextTemp());
}
value.type = SemanticType::Int;
return value;
}
value.value = builder_.CreateFtoI(value.value, NextTemp());
value.type = SemanticType::Int;
return value;
}
if (target_type == SemanticType::Float) {
if (value.type == SemanticType::Float) {
return value;
}
if (value.value->GetType()->IsInt1()) {
value.value = builder_.CreateZext(value.value, ir::Type::GetInt32Type(), NextTemp());
}
value.value = builder_.CreateIToF(value.value, NextTemp());
value.type = SemanticType::Float;
return value;
}
ThrowError(ctx, "????????????");
}
ir::Value* IRGenImpl::CastToCondition(TypedValue value,
const antlr4::ParserRuleContext* ctx) {
if (value.is_array) {
ThrowError(ctx, "?????????");
}
if (value.type == SemanticType::Void) {
ThrowError(ctx, "void ???????");
}
if (value.value->GetType()->IsInt1()) {
return value.value;
}
if (value.type == SemanticType::Int) {
return builder_.CreateICmp(ir::Opcode::ICmpNE, value.value, builder_.CreateConstInt(0),
NextTemp());
}
return builder_.CreateFCmp(ir::Opcode::FCmpNE, value.value,
builder_.CreateConstFloat(0.0f), NextTemp());
}
IRGenImpl::TypedValue IRGenImpl::NormalizeLogicalValue(
TypedValue value, const antlr4::ParserRuleContext* ctx) {
auto* cond = CastToCondition(value, ctx);
return {builder_.CreateZext(cond, ir::Type::GetInt32Type(), NextTemp()),
SemanticType::Int, false, {}};
}
ConstantValue IRGenImpl::ParseNumber(SysYParser::NumberContext& ctx) const {
ConstantValue value;
if (ctx.IntConst() != nullptr) {
value.type = SemanticType::Int;
value.int_value = std::stoi(ctx.getText(), nullptr, 0);
value.float_value = static_cast<float>(value.int_value);
return value;
}
if (ctx.FloatConst() != nullptr) {
value.type = SemanticType::Float;
value.float_value = std::strtof(ctx.getText().c_str(), nullptr);
value.int_value = static_cast<int>(value.float_value);
return value;
}
ThrowError(&ctx, "?????????");
}
ConstantValue IRGenImpl::ConvertConst(ConstantValue value,
SemanticType target_type) const {
if (target_type == SemanticType::Void) {
throw std::runtime_error("void is not a valid constant target type");
}
if (value.type == target_type) {
return value;
}
if (target_type == SemanticType::Int) {
value.int_value = static_cast<int>(value.float_value);
value.type = SemanticType::Int;
return value;
}
value.float_value = static_cast<float>(value.int_value);
value.type = SemanticType::Float;
return value;
}
ConstantValue IRGenImpl::EvalConstExp(SysYParser::ExpContext& ctx) {
return EvalConstAddExp(*ctx.addExp());
}
ConstantValue IRGenImpl::EvalConstAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() == nullptr) {
return EvalConstMulExp(*ctx.mulExp());
}
auto lhs = EvalConstAddExp(*ctx.addExp());
auto rhs = EvalConstMulExp(*ctx.mulExp());
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = ConvertConst(lhs, SemanticType::Float);
rhs = ConvertConst(rhs, SemanticType::Float);
lhs.float_value = ctx.op->getType() == SysYParser::ADD
? lhs.float_value + rhs.float_value
: lhs.float_value - rhs.float_value;
lhs.int_value = static_cast<int>(lhs.float_value);
lhs.type = SemanticType::Float;
return lhs;
}
lhs.int_value = ctx.op->getType() == SysYParser::ADD ? lhs.int_value + rhs.int_value
: lhs.int_value - rhs.int_value;
lhs.float_value = static_cast<float>(lhs.int_value);
lhs.type = SemanticType::Int;
return lhs;
}
ConstantValue IRGenImpl::EvalConstMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() == nullptr) {
return EvalConstUnaryExp(*ctx.unaryExp());
}
auto lhs = EvalConstMulExp(*ctx.mulExp());
auto rhs = EvalConstUnaryExp(*ctx.unaryExp());
if (ctx.op->getType() == SysYParser::MOD &&
(lhs.type == SemanticType::Float || rhs.type == SemanticType::Float)) {
ThrowError(&ctx, "?????? % ??");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = ConvertConst(lhs, SemanticType::Float);
rhs = ConvertConst(rhs, SemanticType::Float);
switch (ctx.op->getType()) {
case SysYParser::MUL:
lhs.float_value *= rhs.float_value;
break;
case SysYParser::DIV:
lhs.float_value /= rhs.float_value;
break;
default:
ThrowError(&ctx, "??????????");
}
lhs.int_value = static_cast<int>(lhs.float_value);
lhs.type = SemanticType::Float;
return lhs;
}
switch (ctx.op->getType()) {
case SysYParser::MUL:
lhs.int_value *= rhs.int_value;
break;
case SysYParser::DIV:
lhs.int_value /= rhs.int_value;
break;
case SysYParser::MOD:
lhs.int_value %= rhs.int_value;
break;
default:
ThrowError(&ctx, "????????");
}
lhs.float_value = static_cast<float>(lhs.int_value);
lhs.type = SemanticType::Int;
return lhs;
}
ConstantValue IRGenImpl::EvalConstUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return EvalConstPrimaryExp(*ctx.primaryExp());
}
if (ctx.Ident() != nullptr) {
ThrowError(&ctx, "?????????????");
}
auto operand = EvalConstUnaryExp(*ctx.unaryExp());
if (ctx.unaryOp()->ADD() != nullptr) {
return operand;
}
if (ctx.unaryOp()->SUB() != nullptr) {
if (operand.type == SemanticType::Float) {
operand.float_value = -operand.float_value;
operand.int_value = static_cast<int>(operand.float_value);
} else {
operand.int_value = -operand.int_value;
operand.float_value = static_cast<float>(operand.int_value);
}
return operand;
}
if (ctx.unaryOp()->NOT() != nullptr) {
const bool truthy = operand.type == SemanticType::Float ? operand.float_value != 0.0f
: operand.int_value != 0;
ConstantValue result;
result.type = SemanticType::Int;
result.int_value = truthy ? 0 : 1;
result.float_value = static_cast<float>(result.int_value);
return result;
}
ThrowError(&ctx, "???????");
}
ConstantValue IRGenImpl::EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return EvalConstExp(*ctx.exp());
}
if (ctx.number() != nullptr) {
return ParseNumber(*ctx.number());
}
if (ctx.lVal() != nullptr) {
return EvalConstLVal(*ctx.lVal());
}
ThrowError(&ctx, "???? primaryExp");
}
ConstantValue IRGenImpl::EvalConstLVal(SysYParser::LValContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr) {
ThrowError(&ctx, "?????: " + name);
}
if (!symbol->is_const) {
ThrowError(&ctx, "?????????????: " + name);
}
if (!symbol->is_array) {
if (!ctx.exp().empty()) {
ThrowError(&ctx, "??????????");
}
if (!symbol->const_scalar.has_value()) {
ThrowError(&ctx, "?????: " + name);
}
return *symbol->const_scalar;
}
if (ctx.exp().size() != symbol->dims.size()) {
ThrowError(&ctx, "???????????????????: " + name);
}
std::vector<int> indices;
indices.reserve(ctx.exp().size());
for (auto* exp_ctx : ctx.exp()) {
auto index = ConvertConst(EvalConstExp(*exp_ctx), SemanticType::Int);
indices.push_back(index.int_value);
}
for (size_t i = 0; i < indices.size(); ++i) {
if (indices[i] < 0 || indices[i] >= symbol->dims[i]) {
ThrowError(&ctx, "????????: " + name);
}
}
const auto offset = FlattenIndices(symbol->dims, indices);
if (symbol->const_array_all_zero) {
return ZeroConst(symbol->type);
}
if (offset >= symbol->const_array.size()) {
ThrowError(&ctx, "???????????: " + name);
}
return symbol->const_array[offset];
}
IRGenImpl::TypedValue IRGenImpl::EmitExp(SysYParser::ExpContext& ctx) {
return EmitAddExp(*ctx.addExp());
}
IRGenImpl::TypedValue IRGenImpl::EmitAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() == nullptr) {
return EmitMulExp(*ctx.mulExp());
}
auto lhs = EmitAddExp(*ctx.addExp());
auto rhs = EmitMulExp(*ctx.mulExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
auto* value = ctx.op->getType() == SysYParser::ADD
? builder_.CreateBinary(ir::Opcode::FAdd, lhs.value, rhs.value,
NextTemp())
: builder_.CreateBinary(ir::Opcode::FSub, lhs.value, rhs.value,
NextTemp());
return {value, SemanticType::Float, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
auto* value = ctx.op->getType() == SysYParser::ADD
? builder_.CreateAdd(lhs.value, rhs.value, NextTemp())
: builder_.CreateSub(lhs.value, rhs.value, NextTemp());
return {value, SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() == nullptr) {
return EmitUnaryExp(*ctx.unaryExp());
}
auto lhs = EmitMulExp(*ctx.mulExp());
auto rhs = EmitUnaryExp(*ctx.unaryExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "????????????");
}
if (ctx.op->getType() == SysYParser::MOD &&
(lhs.type == SemanticType::Float || rhs.type == SemanticType::Float)) {
ThrowError(&ctx, "?????? % ??");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
ir::Opcode opcode = ir::Opcode::FMul;
if (ctx.op->getType() == SysYParser::DIV) {
opcode = ir::Opcode::FDiv;
} else if (ctx.op->getType() == SysYParser::MUL) {
opcode = ir::Opcode::FMul;
} else {
ThrowError(&ctx, "?????????");
}
return {builder_.CreateBinary(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Float, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
ir::Value* value = nullptr;
switch (ctx.op->getType()) {
case SysYParser::MUL:
value = builder_.CreateMul(lhs.value, rhs.value, NextTemp());
break;
case SysYParser::DIV:
value = builder_.CreateDiv(lhs.value, rhs.value, NextTemp());
break;
case SysYParser::MOD:
value = builder_.CreateRem(lhs.value, rhs.value, NextTemp());
break;
default:
ThrowError(&ctx, "???????");
}
return {value, SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return EmitPrimaryExp(*ctx.primaryExp());
}
if (ctx.Ident() != nullptr) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || symbol->kind != SymbolKind::Function || symbol->function == nullptr) {
ThrowError(&ctx, "????????: " + name);
}
const auto& function_type = symbol->function_type;
std::vector<ir::Value*> args;
std::vector<SysYParser::ExpContext*> arg_exprs;
if (ctx.funcRParams() != nullptr) {
arg_exprs = ctx.funcRParams()->exp();
}
if (arg_exprs.size() != function_type.param_types.size()) {
ThrowError(&ctx, "?????????: " + name);
}
for (size_t i = 0; i < arg_exprs.size(); ++i) {
auto arg = EmitExp(*arg_exprs[i]);
if (i < function_type.param_is_array.size() && function_type.param_is_array[i]) {
if (!arg.is_array || !IsSameDims(arg.dims, function_type.param_dims[i])) {
ThrowError(arg_exprs[i], "????????????: " + name);
}
args.push_back(arg.value);
} else {
if (arg.is_array) {
ThrowError(arg_exprs[i], "??????????: " + name);
}
arg = CastScalar(arg, function_type.param_types[i], arg_exprs[i]);
args.push_back(arg.value);
}
}
if (function_type.return_type == SemanticType::Void) {
builder_.CreateCall(symbol->function, args);
return {nullptr, SemanticType::Void, false, {}};
}
return {builder_.CreateCall(symbol->function, args, NextTemp()),
function_type.return_type, false, {}};
}
auto operand = EmitUnaryExp(*ctx.unaryExp());
if (!IsNumeric(operand)) {
ThrowError(&ctx, "???????????");
}
if (ctx.unaryOp()->ADD() != nullptr) {
return operand;
}
if (ctx.unaryOp()->SUB() != nullptr) {
if (operand.type == SemanticType::Float) {
return {builder_.CreateFNeg(operand.value, NextTemp()), SemanticType::Float,
false, {}};
}
operand = CastScalar(operand, SemanticType::Int, &ctx);
return {builder_.CreateSub(builder_.CreateConstInt(0), operand.value, NextTemp()),
SemanticType::Int, false, {}};
}
if (ctx.unaryOp()->NOT() != nullptr) {
auto* cond = CastToCondition(operand, &ctx);
auto* inverted = builder_.CreateXor(cond, builder_.CreateConstBool(true), NextTemp());
return {builder_.CreateZext(inverted, ir::Type::GetInt32Type(), NextTemp()),
SemanticType::Int, false, {}};
}
ThrowError(&ctx, "???????");
}
IRGenImpl::TypedValue IRGenImpl::EmitPrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return EmitExp(*ctx.exp());
}
if (ctx.number() != nullptr) {
auto number = ParseNumber(*ctx.number());
return {CreateTypedConstant(number), number.type, false, {}};
}
if (ctx.lVal() != nullptr) {
return EmitLValValue(*ctx.lVal());
}
ThrowError(&ctx, "?? primaryExp");
}
IRGenImpl::TypedValue IRGenImpl::EmitRelExp(SysYParser::RelExpContext& ctx) {
if (ctx.relExp() == nullptr) {
return EmitAddExp(*ctx.addExp());
}
auto lhs = EmitRelExp(*ctx.relExp());
auto rhs = EmitAddExp(*ctx.addExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
ir::Opcode opcode = ir::Opcode::FCmpLT;
switch (ctx.op->getType()) {
case SysYParser::LT:
opcode = ir::Opcode::FCmpLT;
break;
case SysYParser::GT:
opcode = ir::Opcode::FCmpGT;
break;
case SysYParser::LE:
opcode = ir::Opcode::FCmpLE;
break;
case SysYParser::GE:
opcode = ir::Opcode::FCmpGE;
break;
default:
ThrowError(&ctx, "????????");
}
return {builder_.CreateFCmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
ir::Opcode opcode = ir::Opcode::ICmpLT;
switch (ctx.op->getType()) {
case SysYParser::LT:
opcode = ir::Opcode::ICmpLT;
break;
case SysYParser::GT:
opcode = ir::Opcode::ICmpGT;
break;
case SysYParser::LE:
opcode = ir::Opcode::ICmpLE;
break;
case SysYParser::GE:
opcode = ir::Opcode::ICmpGE;
break;
default:
ThrowError(&ctx, "????????");
}
return {builder_.CreateICmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitEqExp(SysYParser::EqExpContext& ctx) {
if (ctx.eqExp() == nullptr) {
return EmitRelExp(*ctx.relExp());
}
auto lhs = EmitEqExp(*ctx.eqExp());
auto rhs = EmitRelExp(*ctx.relExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
const auto opcode = ctx.op->getType() == SysYParser::EQ ? ir::Opcode::FCmpEQ
: ir::Opcode::FCmpNE;
return {builder_.CreateFCmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
const auto opcode = ctx.op->getType() == SysYParser::EQ ? ir::Opcode::ICmpEQ
: ir::Opcode::ICmpNE;
return {builder_.CreateICmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
IRGenImpl::LValueInfo IRGenImpl::ResolveLVal(SysYParser::LValContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr) {
ThrowError(&ctx, "?????????: " + name);
}
if (symbol->kind == SymbolKind::Function) {
ThrowError(&ctx, "????????: " + name);
}
std::vector<ir::Value*> index_values;
index_values.reserve(ctx.exp().size());
for (auto* exp_ctx : ctx.exp()) {
auto index = CastScalar(EmitExp(*exp_ctx), SemanticType::Int, exp_ctx);
if (index.is_array) {
ThrowError(exp_ctx, "????????");
}
index_values.push_back(index.value);
}
if (!symbol->is_array) {
if (!index_values.empty()) {
ThrowError(&ctx, "??????????: " + name);
}
return {symbol, symbol->ir_value, symbol->type, false, {}, false};
}
std::vector<int> selected_dims;
if (symbol->is_param_array) {
if (index_values.size() > symbol->dims.size() + 1) {
ThrowError(&ctx, "????????: " + name);
}
if (index_values.empty()) {
selected_dims = symbol->dims;
} else if (index_values.size() <= 1) {
selected_dims = symbol->dims;
} else {
selected_dims.assign(symbol->dims.begin() + static_cast<long long>(index_values.size() - 1),
symbol->dims.end());
}
} else {
if (index_values.size() > symbol->dims.size()) {
ThrowError(&ctx, "??????: " + name);
}
selected_dims.assign(symbol->dims.begin() + static_cast<long long>(index_values.size()),
symbol->dims.end());
}
ir::Value* addr = symbol->ir_value;
if (!index_values.empty()) {
addr = CreateArrayElementAddr(symbol->ir_value, symbol->is_param_array, symbol->type,
symbol->dims, index_values, &ctx);
}
const bool root_param_array_no_index = symbol->is_param_array && index_values.empty();
const bool still_array = !selected_dims.empty() || root_param_array_no_index;
return {symbol, addr, symbol->type, still_array, selected_dims,
root_param_array_no_index};
}
ir::Value* IRGenImpl::GenLValAddr(SysYParser::LValContext& ctx) {
auto info = ResolveLVal(ctx);
if (info.is_array) {
ThrowError(&ctx, "?????????????");
}
return info.addr;
}
IRGenImpl::TypedValue IRGenImpl::EmitLValValue(SysYParser::LValContext& ctx) {
auto info = ResolveLVal(ctx);
if (!info.is_array) {
if (info.symbol != nullptr && info.symbol->const_scalar.has_value()) {
return {CreateTypedConstant(*info.symbol->const_scalar), info.type, false, {}};
}
return {builder_.CreateLoad(info.addr, GetIRScalarType(info.type), NextTemp()),
info.type, false, {}};
}
if (info.root_param_array_no_index) {
return {info.addr, info.type, true, info.dims};
}
auto* decayed = builder_.CreateGEP(info.addr, BuildArrayType(info.type, info.dims),
{builder_.CreateConstInt(0), builder_.CreateConstInt(0)},
NextTemp());
std::vector<int> decay_dims;
if (!info.dims.empty()) {
decay_dims.assign(info.dims.begin() + 1, info.dims.end());
}
return {decayed, info.type, true, decay_dims};
}
ir::Value* IRGenImpl::CreateArrayElementAddr(
ir::Value* base_addr, bool is_param_array, SemanticType base_type,
const std::vector<int>& dims, const std::vector<ir::Value*>& indices,
const antlr4::ParserRuleContext* ctx) {
if (base_addr == nullptr) {
ThrowError(ctx, "???????");
}
if (indices.empty()) {
return base_addr;
}
std::vector<ir::Value*> gep_indices;
if (!is_param_array) {
gep_indices.push_back(builder_.CreateConstInt(0));
}
gep_indices.insert(gep_indices.end(), indices.begin(), indices.end());
auto source_type = dims.empty() ? GetIRScalarType(base_type) : BuildArrayType(base_type, dims);
return builder_.CreateGEP(base_addr, source_type, gep_indices, NextTemp());
}
void IRGenImpl::EmitCond(SysYParser::CondContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
EmitLOrCond(*ctx.lOrExp(), true_block, false_block);
}
void IRGenImpl::EmitLOrCond(SysYParser::LOrExpContext& ctx,
ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
if (ctx.lOrExp() == nullptr) {
EmitLAndCond(*ctx.lAndExp(), true_block, false_block);
return;
}
auto* rhs_block = current_function_->CreateBlock(NextBlockName("lor.rhs"));
EmitLOrCond(*ctx.lOrExp(), true_block, rhs_block);
builder_.SetInsertPoint(rhs_block);
EmitLAndCond(*ctx.lAndExp(), true_block, false_block);
}
void IRGenImpl::EmitLAndCond(SysYParser::LAndExpContext& ctx,
ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
if (ctx.lAndExp() == nullptr) {
auto cond = EmitEqExp(*ctx.eqExp());
builder_.CreateCondBr(CastToCondition(cond, &ctx), true_block, false_block);
return;
}
auto* rhs_block = current_function_->CreateBlock(NextBlockName("land.rhs"));
EmitLAndCond(*ctx.lAndExp(), rhs_block, false_block);
builder_.SetInsertPoint(rhs_block);
auto cond = EmitEqExp(*ctx.eqExp());
builder_.CreateCondBr(CastToCondition(cond, &ctx), true_block, false_block);
}

@ -0,0 +1,268 @@
#include "irgen/IRGen.h"
#include <stdexcept>
#include <utility>
#include "utils/Log.h"
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), sema_(sema), builder_(module.GetContext(), nullptr) {}
[[noreturn]] void IRGenImpl::ThrowError(
const antlr4::ParserRuleContext* ctx, const std::string& message) const {
if (ctx != nullptr && ctx->getStart() != nullptr) {
throw std::runtime_error(FormatErrorAt("irgen",
static_cast<size_t>(ctx->getStart()->getLine()),
static_cast<size_t>(ctx->getStart()->getCharPositionInLine() + 1),
message));
}
throw std::runtime_error(FormatError("irgen", message));
}
std::string IRGenImpl::NextTemp() { return module_.GetContext().NextTemp(); }
std::string IRGenImpl::NextBlockName(const std::string& prefix) {
return module_.GetContext().NextBlockName(prefix);
}
void IRGenImpl::ApplyFunctionSema(const std::string& name, ir::Function& function) {
const auto* info = sema_.LookupFunction(name);
if (info == nullptr) {
return;
}
function.SetEffectInfo(info->reads_global_memory, info->writes_global_memory,
info->reads_param_memory, info->writes_param_memory,
info->has_io, info->has_unknown_effects,
info->is_recursive);
}
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
symbols_.Clear();
symbols_.EnterScope();
RegisterBuiltinFunctions();
PredeclareTopLevel(*ctx);
for (auto* child : ctx->children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
EmitGlobalDecl(*decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
EmitFunction(*func);
}
}
symbols_.ExitScope();
return {};
}
void IRGenImpl::RegisterBuiltinFunctions() {
if (builtins_registered_) {
return;
}
struct BuiltinSpec {
const char* name;
SemanticType return_type;
std::vector<SemanticType> param_types;
std::vector<bool> param_is_array;
std::vector<std::vector<int>> param_dims;
};
const std::vector<BuiltinSpec> builtins = {
{"getint", SemanticType::Int, {}, {}, {}},
{"getch", SemanticType::Int, {}, {}, {}},
{"getfloat", SemanticType::Float, {}, {}, {}},
{"getarray", SemanticType::Int, {SemanticType::Int}, {true}, {std::vector<int>{}}},
{"getfarray", SemanticType::Int, {SemanticType::Float}, {true}, {std::vector<int>{}}},
{"putint", SemanticType::Void, {SemanticType::Int}, {false}, {std::vector<int>{}}},
{"putch", SemanticType::Void, {SemanticType::Int}, {false}, {std::vector<int>{}}},
{"putfloat", SemanticType::Void, {SemanticType::Float}, {false}, {std::vector<int>{}}},
{"putarray", SemanticType::Void, {SemanticType::Int, SemanticType::Int}, {false, true}, {std::vector<int>{}, std::vector<int>{}}},
{"putfarray", SemanticType::Void, {SemanticType::Int, SemanticType::Float}, {false, true}, {std::vector<int>{}, std::vector<int>{}}},
{"starttime", SemanticType::Void, {}, {}, {}},
{"stoptime", SemanticType::Void, {}, {}, {}},
};
for (const auto& builtin : builtins) {
FunctionTypeInfo function_type;
function_type.return_type = builtin.return_type;
function_type.param_types = builtin.param_types;
function_type.param_is_array = builtin.param_is_array;
function_type.param_dims = builtin.param_dims;
std::vector<std::shared_ptr<ir::Type>> ir_param_types;
std::vector<std::string> ir_param_names;
for (size_t i = 0; i < builtin.param_types.size(); ++i) {
if (i < builtin.param_is_array.size() && builtin.param_is_array[i]) {
ir_param_types.push_back(ir::Type::GetPointerType());
} else {
ir_param_types.push_back(GetIRScalarType(builtin.param_types[i]));
}
ir_param_names.push_back("%arg" + std::to_string(i));
}
auto* function = module_.CreateFunction(
builtin.name, GetIRScalarType(builtin.return_type), ir_param_types,
ir_param_names, true);
ApplyFunctionSema(builtin.name, *function);
SymbolEntry entry;
entry.kind = SymbolKind::Function;
entry.type = builtin.return_type;
entry.function = function;
entry.ir_value = function;
entry.function_type = std::move(function_type);
symbols_.Insert(builtin.name, entry);
}
builtins_registered_ = true;
}
void IRGenImpl::PredeclareTopLevel(SysYParser::CompUnitContext& ctx) {
for (auto* child : ctx.children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
PredeclareGlobalDecl(*decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
PredeclareFunction(*func);
}
}
}
FunctionTypeInfo IRGenImpl::BuildFunctionTypeInfo(
SysYParser::FuncDefContext& ctx) {
FunctionTypeInfo function_type;
function_type.return_type = ParseFuncType(ctx.funcType());
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
const auto type = ParseBType(param->bType());
const auto dims = param->LBRACK().empty() ? std::vector<int>{} : ParseParamDims(*param);
function_type.param_types.push_back(type);
function_type.param_is_array.push_back(!param->LBRACK().empty());
function_type.param_dims.push_back(dims);
}
}
return function_type;
}
std::vector<std::shared_ptr<ir::Type>> IRGenImpl::BuildFunctionIRParamTypes(
const FunctionTypeInfo& function_type) const {
std::vector<std::shared_ptr<ir::Type>> param_types;
for (size_t i = 0; i < function_type.param_types.size(); ++i) {
if (i < function_type.param_is_array.size() && function_type.param_is_array[i]) {
param_types.push_back(ir::Type::GetPointerType());
} else {
param_types.push_back(GetIRScalarType(function_type.param_types[i]));
}
}
return param_types;
}
std::vector<std::string> IRGenImpl::BuildFunctionIRParamNames(
SysYParser::FuncDefContext& ctx) const {
std::vector<std::string> param_names;
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
param_names.push_back("%" + ExpectIdent(*param, param->Ident()));
}
}
return param_names;
}
void IRGenImpl::PredeclareFunction(SysYParser::FuncDefContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
auto function_type = BuildFunctionTypeInfo(ctx);
auto* function = module_.CreateFunction(
name, GetIRScalarType(function_type.return_type),
BuildFunctionIRParamTypes(function_type), BuildFunctionIRParamNames(ctx), false);
ApplyFunctionSema(name, *function);
SymbolEntry entry;
entry.kind = SymbolKind::Function;
entry.type = function_type.return_type;
entry.function = function;
entry.ir_value = function;
entry.function_type = std::move(function_type);
symbols_.Insert(name, entry);
}
void IRGenImpl::EmitFunction(SysYParser::FuncDefContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || symbol->kind != SymbolKind::Function || symbol->function == nullptr) {
ThrowError(&ctx, "????????: " + name);
}
current_function_ = symbol->function;
current_return_type_ = symbol->function_type.return_type;
current_function_->SetExternal(false);
auto* entry_block = current_function_->EnsureEntryBlock();
builder_.SetInsertPoint(entry_block);
symbols_.EnterScope();
BindFunctionParams(ctx, *current_function_);
EmitBlock(*ctx.block(), false);
symbols_.ExitScope();
auto* insert_block = builder_.GetInsertBlock();
if (insert_block != nullptr && !insert_block->HasTerminator()) {
if (current_return_type_ == SemanticType::Void) {
builder_.CreateRet();
} else {
builder_.CreateRet(ZeroIRValue(current_return_type_));
}
}
builder_.SetInsertPoint(nullptr);
current_function_ = nullptr;
current_return_type_ = SemanticType::Void;
loop_stack_.clear();
}
void IRGenImpl::BindFunctionParams(SysYParser::FuncDefContext& ctx, ir::Function& func) {
if (ctx.funcFParams() == nullptr) {
return;
}
const auto& params = ctx.funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
auto* param = params[i];
const auto name = ExpectIdent(*param, param->Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(param, "??????: " + name);
}
SymbolEntry entry;
entry.kind = SymbolKind::Variable;
entry.type = ParseBType(param->bType());
entry.is_const = false;
entry.is_array = !param->LBRACK().empty();
entry.is_param_array = entry.is_array;
entry.dims = entry.is_array ? ParseParamDims(*param) : std::vector<int>{};
auto* arg = func.GetArgument(i);
if (arg == nullptr) {
ThrowError(param, "????????: " + name);
}
if (entry.is_array) {
entry.ir_value = arg;
} else {
auto* slot = CreateEntryAlloca(GetIRScalarType(entry.type), NextTemp());
builder_.CreateStore(arg, slot);
entry.ir_value = slot;
}
if (!symbols_.Insert(name, entry)) {
ThrowError(param, "??????: " + name);
}
}
}

@ -0,0 +1,159 @@
#include "irgen/IRGen.h"
void IRGenImpl::EmitBlock(SysYParser::BlockContext& ctx, bool create_scope) {
if (create_scope) {
symbols_.EnterScope();
}
for (auto* item : ctx.blockItem()) {
if (item != nullptr && EmitBlockItem(*item) == FlowState::Terminated) {
break;
}
}
if (create_scope) {
symbols_.ExitScope();
}
}
IRGenImpl::FlowState IRGenImpl::EmitBlockItem(SysYParser::BlockItemContext& ctx) {
if (ctx.decl() != nullptr) {
EmitDecl(*ctx.decl(), false);
return FlowState::Continue;
}
if (ctx.stmt() != nullptr) {
return EmitStmt(*ctx.stmt());
}
ThrowError(&ctx, "??????");
}
IRGenImpl::FlowState IRGenImpl::EmitStmt(SysYParser::StmtContext& ctx) {
auto branch_terminated = [this]() {
auto* block = builder_.GetInsertBlock();
return block == nullptr || block->HasTerminator();
};
if (ctx.ASSIGN() != nullptr) {
auto lhs = ResolveLVal(*ctx.lVal());
if (lhs.is_array) {
ThrowError(&ctx, "????????");
}
if (lhs.symbol != nullptr && lhs.symbol->is_const) {
ThrowError(&ctx, "??? const ????");
}
auto rhs = CastScalar(EmitExp(*ctx.exp()), lhs.type, ctx.exp());
builder_.CreateStore(rhs.value, lhs.addr);
return FlowState::Continue;
}
if (ctx.RETURN() != nullptr) {
if (current_return_type_ == SemanticType::Void) {
if (ctx.exp() != nullptr) {
ThrowError(&ctx, "void ?????????");
}
builder_.CreateRet();
} else {
if (ctx.exp() == nullptr) {
ThrowError(&ctx, "? void ?????????");
}
auto value = CastScalar(EmitExp(*ctx.exp()), current_return_type_, ctx.exp());
builder_.CreateRet(value.value);
}
return FlowState::Terminated;
}
if (ctx.block() != nullptr) {
EmitBlock(*ctx.block(), true);
return branch_terminated() ? FlowState::Terminated : FlowState::Continue;
}
if (ctx.IF() != nullptr) {
auto* then_block = current_function_->CreateBlock(NextBlockName("if.then"));
if (ctx.ELSE() == nullptr) {
auto* end_block = current_function_->CreateBlock(NextBlockName("if.end"));
EmitCond(*ctx.cond(), then_block, end_block);
builder_.SetInsertPoint(then_block);
auto then_state = EmitStmt(*ctx.stmt(0));
if (then_state != FlowState::Terminated && !branch_terminated()) {
builder_.CreateBr(end_block);
}
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
}
auto* else_block = current_function_->CreateBlock(NextBlockName("if.else"));
EmitCond(*ctx.cond(), then_block, else_block);
ir::BasicBlock* end_block = nullptr;
builder_.SetInsertPoint(then_block);
auto then_state = EmitStmt(*ctx.stmt(0));
const bool then_terminated = then_state == FlowState::Terminated || branch_terminated();
if (!then_terminated) {
if (end_block == nullptr) {
end_block = current_function_->CreateBlock(NextBlockName("if.end"));
}
builder_.CreateBr(end_block);
}
builder_.SetInsertPoint(else_block);
auto else_state = EmitStmt(*ctx.stmt(1));
const bool else_terminated = else_state == FlowState::Terminated || branch_terminated();
if (!else_terminated) {
if (end_block == nullptr) {
end_block = current_function_->CreateBlock(NextBlockName("if.end"));
}
builder_.CreateBr(end_block);
}
if (end_block == nullptr) {
builder_.SetInsertPoint(nullptr);
return FlowState::Terminated;
}
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
}
if (ctx.WHILE() != nullptr) {
auto* cond_block = current_function_->CreateBlock(NextBlockName("while.cond"));
auto* body_block = current_function_->CreateBlock(NextBlockName("while.body"));
auto* end_block = current_function_->CreateBlock(NextBlockName("while.end"));
builder_.CreateBr(cond_block);
builder_.SetInsertPoint(cond_block);
EmitCond(*ctx.cond(), body_block, end_block);
loop_stack_.push_back({cond_block, end_block});
builder_.SetInsertPoint(body_block);
auto body_state = EmitStmt(*ctx.stmt(0));
if (body_state != FlowState::Terminated && !branch_terminated()) {
builder_.CreateBr(cond_block);
}
loop_stack_.pop_back();
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
}
if (ctx.BREAK() != nullptr) {
if (loop_stack_.empty()) {
ThrowError(&ctx, "break ????? while ???");
}
builder_.CreateBr(loop_stack_.back().exit_block);
return FlowState::Terminated;
}
if (ctx.CONTINUE() != nullptr) {
if (loop_stack_.empty()) {
ThrowError(&ctx, "continue ????? while ???");
}
builder_.CreateBr(loop_stack_.back().cond_block);
return FlowState::Terminated;
}
if (ctx.exp() != nullptr) {
(void)EmitExp(*ctx.exp());
}
return FlowState::Continue;
}

@ -0,0 +1,86 @@
#include <exception>
#include <iostream>
#include <stdexcept>
#include "frontend/AntlrDriver.h"
#include "frontend/SyntaxTreePrinter.h"
#if !COMPILER_PARSE_ONLY
#include "ir/IR.h"
#include "ir/PassManager.h"
#include "irgen/IRGen.h"
#include "mir/MIR.h"
#include "sem/Sema.h"
#endif
#include "utils/CLI.h"
#include "utils/Log.h"
int main(int argc, char** argv) {
try {
auto opts = ParseCLI(argc, argv);
if (opts.show_help) {
PrintHelp(std::cout);
return 0;
}
auto antlr = ParseFileWithAntlr(opts.input);
bool need_blank_line = false;
if (opts.emit_parse_tree) {
PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout);
need_blank_line = true;
}
#if !COMPILER_PARSE_ONLY
auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree);
if (!comp_unit) {
throw std::runtime_error(FormatError("main", "syntax tree root is not compUnit"));
}
auto sema = RunSema(*comp_unit);
std::unique_ptr<ir::Module> asm_module;
if (opts.emit_asm) {
asm_module = GenerateIR(*comp_unit, sema);
ir::RunIRPassPipeline(*asm_module);
}
if (opts.emit_ir) {
std::unique_ptr<ir::Module> ir_module;
if (opts.emit_asm) {
ir_module = GenerateIR(*comp_unit, sema);
} else {
ir_module = GenerateIR(*comp_unit, sema);
}
ir::RunIRPassPipeline(*ir_module);
if (need_blank_line) {
std::cout << "\n";
}
ir::IRPrinter printer;
printer.Print(*ir_module, std::cout);
need_blank_line = true;
}
if (opts.emit_asm) {
auto machine_module = mir::LowerToMIR(*asm_module);
mir::RunMIRPreRegAllocPassPipeline(*machine_module);
mir::RunRegAlloc(*machine_module);
mir::RunMIRPostRegAllocPassPipeline(*machine_module);
mir::RunFrameLowering(*machine_module);
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_module, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {
throw std::runtime_error(
FormatError("main", "IR/asm emission is unavailable in parse-only builds"));
}
#endif
} catch (const std::exception& ex) {
PrintException(std::cerr, ex);
return 1;
}
return 0;
}

@ -0,0 +1,140 @@
#include "mir/MIR.h"
#include <cstdint>
#include <unordered_map>
#include <vector>
namespace mir {
namespace {
bool IsHoistCandidate(const MachineFunction& function, int object_index, int use_count) {
const auto& object = function.GetStackObject(object_index);
if (object.kind != StackObjectKind::Local) {
return false;
}
if (use_count < 2) {
return false;
}
if (object.size >= 4096) {
return true;
}
return object.size >= 256 && use_count >= 4;
}
bool IsPlainFrameLea(const MachineInstr& inst, int object_index) {
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress() ||
inst.GetOperands().empty() || inst.GetOperands()[0].GetKind() != OperandKind::VReg) {
return false;
}
const auto& address = inst.GetAddress();
return address.base_kind == AddrBaseKind::FrameObject &&
address.base_index == object_index && address.const_offset == 0 &&
address.scaled_vregs.empty();
}
std::size_t FindEntryInsertPos(const MachineBasicBlock& block) {
const auto& instructions = block.GetInstructions();
std::size_t pos = 0;
while (pos < instructions.size() &&
instructions[pos].GetOpcode() == MachineInstr::Opcode::Arg) {
++pos;
}
return pos;
}
} // namespace
void RunAddressHoisting(MachineModule& module) {
for (auto& function : module.GetFunctions()) {
if (!function || function->GetBlocks().empty()) {
continue;
}
std::unordered_map<int, int> use_counts;
for (auto& block : function->GetBlocks()) {
for (auto& inst : block->GetInstructions()) {
if (!inst.HasAddress()) {
continue;
}
const auto& address = inst.GetAddress();
if (address.base_kind == AddrBaseKind::FrameObject && address.base_index >= 0) {
++use_counts[address.base_index];
}
}
}
std::unordered_map<int, int> base_vregs;
for (const auto& [object_index, count] : use_counts) {
if (!IsHoistCandidate(*function, object_index, count)) {
continue;
}
base_vregs.emplace(object_index, -1);
}
if (base_vregs.empty()) {
continue;
}
for (auto& block : function->GetBlocks()) {
for (auto& inst : block->GetInstructions()) {
if (!inst.HasAddress()) {
continue;
}
const auto& address = inst.GetAddress();
auto it = base_vregs.find(address.base_index);
if (it == base_vregs.end()) {
continue;
}
if (it->second >= 0) {
continue;
}
if (IsPlainFrameLea(inst, address.base_index)) {
it->second = inst.GetOperands()[0].GetVReg();
}
}
}
auto& entry_block = *function->GetBlocks().front();
auto& entry_insts = entry_block.GetInstructions();
std::size_t insert_pos = FindEntryInsertPos(entry_block);
for (auto& [object_index, base_vreg] : base_vregs) {
if (base_vreg >= 0) {
continue;
}
base_vreg = function->NewVReg(ValueType::Ptr);
MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(base_vreg)});
AddressExpr address;
address.base_kind = AddrBaseKind::FrameObject;
address.base_index = object_index;
lea.SetAddress(std::move(address));
entry_insts.insert(entry_insts.begin() + static_cast<std::ptrdiff_t>(insert_pos),
std::move(lea));
++insert_pos;
}
for (auto& block : function->GetBlocks()) {
for (auto& inst : block->GetInstructions()) {
if (!inst.HasAddress()) {
continue;
}
auto& address = inst.GetAddress();
auto it = base_vregs.find(address.base_index);
if (it == base_vregs.end()) {
continue;
}
if (IsPlainFrameLea(inst, address.base_index) &&
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
inst.GetOperands()[0].GetVReg() == it->second) {
continue;
}
if (address.base_kind != AddrBaseKind::FrameObject || address.base_index < 0) {
continue;
}
address.base_kind = AddrBaseKind::VReg;
address.base_index = it->second;
}
}
}
}
} // namespace mir

File diff suppressed because it is too large Load Diff

@ -0,0 +1,25 @@
add_library(mir_core STATIC
MIRContext.cpp
MIRFunction.cpp
MIRBasicBlock.cpp
MIRInstr.cpp
Register.cpp
Lowering.cpp
AddressHoisting.cpp
RegAlloc.cpp
FrameLowering.cpp
AsmPrinter.cpp
)
target_link_libraries(mir_core PUBLIC
build_options
ir
)
add_subdirectory(passes)
add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE
mir_core
mir_passes
)

@ -0,0 +1,40 @@
#include "mir/MIR.h"
#include <string>
namespace mir {
namespace {
int AlignTo(int value, int align) {
if (align <= 1) {
return value;
}
return ((value + align - 1) / align) * align;
}
} // namespace
void RunFrameLowering(MachineModule& module) {
for (auto& function : module.GetFunctions()) {
for (int reg : function->GetUsedCalleeSavedGPRs()) {
function->CreateStackObject(8, 8, StackObjectKind::SavedGPR,
"save.x" + std::to_string(reg));
}
for (int reg : function->GetUsedCalleeSavedFPRs()) {
function->CreateStackObject(8, 8, StackObjectKind::SavedFPR,
"save.v" + std::to_string(reg));
}
int cursor = 0;
const int object_count = static_cast<int>(function->GetStackObjects().size());
for (int i = 0; i < object_count; ++i) {
auto& object = function->GetStackObject(i);
cursor = AlignTo(cursor, object.align);
cursor += object.size;
object.offset = -cursor;
}
function->SetFrameSize(AlignTo(cursor, 16));
}
}
} // namespace mir

@ -0,0 +1,996 @@
#include "mir/MIR.h"
#include <algorithm>
#include <cstring>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
enum class LoweredKind { Invalid, VReg, StackObject, Global };
std::vector<ir::BasicBlock*> CollectLoweringOrder(ir::Function& function) {
std::vector<ir::BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<ir::BasicBlock*> visited;
std::vector<ir::BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
for (const auto& block : function.GetBlocks()) {
if (block && visited.insert(block.get()).second) {
order.push_back(block.get());
}
}
return order;
}
struct LoweredValue {
LoweredKind kind = LoweredKind::Invalid;
ValueType type = ValueType::Void;
int index = -1;
std::string symbol;
};
ValueType LowerType(const std::shared_ptr<ir::Type>& type) {
if (!type || type->IsVoid()) {
return ValueType::Void;
}
if (type->IsInt1()) {
return ValueType::I1;
}
if (type->IsInt32()) {
return ValueType::I32;
}
if (type->IsFloat()) {
return ValueType::F32;
}
if (type->IsPointer()) {
return ValueType::Ptr;
}
throw std::runtime_error(FormatError("mir", "unsupported IR type in backend lowering"));
}
int GetIRTypeAlign(const std::shared_ptr<ir::Type>& type) {
if (!type) {
return 1;
}
if (type->IsArray()) {
return GetIRTypeAlign(type->GetElementType());
}
return GetValueAlign(LowerType(type));
}
bool ShouldMaterializeAllocaBase(const std::shared_ptr<ir::Type>& type) {
return type && type->IsArray() && type->GetSize() >= 256;
}
CondCode LowerIntCond(ir::Opcode opcode) {
switch (opcode) {
case ir::Opcode::ICmpEQ:
return CondCode::EQ;
case ir::Opcode::ICmpNE:
return CondCode::NE;
case ir::Opcode::ICmpLT:
return CondCode::LT;
case ir::Opcode::ICmpGT:
return CondCode::GT;
case ir::Opcode::ICmpLE:
return CondCode::LE;
case ir::Opcode::ICmpGE:
return CondCode::GE;
default:
throw std::runtime_error(FormatError("mir", "invalid integer compare opcode"));
}
}
CondCode LowerFloatCond(ir::Opcode opcode) {
switch (opcode) {
case ir::Opcode::FCmpEQ:
return CondCode::EQ;
case ir::Opcode::FCmpNE:
return CondCode::NE;
case ir::Opcode::FCmpLT:
return CondCode::LT;
case ir::Opcode::FCmpGT:
return CondCode::GT;
case ir::Opcode::FCmpLE:
return CondCode::LE;
case ir::Opcode::FCmpGE:
return CondCode::GE;
default:
throw std::runtime_error(FormatError("mir", "invalid float compare opcode"));
}
}
std::int64_t FloatBits(float value) {
std::uint32_t bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
return static_cast<std::int64_t>(bits);
}
class Lowerer {
public:
explicit Lowerer(const ir::Module& module)
: module_(module), machine_module_(std::make_unique<MachineModule>(module)) {}
std::unique_ptr<MachineModule> Run() {
for (const auto& func : module_.GetFunctions()) {
if (func && !func->IsExternal()) {
LowerFunction(*func);
}
}
return std::move(machine_module_);
}
private:
using OperandMap = std::unordered_map<const ir::Value*, MachineOperand>;
MachineOperand ResolveScalarOperand(ir::Value* value,
const OperandMap* inline_values = nullptr) {
if (auto* ci = ir::dyncast<ir::ConstantInt>(value)) {
return MachineOperand::Imm(ci->GetValue());
}
if (auto* cb = ir::dyncast<ir::ConstantI1>(value)) {
return MachineOperand::Imm(cb->GetValue() ? 1 : 0);
}
if (auto* cf = ir::dyncast<ir::ConstantFloat>(value)) {
return MachineOperand::Imm(FloatBits(cf->GetValue()));
}
if (inline_values != nullptr) {
auto inline_it = inline_values->find(value);
if (inline_it != inline_values->end()) {
return inline_it->second;
}
}
auto it = values_.find(value);
if (it == values_.end() || it->second.kind != LoweredKind::VReg) {
throw std::runtime_error(
FormatError("mir", "value is not materialized as a virtual register: " +
value->GetName()));
}
return MachineOperand::VReg(it->second.index);
}
MachineOperand LowerScalarOperand(ir::Value* value) {
return ResolveScalarOperand(value, nullptr);
}
AddressExpr LowerAddress(ir::Value* value) {
if (auto* global = ir::dyncast<ir::GlobalValue>(value)) {
AddressExpr address;
address.base_kind = AddrBaseKind::Global;
address.symbol = global->GetName();
return address;
}
auto it = values_.find(value);
if (it == values_.end()) {
throw std::runtime_error(FormatError("mir", "missing lowered address value"));
}
AddressExpr address;
switch (it->second.kind) {
case LoweredKind::StackObject:
address.base_kind = AddrBaseKind::FrameObject;
address.base_index = it->second.index;
return address;
case LoweredKind::Global:
address.base_kind = AddrBaseKind::Global;
address.symbol = it->second.symbol;
return address;
case LoweredKind::VReg:
address.base_kind = AddrBaseKind::VReg;
address.base_index = it->second.index;
return address;
case LoweredKind::Invalid:
break;
}
throw std::runtime_error(FormatError("mir", "invalid address lowering"));
}
MachineInstr::Opcode LowerBinaryOpcode(ir::Opcode opcode) {
switch (opcode) {
case ir::Opcode::Add:
return MachineInstr::Opcode::Add;
case ir::Opcode::Sub:
return MachineInstr::Opcode::Sub;
case ir::Opcode::Mul:
return MachineInstr::Opcode::Mul;
case ir::Opcode::Div:
return MachineInstr::Opcode::Div;
case ir::Opcode::Rem:
return MachineInstr::Opcode::Rem;
case ir::Opcode::And:
return MachineInstr::Opcode::And;
case ir::Opcode::Or:
return MachineInstr::Opcode::Or;
case ir::Opcode::Xor:
return MachineInstr::Opcode::Xor;
case ir::Opcode::Shl:
return MachineInstr::Opcode::Shl;
case ir::Opcode::AShr:
return MachineInstr::Opcode::AShr;
case ir::Opcode::LShr:
return MachineInstr::Opcode::LShr;
case ir::Opcode::FAdd:
return MachineInstr::Opcode::FAdd;
case ir::Opcode::FSub:
return MachineInstr::Opcode::FSub;
case ir::Opcode::FMul:
return MachineInstr::Opcode::FMul;
case ir::Opcode::FDiv:
return MachineInstr::Opcode::FDiv;
default:
throw std::runtime_error(FormatError("mir", "unsupported binary opcode"));
}
}
LoweredValue NewVRegValue(ValueType type) {
return {LoweredKind::VReg, type, current_function_->NewVReg(type), ""};
}
LoweredValue MaterializeOperandAsValue(const MachineOperand& operand, ValueType type) {
if (operand.GetKind() == OperandKind::VReg) {
return {LoweredKind::VReg, type, operand.GetVReg(), ""};
}
auto lowered = NewVRegValue(type);
current_block_->Append(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(lowered.index), operand});
return lowered;
}
void InsertBeforeTerminator(MachineBasicBlock* block, MachineInstr instr) {
auto& instructions = block->GetInstructions();
auto insert_pos = instructions.end();
if (!instructions.empty() && instructions.back().IsTerminator()) {
insert_pos = instructions.end() - 1;
}
instructions.insert(insert_pos, std::move(instr));
}
struct PhiCopy {
int dst_vreg = -1;
MachineOperand src;
};
void EmitResolvedPhiCopies(MachineBasicBlock* block, std::vector<PhiCopy> copies) {
copies.erase(std::remove_if(copies.begin(), copies.end(),
[](const PhiCopy& copy) {
return copy.src.GetKind() == OperandKind::VReg &&
copy.src.GetVReg() == copy.dst_vreg;
}),
copies.end());
while (!copies.empty()) {
bool progress = false;
for (auto it = copies.begin(); it != copies.end(); ++it) {
const bool dst_is_still_needed_as_source =
std::any_of(copies.begin(), copies.end(), [&](const PhiCopy& other) {
return other.src.GetKind() == OperandKind::VReg &&
other.src.GetVReg() == it->dst_vreg;
});
if (dst_is_still_needed_as_source) {
continue;
}
InsertBeforeTerminator(
block, MachineInstr(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(it->dst_vreg), it->src}));
copies.erase(it);
progress = true;
break;
}
if (progress) {
continue;
}
auto& cycle = copies.front();
if (cycle.src.GetKind() != OperandKind::VReg) {
throw std::runtime_error(FormatError("mir", "invalid phi copy cycle"));
}
const int src_vreg = cycle.src.GetVReg();
const auto temp_type = current_function_->GetVRegInfo(src_vreg).type;
const int temp_vreg = current_function_->NewVReg(temp_type);
InsertBeforeTerminator(
block, MachineInstr(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(temp_vreg), MachineOperand::VReg(src_vreg)}));
for (auto& copy : copies) {
if (copy.src.GetKind() == OperandKind::VReg && copy.src.GetVReg() == src_vreg) {
copy.src = MachineOperand::VReg(temp_vreg);
}
}
}
}
void RedirectEdgeToPhiBlock(MachineBasicBlock* pred_block,
const std::string& succ_name,
const std::string& phi_block_name) {
auto& instructions = pred_block->GetInstructions();
if (instructions.empty() || !instructions.back().IsTerminator()) {
throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator"));
}
auto& term = instructions.back();
auto& operands = term.GetOperands();
switch (term.GetOpcode()) {
case MachineInstr::Opcode::Br:
if (!operands.empty() && operands[0].GetKind() == OperandKind::Block &&
operands[0].GetText() == succ_name) {
operands[0] = MachineOperand::Block(phi_block_name);
return;
}
break;
case MachineInstr::Opcode::CondBr:
for (size_t i = 1; i < operands.size(); ++i) {
if (operands[i].GetKind() == OperandKind::Block &&
operands[i].GetText() == succ_name) {
operands[i] = MachineOperand::Block(phi_block_name);
return;
}
}
break;
default:
break;
}
throw std::runtime_error(FormatError("mir", "failed to redirect phi edge"));
}
void PreparePhiResults(ir::Function& function) {
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
if (inst->GetOpcode() != ir::Opcode::Phi) {
break;
}
auto lowered = NewVRegValue(LowerType(inst->GetType()));
values_[inst.get()] = lowered;
}
}
}
void EmitPhiCopies(ir::Function& function) {
struct EdgeCopies {
MachineBasicBlock* succ_block = nullptr;
std::vector<PhiCopy> copies;
};
std::unordered_map<MachineBasicBlock*, std::vector<EdgeCopies>> pending;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
if (inst->GetOpcode() != ir::Opcode::Phi) {
break;
}
auto* phi = static_cast<ir::PhiInst*>(inst.get());
const int dest_vreg = values_.at(phi).index;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* pred_block = blocks_.at(phi->GetIncomingBlock(i));
auto* succ_block = blocks_.at(block.get());
auto& edges = pending[pred_block];
auto edge_it = std::find_if(edges.begin(), edges.end(), [&](const EdgeCopies& edge) {
return edge.succ_block == succ_block;
});
if (edge_it == edges.end()) {
edges.push_back({succ_block, {}});
edge_it = std::prev(edges.end());
}
edge_it->copies.push_back(
{dest_vreg, LowerScalarOperand(phi->GetIncomingValue(i))});
}
}
}
int phi_block_index = 0;
for (auto& item : pending) {
auto* pred_block = item.first;
auto& pred_instructions = pred_block->GetInstructions();
if (pred_instructions.empty() || !pred_instructions.back().IsTerminator()) {
throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator"));
}
const auto terminator_opcode = pred_instructions.back().GetOpcode();
for (auto& edge : item.second) {
if (terminator_opcode == MachineInstr::Opcode::Br) {
EmitResolvedPhiCopies(pred_block, std::move(edge.copies));
continue;
}
if (terminator_opcode != MachineInstr::Opcode::CondBr) {
throw std::runtime_error(
FormatError("mir", "unsupported terminator for phi lowering"));
}
auto* phi_block = current_function_->CreateBlock(
"phi.edge." + std::to_string(phi_block_index++));
EmitResolvedPhiCopies(phi_block, std::move(edge.copies));
phi_block->Append(MachineInstr::Opcode::Br,
{MachineOperand::Block(edge.succ_block->GetName())});
RedirectEdgeToPhiBlock(pred_block, edge.succ_block->GetName(), phi_block->GetName());
}
}
}
bool CanInlineDirectCall(const ir::Function& function) const {
if (function.IsExternal() || function.GetBlocks().size() != 1) {
return false;
}
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
switch (inst->GetOpcode()) {
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Rem:
case ir::Opcode::And:
case ir::Opcode::Or:
case ir::Opcode::Xor:
case ir::Opcode::Shl:
case ir::Opcode::AShr:
case ir::Opcode::LShr:
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv:
case ir::Opcode::FNeg:
case ir::Opcode::ICmpEQ:
case ir::Opcode::ICmpNE:
case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE:
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE:
case ir::Opcode::Zext:
case ir::Opcode::IToF:
case ir::Opcode::FtoI:
case ir::Opcode::Call:
case ir::Opcode::Return:
break;
default:
return false;
}
}
}
return true;
}
bool TryInlineFunctionBody(const ir::Function& callee, OperandMap* inline_values,
MachineOperand* return_operand, bool* has_return,
int inline_depth) {
if (inline_depth > 2) {
return false;
}
for (const auto& inst : callee.GetBlocks().front()->GetInstructions()) {
switch (inst->GetOpcode()) {
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Rem:
case ir::Opcode::And:
case ir::Opcode::Or:
case ir::Opcode::Xor:
case ir::Opcode::Shl:
case ir::Opcode::AShr:
case ir::Opcode::LShr:
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
auto lowered = NewVRegValue(LowerType(binary->GetType()));
current_block_->Append(LowerBinaryOpcode(inst->GetOpcode()),
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(binary->GetLhs(), inline_values),
ResolveScalarOperand(binary->GetRhs(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::FNeg: {
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::FNeg,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(unary->GetOprd(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::ICmpEQ:
case ir::Opcode::ICmpNE:
case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::ICmp,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(binary->GetLhs(), inline_values),
ResolveScalarOperand(binary->GetRhs(), inline_values)});
instr.SetCondCode(LowerIntCond(inst->GetOpcode()));
current_block_->Append(std::move(instr));
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::FCmp,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(binary->GetLhs(), inline_values),
ResolveScalarOperand(binary->GetRhs(), inline_values)});
instr.SetCondCode(LowerFloatCond(inst->GetOpcode()));
current_block_->Append(std::move(instr));
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::Zext: {
auto* zext = static_cast<ir::ZextInst*>(inst.get());
auto lowered = NewVRegValue(LowerType(zext->GetType()));
current_block_->Append(MachineInstr::Opcode::ZExt,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(zext->GetValue(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::IToF: {
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::ItoF,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(unary->GetOprd(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::FtoI: {
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::FtoI,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(unary->GetOprd(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::Call: {
auto* nested_call = static_cast<ir::CallInst*>(inst.get());
auto* nested_callee = nested_call->GetCallee();
if (nested_callee == nullptr || nested_callee == current_ir_function_) {
return false;
}
if (CanInlineDirectCall(*nested_callee)) {
MachineOperand nested_return_operand;
bool nested_has_return = false;
OperandMap nested_values;
const auto& nested_args = nested_callee->GetArguments();
const auto& nested_call_args = nested_call->GetArguments();
if (nested_args.size() != nested_call_args.size()) {
return false;
}
for (size_t i = 0; i < nested_call_args.size(); ++i) {
nested_values[nested_args[i].get()] =
ResolveScalarOperand(nested_call_args[i], inline_values);
}
if (!TryInlineFunctionBody(*nested_callee, &nested_values, &nested_return_operand,
&nested_has_return, inline_depth + 1)) {
return false;
}
if (!nested_call->GetType()->IsVoid()) {
if (!nested_has_return) {
throw std::runtime_error(
FormatError("mir", "inlined nested call is missing return value"));
}
auto nested_value =
MaterializeOperandAsValue(nested_return_operand, LowerType(nested_call->GetType()));
(*inline_values)[inst.get()] = MachineOperand::VReg(nested_value.index);
}
break;
}
std::vector<MachineOperand> operands;
if (!nested_call->GetType()->IsVoid()) {
auto lowered = NewVRegValue(LowerType(nested_call->GetType()));
operands.push_back(MachineOperand::VReg(lowered.index));
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
}
std::vector<ValueType> arg_types;
for (auto* arg : nested_call->GetArguments()) {
operands.push_back(ResolveScalarOperand(arg, inline_values));
arg_types.push_back(LowerType(arg->GetType()));
}
MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands));
instr.SetCallInfo(nested_callee->GetName(), std::move(arg_types),
LowerType(nested_call->GetType()));
current_block_->Append(std::move(instr));
break;
}
case ir::Opcode::Return: {
auto* ret = static_cast<ir::ReturnInst*>(inst.get());
if (ret->HasReturnValue()) {
*return_operand = ResolveScalarOperand(ret->GetReturnValue(), inline_values);
*has_return = true;
}
break;
}
default:
return false;
}
}
return true;
}
bool TryInlineDirectCall(ir::CallInst* call) {
auto* callee = call->GetCallee();
if (callee == nullptr || callee == current_ir_function_ || !CanInlineDirectCall(*callee)) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto& call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
OperandMap inline_values;
for (size_t i = 0; i < call_args.size(); ++i) {
inline_values[callee_args[i].get()] = ResolveScalarOperand(call_args[i], nullptr);
}
MachineOperand return_operand;
bool has_return = false;
if (!TryInlineFunctionBody(*callee, &inline_values, &return_operand, &has_return, 0)) {
return false;
}
if (!call->GetType()->IsVoid()) {
if (!has_return) {
throw std::runtime_error(FormatError("mir", "inlined call is missing return value"));
}
values_[call] = MaterializeOperandAsValue(return_operand, LowerType(call->GetType()));
}
return true;
}
void LowerInstruction(ir::Instruction& inst) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
auto* alloca_inst = static_cast<ir::AllocaInst*>(&inst);
const auto allocated_type = alloca_inst->GetAllocatedType();
const int object = current_function_->CreateStackObject(
allocated_type->GetSize(), GetIRTypeAlign(allocated_type),
StackObjectKind::Local, inst.GetName());
if (ShouldMaterializeAllocaBase(allocated_type)) {
auto lowered = NewVRegValue(ValueType::Ptr);
MachineInstr lea(MachineInstr::Opcode::Lea,
{MachineOperand::VReg(lowered.index)});
AddressExpr address;
address.base_kind = AddrBaseKind::FrameObject;
address.base_index = object;
lea.SetAddress(std::move(address));
current_block_->Append(std::move(lea));
values_[&inst] = lowered;
} else {
values_[&inst] = {LoweredKind::StackObject, ValueType::Ptr, object, ""};
}
return;
}
case ir::Opcode::Load: {
auto* load = static_cast<ir::LoadInst*>(&inst);
auto lowered = NewVRegValue(LowerType(load->GetType()));
MachineInstr instr(MachineInstr::Opcode::Load,
{MachineOperand::VReg(lowered.index)});
instr.SetAddress(LowerAddress(load->GetPtr()));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::Store: {
auto* store = static_cast<ir::StoreInst*>(&inst);
MachineInstr instr(MachineInstr::Opcode::Store,
{LowerScalarOperand(store->GetValue())});
instr.SetValueType(LowerType(store->GetValue()->GetType()));
instr.SetAddress(LowerAddress(store->GetPtr()));
current_block_->Append(std::move(instr));
return;
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Rem:
case ir::Opcode::And:
case ir::Opcode::Or:
case ir::Opcode::Xor:
case ir::Opcode::Shl:
case ir::Opcode::AShr:
case ir::Opcode::LShr:
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto* binary = static_cast<ir::BinaryInst*>(&inst);
auto lowered = NewVRegValue(LowerType(binary->GetType()));
current_block_->Append(LowerBinaryOpcode(inst.GetOpcode()),
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(binary->GetLhs()),
LowerScalarOperand(binary->GetRhs())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::FNeg: {
auto* unary = static_cast<ir::UnaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::FNeg,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(unary->GetOprd())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::ICmpEQ:
case ir::Opcode::ICmpNE:
case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::ICmp,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(binary->GetLhs()),
LowerScalarOperand(binary->GetRhs())});
instr.SetCondCode(LowerIntCond(inst.GetOpcode()));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::FCmp,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(binary->GetLhs()),
LowerScalarOperand(binary->GetRhs())});
instr.SetCondCode(LowerFloatCond(inst.GetOpcode()));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::Zext: {
auto* zext = static_cast<ir::ZextInst*>(&inst);
auto lowered = NewVRegValue(LowerType(zext->GetType()));
current_block_->Append(MachineInstr::Opcode::ZExt,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(zext->GetValue())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::IToF: {
auto* unary = static_cast<ir::UnaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::ItoF,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(unary->GetOprd())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::FtoI: {
auto* unary = static_cast<ir::UnaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::FtoI,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(unary->GetOprd())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::GetElementPtr: {
auto* gep = static_cast<ir::GetElementPtrInst*>(&inst);
auto lowered = NewVRegValue(ValueType::Ptr);
AddressExpr address = LowerAddress(gep->GetPointer());
auto current_type = gep->GetSourceType();
for (size_t i = 0; i < gep->GetNumIndices(); ++i) {
auto* index = gep->GetIndex(i);
const std::int64_t stride = current_type ? current_type->GetSize() : 0;
if (auto* ci = ir::dyncast<ir::ConstantInt>(index)) {
address.const_offset += static_cast<std::int64_t>(ci->GetValue()) * stride;
} else if (auto* cb = ir::dyncast<ir::ConstantI1>(index)) {
address.const_offset +=
static_cast<std::int64_t>(cb->GetValue() ? 1 : 0) * stride;
} else {
address.scaled_vregs.push_back({LowerScalarOperand(index).GetVReg(), stride});
}
if (current_type && current_type->IsArray()) {
current_type = current_type->GetElementType();
}
}
MachineInstr instr(MachineInstr::Opcode::Lea,
{MachineOperand::VReg(lowered.index)});
instr.SetAddress(std::move(address));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::Call: {
auto* call = static_cast<ir::CallInst*>(&inst);
if (TryInlineDirectCall(call)) {
return;
}
std::vector<MachineOperand> operands;
if (!call->GetType()->IsVoid()) {
auto lowered = NewVRegValue(LowerType(call->GetType()));
operands.push_back(MachineOperand::VReg(lowered.index));
values_[&inst] = lowered;
}
std::vector<ValueType> arg_types;
for (auto* arg : call->GetArguments()) {
operands.push_back(LowerScalarOperand(arg));
arg_types.push_back(LowerType(arg->GetType()));
}
MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands));
instr.SetCallInfo(call->GetCallee()->GetName(), std::move(arg_types),
LowerType(call->GetType()));
current_block_->Append(std::move(instr));
return;
}
case ir::Opcode::Br: {
auto* br = static_cast<ir::UncondBrInst*>(&inst);
current_block_->Append(MachineInstr::Opcode::Br,
{MachineOperand::Block(blocks_.at(br->GetDest())->GetName())});
return;
}
case ir::Opcode::CondBr: {
auto* br = static_cast<ir::CondBrInst*>(&inst);
current_block_->Append(MachineInstr::Opcode::CondBr,
{LowerScalarOperand(br->GetCondition()),
MachineOperand::Block(blocks_.at(br->GetThenBlock())->GetName()),
MachineOperand::Block(blocks_.at(br->GetElseBlock())->GetName())});
return;
}
case ir::Opcode::Return: {
auto* ret = static_cast<ir::ReturnInst*>(&inst);
if (ret->HasReturnValue()) {
MachineInstr instr(MachineInstr::Opcode::Ret,
{LowerScalarOperand(ret->GetReturnValue())});
instr.SetValueType(LowerType(ret->GetReturnValue()->GetType()));
current_block_->Append(std::move(instr));
} else {
current_block_->Append(MachineInstr::Opcode::Ret);
}
return;
}
case ir::Opcode::Memset: {
auto* memset_inst = static_cast<ir::MemsetInst*>(&inst);
MachineInstr instr(MachineInstr::Opcode::Memset,
{LowerScalarOperand(memset_inst->GetValue()),
LowerScalarOperand(memset_inst->GetLength())});
instr.SetAddress(LowerAddress(memset_inst->GetDest()));
current_block_->Append(std::move(instr));
return;
}
case ir::Opcode::Unreachable:
current_block_->Append(MachineInstr::Opcode::Unreachable);
return;
case ir::Opcode::Phi:
return;
case ir::Opcode::FRem:
case ir::Opcode::Neg:
case ir::Opcode::Not:
throw std::runtime_error(
FormatError("mir", "unsupported instruction in backend lowering"));
}
throw std::runtime_error(FormatError("mir", "unsupported IR opcode in backend lowering"));
}
void LowerFunction(ir::Function& function) {
values_.clear();
blocks_.clear();
std::vector<ValueType> param_types;
for (const auto& type : function.GetParamTypes()) {
param_types.push_back(LowerType(type));
}
auto machine_function = std::make_unique<MachineFunction>(
function.GetName(), LowerType(function.GetReturnType()), std::move(param_types));
current_ir_function_ = &function;
current_function_ = machine_function.get();
const auto ordered_blocks = CollectLoweringOrder(function);
for (const auto& block : function.GetBlocks()) {
blocks_[block.get()] = current_function_->CreateBlock(block->GetName());
}
if (!function.GetBlocks().empty()) {
auto* entry = blocks_.at(function.GetBlocks().front().get());
for (const auto& argument : function.GetArguments()) {
auto lowered = NewVRegValue(LowerType(argument->GetType()));
entry->Append(MachineInstr::Opcode::Arg,
{MachineOperand::VReg(lowered.index),
MachineOperand::Imm(static_cast<std::int64_t>(argument->GetIndex()))});
values_[argument.get()] = lowered;
}
}
PreparePhiResults(function);
for (auto* block : ordered_blocks) {
current_block_ = blocks_.at(block);
for (const auto& inst : block->GetInstructions()) {
LowerInstruction(*inst);
}
}
EmitPhiCopies(function);
machine_module_->AddFunction(std::move(machine_function));
current_ir_function_ = nullptr;
current_function_ = nullptr;
current_block_ = nullptr;
}
const ir::Module& module_;
std::unique_ptr<MachineModule> machine_module_;
ir::Function* current_ir_function_ = nullptr;
MachineFunction* current_function_ = nullptr;
MachineBasicBlock* current_block_ = nullptr;
std::unordered_map<const ir::Value*, LoweredValue> values_;
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> blocks_;
};
} // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
DefaultContext();
return Lowerer(module).Run();
}
} // namespace mir

@ -0,0 +1,21 @@
#include "mir/MIR.h"
#include <utility>
namespace mir {
MachineBasicBlock::MachineBasicBlock(std::string name)
: name_(std::move(name)) {}
MachineInstr& MachineBasicBlock::Append(MachineInstr::Opcode opcode,
std::vector<MachineOperand> operands) {
instructions_.emplace_back(opcode, std::move(operands));
return instructions_.back();
}
MachineInstr& MachineBasicBlock::Append(MachineInstr instr) {
instructions_.push_back(std::move(instr));
return instructions_.back();
}
} // namespace mir

@ -0,0 +1,11 @@
#include "mir/MIR.h"
namespace mir {
namespace {
MIRContext g_context;
} // namespace
MIRContext& DefaultContext() { return g_context; }
} // namespace mir

@ -0,0 +1,106 @@
#include "mir/MIR.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
namespace mir {
MachineFunction::MachineFunction(std::string name, ValueType return_type,
std::vector<ValueType> param_types)
: name_(std::move(name)),
return_type_(return_type),
param_types_(std::move(param_types)) {}
MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) {
auto block = std::make_unique<MachineBasicBlock>(name);
auto* ptr = block.get();
blocks_.push_back(std::move(block));
return ptr;
}
int MachineFunction::NewVReg(ValueType type) {
const int id = static_cast<int>(vregs_.size());
vregs_.push_back({id, type});
allocations_.push_back({});
return id;
}
const VRegInfo& MachineFunction::GetVRegInfo(int id) const {
if (id < 0 || id >= static_cast<int>(vregs_.size())) {
throw std::out_of_range("virtual register index out of range");
}
return vregs_[static_cast<size_t>(id)];
}
VRegInfo& MachineFunction::GetVRegInfo(int id) {
if (id < 0 || id >= static_cast<int>(vregs_.size())) {
throw std::out_of_range("virtual register index out of range");
}
return vregs_[static_cast<size_t>(id)];
}
int MachineFunction::CreateStackObject(int size, int align, StackObjectKind kind,
const std::string& name) {
const int index = static_cast<int>(stack_objects_.size());
stack_objects_.push_back({index, kind, size, align, 0, name});
return index;
}
StackObject& MachineFunction::GetStackObject(int index) {
if (index < 0 || index >= static_cast<int>(stack_objects_.size())) {
throw std::out_of_range("stack object index out of range");
}
return stack_objects_[static_cast<size_t>(index)];
}
const StackObject& MachineFunction::GetStackObject(int index) const {
if (index < 0 || index >= static_cast<int>(stack_objects_.size())) {
throw std::out_of_range("stack object index out of range");
}
return stack_objects_[static_cast<size_t>(index)];
}
void MachineFunction::SetAllocation(int vreg, Allocation allocation) {
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
throw std::out_of_range("allocation index out of range");
}
allocations_[static_cast<size_t>(vreg)] = allocation;
}
const Allocation& MachineFunction::GetAllocation(int vreg) const {
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
throw std::out_of_range("allocation index out of range");
}
return allocations_[static_cast<size_t>(vreg)];
}
Allocation& MachineFunction::GetAllocation(int vreg) {
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
throw std::out_of_range("allocation index out of range");
}
return allocations_[static_cast<size_t>(vreg)];
}
void MachineFunction::AddUsedCalleeSavedGPR(int reg_index) {
if (std::find(used_callee_saved_gprs_.begin(), used_callee_saved_gprs_.end(),
reg_index) == used_callee_saved_gprs_.end()) {
used_callee_saved_gprs_.push_back(reg_index);
}
}
void MachineFunction::AddUsedCalleeSavedFPR(int reg_index) {
if (std::find(used_callee_saved_fprs_.begin(), used_callee_saved_fprs_.end(),
reg_index) == used_callee_saved_fprs_.end()) {
used_callee_saved_fprs_.push_back(reg_index);
}
}
MachineFunction* MachineModule::AddFunction(
std::unique_ptr<MachineFunction> function) {
auto* ptr = function.get();
functions_.push_back(std::move(function));
return ptr;
}
} // namespace mir

@ -0,0 +1,178 @@
#include "mir/MIR.h"
namespace mir {
MachineOperand::MachineOperand(OperandKind kind, int vreg, std::int64_t imm,
std::string text)
: kind_(kind), vreg_(vreg), imm_(imm), text_(std::move(text)) {}
MachineOperand MachineOperand::VReg(int reg) {
return MachineOperand(OperandKind::VReg, reg, 0, "");
}
MachineOperand MachineOperand::Imm(std::int64_t value) {
return MachineOperand(OperandKind::Imm, -1, value, "");
}
MachineOperand MachineOperand::Block(std::string name) {
return MachineOperand(OperandKind::Block, -1, 0, std::move(name));
}
MachineOperand MachineOperand::Symbol(std::string name) {
return MachineOperand(OperandKind::Symbol, -1, 0, std::move(name));
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<MachineOperand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
bool MachineInstr::IsTerminator() const {
return opcode_ == Opcode::Br || opcode_ == Opcode::CondBr ||
opcode_ == Opcode::Ret || opcode_ == Opcode::Unreachable;
}
std::vector<int> MachineInstr::GetDefs() const {
switch (opcode_) {
case Opcode::Arg:
case Opcode::Copy:
case Opcode::Load:
case Opcode::Lea:
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FNeg:
case Opcode::ICmp:
case Opcode::FCmp:
case Opcode::ZExt:
case Opcode::ItoF:
case Opcode::FtoI:
if (!operands_.empty() && operands_[0].GetKind() == OperandKind::VReg) {
return {operands_[0].GetVReg()};
}
return {};
case Opcode::Call:
if (call_return_type_ != ValueType::Void && !operands_.empty() &&
operands_[0].GetKind() == OperandKind::VReg) {
return {operands_[0].GetVReg()};
}
return {};
case Opcode::Store:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Ret:
case Opcode::Memset:
case Opcode::Unreachable:
return {};
}
return {};
}
std::vector<int> MachineInstr::GetUses() const {
std::vector<int> uses;
auto push_vreg = [&](const MachineOperand& operand) {
if (operand.GetKind() == OperandKind::VReg) {
uses.push_back(operand.GetVReg());
}
};
auto push_addr_uses = [&]() {
if (!has_address_) {
return;
}
if (address_.base_kind == AddrBaseKind::VReg && address_.base_index >= 0) {
uses.push_back(address_.base_index);
}
for (const auto& term : address_.scaled_vregs) {
uses.push_back(term.first);
}
};
switch (opcode_) {
case Opcode::Arg:
case Opcode::Br:
case Opcode::Unreachable:
break;
case Opcode::Copy:
case Opcode::ZExt:
case Opcode::ItoF:
case Opcode::FtoI:
case Opcode::FNeg:
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
}
break;
case Opcode::Load:
case Opcode::Lea:
push_addr_uses();
break;
case Opcode::Store:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
push_addr_uses();
break;
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::ICmp:
case Opcode::FCmp:
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
}
if (operands_.size() >= 3) {
push_vreg(operands_[2]);
}
break;
case Opcode::CondBr:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
break;
case Opcode::Call: {
size_t arg_begin = call_return_type_ == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands_.size(); ++i) {
push_vreg(operands_[i]);
}
break;
}
case Opcode::Ret:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
break;
case Opcode::Memset:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
}
push_addr_uses();
break;
}
return uses;
}
} // namespace mir

@ -0,0 +1,820 @@
#include "mir/MIR.h"
#include <algorithm>
#include <cstdint>
#include <limits>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "utils/Log.h"
namespace mir {
namespace {
struct BlockInfo {
int start_pos = 0;
int end_pos = 0;
std::vector<int> successors;
std::vector<std::uint8_t> use;
std::vector<std::uint8_t> def;
std::vector<std::uint8_t> live_in;
std::vector<std::uint8_t> live_out;
};
struct MoveEdge {
int dst = -1;
int src = -1;
};
bool BelongsToClass(ValueType type, RegClass reg_class) {
if (type == ValueType::Void) {
return false;
}
return IsFPR(type) ? reg_class == RegClass::FPR : reg_class == RegClass::GPR;
}
bool IsCalleeSaved(PhysReg reg) {
if (reg.reg_class == RegClass::GPR) {
return reg.index >= 19 && reg.index <= 28;
}
return reg.index >= 8 && reg.index <= 15;
}
bool IsCallerSaved(PhysReg reg) {
return !IsCalleeSaved(reg);
}
bool IsCheapRematerializableInst(const MachineInstr& inst) {
if (inst.GetDefs().size() != 1) {
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Copy) {
const auto& operands = inst.GetOperands();
return operands.size() >= 2 && operands[1].GetKind() == OperandKind::Imm;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress()) {
return false;
}
const auto& address = inst.GetAddress();
return address.base_kind != AddrBaseKind::VReg && address.scaled_vregs.empty();
}
std::vector<PhysReg> GetAllocatableRegs(RegClass reg_class) {
std::vector<PhysReg> regs;
if (reg_class == RegClass::FPR) {
for (int i = 19; i <= 31; ++i) {
regs.push_back({RegClass::FPR, i});
}
for (int i = 8; i <= 15; ++i) {
regs.push_back({RegClass::FPR, i});
}
return regs;
}
regs.push_back({RegClass::GPR, 8});
for (int i = 13; i <= 15; ++i) {
regs.push_back({RegClass::GPR, i});
}
for (int i = 19; i <= 28; ++i) {
regs.push_back({RegClass::GPR, i});
}
return regs;
}
int CreateSpillSlot(MachineFunction& function, int vreg) {
const auto type = function.GetVRegInfo(vreg).type;
return function.CreateStackObject(GetValueSize(type), GetValueAlign(type),
StackObjectKind::Spill,
"spill." + std::to_string(vreg));
}
std::vector<BlockInfo> AnalyzeBlocks(const MachineFunction& function) {
const auto& blocks = function.GetBlocks();
const int num_blocks = static_cast<int>(blocks.size());
const int num_vregs = static_cast<int>(function.GetVRegs().size());
std::vector<BlockInfo> infos(static_cast<size_t>(num_blocks));
std::vector<std::pair<std::string, int>> block_name_to_index;
block_name_to_index.reserve(blocks.size());
for (int i = 0; i < num_blocks; ++i) {
block_name_to_index.push_back({blocks[static_cast<size_t>(i)]->GetName(), i});
}
auto find_block_index = [&](const std::string& name) {
auto it = std::find_if(block_name_to_index.begin(), block_name_to_index.end(),
[&](const auto& item) { return item.first == name; });
if (it == block_name_to_index.end()) {
throw std::runtime_error(FormatError("mir", "unknown basic block label: " + name));
}
return it->second;
};
int position = 0;
for (int block_index = 0; block_index < num_blocks; ++block_index) {
auto& info = infos[static_cast<size_t>(block_index)];
info.start_pos = position;
info.use.assign(static_cast<size_t>(num_vregs), 0);
info.def.assign(static_cast<size_t>(num_vregs), 0);
info.live_in.assign(static_cast<size_t>(num_vregs), 0);
info.live_out.assign(static_cast<size_t>(num_vregs), 0);
const auto& instructions = blocks[static_cast<size_t>(block_index)]->GetInstructions();
for (const auto& inst : instructions) {
for (int use : inst.GetUses()) {
if (use >= 0 && use < num_vregs && !info.def[static_cast<size_t>(use)]) {
info.use[static_cast<size_t>(use)] = 1;
}
}
for (int def : inst.GetDefs()) {
if (def >= 0 && def < num_vregs) {
info.def[static_cast<size_t>(def)] = 1;
}
}
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Br:
if (!inst.GetOperands().empty()) {
info.successors.push_back(find_block_index(inst.GetOperands()[0].GetText()));
}
break;
case MachineInstr::Opcode::CondBr:
if (inst.GetOperands().size() >= 3) {
info.successors.push_back(find_block_index(inst.GetOperands()[1].GetText()));
info.successors.push_back(find_block_index(inst.GetOperands()[2].GetText()));
}
break;
default:
break;
}
position += 2;
}
std::sort(info.successors.begin(), info.successors.end());
info.successors.erase(std::unique(info.successors.begin(), info.successors.end()),
info.successors.end());
info.end_pos = position;
}
bool changed = true;
while (changed) {
changed = false;
for (int block_index = num_blocks - 1; block_index >= 0; --block_index) {
auto& info = infos[static_cast<size_t>(block_index)];
std::vector<std::uint8_t> next_out(static_cast<size_t>(num_vregs), 0);
std::vector<std::uint8_t> next_in(static_cast<size_t>(num_vregs), 0);
for (int succ : info.successors) {
const auto& succ_in = infos[static_cast<size_t>(succ)].live_in;
for (int vreg = 0; vreg < num_vregs; ++vreg) {
next_out[static_cast<size_t>(vreg)] |= succ_in[static_cast<size_t>(vreg)];
}
}
for (int vreg = 0; vreg < num_vregs; ++vreg) {
const size_t idx = static_cast<size_t>(vreg);
next_in[idx] = info.use[idx] |
(next_out[idx] & static_cast<std::uint8_t>(!info.def[idx]));
}
if (next_out != info.live_out || next_in != info.live_in) {
changed = true;
info.live_out = std::move(next_out);
info.live_in = std::move(next_in);
}
}
}
return infos;
}
class GeorgeColoringAllocator {
public:
GeorgeColoringAllocator(MachineFunction& function, RegClass reg_class,
const std::vector<BlockInfo>& block_infos)
: function_(function),
reg_class_(reg_class),
regs_(GetAllocatableRegs(reg_class)),
k_(static_cast<int>(regs_.size())),
block_infos_(block_infos),
num_vregs_(static_cast<int>(function.GetVRegs().size())),
in_class_(static_cast<size_t>(num_vregs_), 0),
live_across_call_(static_cast<size_t>(num_vregs_), 0),
rematerializable_(static_cast<size_t>(num_vregs_), 0),
adjacency_(static_cast<size_t>(num_vregs_)),
degree_(static_cast<size_t>(num_vregs_), 0),
spill_cost_(static_cast<size_t>(num_vregs_), 0.0),
move_list_(static_cast<size_t>(num_vregs_)),
alias_(static_cast<size_t>(num_vregs_), -1),
color_index_(static_cast<size_t>(num_vregs_), -1),
in_select_stack_(static_cast<size_t>(num_vregs_), 0),
is_coalesced_(static_cast<size_t>(num_vregs_), 0),
is_spilled_(static_cast<size_t>(num_vregs_), 0),
is_colored_(static_cast<size_t>(num_vregs_), 0),
simplify_worklist_(static_cast<size_t>(num_vregs_), 0),
freeze_worklist_(static_cast<size_t>(num_vregs_), 0),
spill_worklist_(static_cast<size_t>(num_vregs_), 0) {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
alias_[static_cast<size_t>(vreg)] = vreg;
in_class_[static_cast<size_t>(vreg)] =
BelongsToClass(function_.GetVRegInfo(vreg).type, reg_class_) ? 1 : 0;
}
}
void Run() {
if (k_ == 0) {
throw std::runtime_error(FormatError("mir", "no allocatable physical registers"));
}
MarkRematerializableDefs();
Build();
MakeWorklists();
while (HasNodes(simplify_worklist_) || HasNodes(freeze_worklist_) ||
HasNodes(spill_worklist_) || HasMoves(worklist_moves_)) {
if (HasNodes(simplify_worklist_)) {
Simplify();
} else if (HasMoves(worklist_moves_)) {
Coalesce();
} else if (HasNodes(freeze_worklist_)) {
Freeze();
} else if (HasNodes(spill_worklist_)) {
SelectSpill();
}
}
AssignColors();
CommitAllocations();
}
private:
void MarkRematerializableDefs() {
for (const auto& block : function_.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
if (!IsCheapRematerializableInst(inst)) {
continue;
}
for (int def : inst.GetDefs()) {
if (def >= 0 && def < num_vregs_ && in_class_[static_cast<size_t>(def)]) {
rematerializable_[static_cast<size_t>(def)] = 1;
}
}
}
}
}
void Build() {
const auto& blocks = function_.GetBlocks();
for (size_t block_index = 0; block_index < blocks.size(); ++block_index) {
const auto& block = blocks[block_index];
const auto& info = block_infos_[block_index];
std::vector<std::uint8_t> live = info.live_out;
double block_weight = 1.0;
for (int succ : info.successors) {
if (succ <= static_cast<int>(block_index)) {
block_weight = 8.0;
break;
}
}
const auto& instructions = block->GetInstructions();
for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
const auto& inst = *it;
auto defs = FilterClass(inst.GetDefs());
auto uses = FilterClass(inst.GetUses());
if (inst.GetOpcode() == MachineInstr::Opcode::Call ||
inst.GetOpcode() == MachineInstr::Opcode::Memset) {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (live[static_cast<size_t>(vreg)] &&
in_class_[static_cast<size_t>(vreg)]) {
live_across_call_[static_cast<size_t>(vreg)] = 1;
}
}
}
for (int def : defs) {
spill_cost_[static_cast<size_t>(def)] +=
block_weight * (rematerializable_[static_cast<size_t>(def)] ? 0.25 : 1.0);
}
for (int use : uses) {
spill_cost_[static_cast<size_t>(use)] +=
block_weight * (rematerializable_[static_cast<size_t>(use)] ? 0.25 : 1.0);
}
// All source operands are simultaneously live at the instruction input.
// They must interfere with each other, otherwise two distinct values
// used by the same instruction may be colored to the same register.
for (size_t i = 0; i < uses.size(); ++i) {
for (size_t j = i + 1; j < uses.size(); ++j) {
AddEdge(uses[i], uses[j]);
}
}
const bool is_move = inst.GetOpcode() == MachineInstr::Opcode::Copy &&
defs.size() == 1 && uses.size() == 1 && defs[0] != uses[0];
if (is_move) {
const int dst = defs[0];
const int src = uses[0];
const int move_index = static_cast<int>(moves_.size());
moves_.push_back({dst, src});
move_list_[static_cast<size_t>(dst)].push_back(move_index);
move_list_[static_cast<size_t>(src)].push_back(move_index);
live[static_cast<size_t>(src)] = 0;
}
for (int def : defs) {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!live[static_cast<size_t>(vreg)] || !in_class_[static_cast<size_t>(vreg)]) {
continue;
}
AddEdge(def, vreg);
}
}
for (int def : defs) {
live[static_cast<size_t>(def)] = 0;
}
for (int use : uses) {
live[static_cast<size_t>(use)] = 1;
}
}
}
worklist_moves_.assign(moves_.size(), 1);
active_moves_.assign(moves_.size(), 0);
coalesced_moves_.assign(moves_.size(), 0);
constrained_moves_.assign(moves_.size(), 0);
frozen_moves_.assign(moves_.size(), 0);
}
void MakeWorklists() {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!in_class_[static_cast<size_t>(vreg)]) {
continue;
}
if (degree_[static_cast<size_t>(vreg)] >= k_) {
spill_worklist_[static_cast<size_t>(vreg)] = 1;
} else if (MoveRelated(vreg)) {
freeze_worklist_[static_cast<size_t>(vreg)] = 1;
} else {
simplify_worklist_[static_cast<size_t>(vreg)] = 1;
}
}
}
void Simplify() {
const int node = PickAnyNode(simplify_worklist_);
simplify_worklist_[static_cast<size_t>(node)] = 0;
select_stack_.push_back(node);
in_select_stack_[static_cast<size_t>(node)] = 1;
for (int neighbor : Adjacent(node)) {
DecrementDegree(neighbor);
}
}
void Coalesce() {
const int move_index = PickBestMove();
worklist_moves_[static_cast<size_t>(move_index)] = 0;
int x = GetAlias(moves_[static_cast<size_t>(move_index)].dst);
int y = GetAlias(moves_[static_cast<size_t>(move_index)].src);
if (x == y) {
coalesced_moves_[static_cast<size_t>(move_index)] = 1;
AddWorkList(x);
return;
}
if (AdjacentTo(x, y)) {
constrained_moves_[static_cast<size_t>(move_index)] = 1;
AddWorkList(x);
AddWorkList(y);
return;
}
int u = x;
int v = y;
if (degree_[static_cast<size_t>(v)] > degree_[static_cast<size_t>(u)]) {
std::swap(u, v);
}
if (GeorgeOK(v, u) || ConservativeUnion(u, v)) {
coalesced_moves_[static_cast<size_t>(move_index)] = 1;
Combine(u, v);
AddWorkList(u);
return;
}
active_moves_[static_cast<size_t>(move_index)] = 1;
}
void Freeze() {
const int node = PickAnyNode(freeze_worklist_);
freeze_worklist_[static_cast<size_t>(node)] = 0;
simplify_worklist_[static_cast<size_t>(node)] = 1;
FreezeMoves(node);
}
void SelectSpill() {
int best = -1;
double best_priority = std::numeric_limits<double>::infinity();
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!spill_worklist_[static_cast<size_t>(vreg)]) {
continue;
}
double priority = spill_cost_[static_cast<size_t>(vreg)] /
std::max(1, degree_[static_cast<size_t>(vreg)]);
if (rematerializable_[static_cast<size_t>(vreg)]) {
priority *= 0.2;
}
if (MoveRelated(vreg)) {
priority *= 1.15;
}
if (live_across_call_[static_cast<size_t>(vreg)] &&
!rematerializable_[static_cast<size_t>(vreg)]) {
priority *= 1.25;
}
if (best < 0 || priority < best_priority) {
best = vreg;
best_priority = priority;
}
}
if (best < 0) {
throw std::runtime_error(FormatError("mir", "failed to select spill candidate"));
}
spill_worklist_[static_cast<size_t>(best)] = 0;
simplify_worklist_[static_cast<size_t>(best)] = 1;
FreezeMoves(best);
}
void AssignColors() {
while (!select_stack_.empty()) {
const int node = select_stack_.back();
select_stack_.pop_back();
in_select_stack_[static_cast<size_t>(node)] = 0;
std::vector<std::uint8_t> ok_colors(static_cast<size_t>(regs_.size()), 1);
if (live_across_call_[static_cast<size_t>(node)]) {
for (size_t i = 0; i < regs_.size(); ++i) {
if (IsCallerSaved(regs_[i])) {
ok_colors[i] = 0;
}
}
}
for (int neighbor : adjacency_[static_cast<size_t>(node)]) {
const int alias = GetAlias(neighbor);
if (!is_colored_[static_cast<size_t>(alias)]) {
continue;
}
const int color = color_index_[static_cast<size_t>(alias)];
if (color >= 0 && color < static_cast<int>(regs_.size())) {
ok_colors[static_cast<size_t>(color)] = 0;
}
}
int chosen = -1;
for (size_t i = 0; i < ok_colors.size(); ++i) {
if (ok_colors[i]) {
chosen = static_cast<int>(i);
break;
}
}
if (chosen < 0) {
is_spilled_[static_cast<size_t>(node)] = 1;
continue;
}
is_colored_[static_cast<size_t>(node)] = 1;
color_index_[static_cast<size_t>(node)] = chosen;
}
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!is_coalesced_[static_cast<size_t>(vreg)]) {
continue;
}
const int alias = GetAlias(vreg);
if (is_spilled_[static_cast<size_t>(alias)]) {
is_spilled_[static_cast<size_t>(vreg)] = 1;
} else {
is_colored_[static_cast<size_t>(vreg)] = 1;
color_index_[static_cast<size_t>(vreg)] = color_index_[static_cast<size_t>(alias)];
}
}
}
void CommitAllocations() {
std::unordered_map<int, Allocation> representative_allocations;
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!in_class_[static_cast<size_t>(vreg)]) {
continue;
}
const int rep = GetAlias(vreg);
if (representative_allocations.find(rep) != representative_allocations.end()) {
continue;
}
Allocation allocation;
if (is_spilled_[static_cast<size_t>(rep)]) {
allocation.kind = Allocation::Kind::Spill;
allocation.stack_object = CreateSpillSlot(function_, rep);
} else if (is_colored_[static_cast<size_t>(rep)]) {
allocation.kind = Allocation::Kind::PhysReg;
allocation.phys = regs_[static_cast<size_t>(color_index_[static_cast<size_t>(rep)])];
} else {
allocation.kind = Allocation::Kind::Spill;
allocation.stack_object = CreateSpillSlot(function_, rep);
}
representative_allocations.emplace(rep, allocation);
}
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!in_class_[static_cast<size_t>(vreg)]) {
continue;
}
const Allocation allocation = representative_allocations.at(GetAlias(vreg));
function_.SetAllocation(vreg, allocation);
if (allocation.kind == Allocation::Kind::PhysReg && IsCalleeSaved(allocation.phys)) {
if (allocation.phys.reg_class == RegClass::GPR) {
function_.AddUsedCalleeSavedGPR(allocation.phys.index);
} else {
function_.AddUsedCalleeSavedFPR(allocation.phys.index);
}
}
}
}
void DecrementDegree(int node) {
const int old_degree = degree_[static_cast<size_t>(node)];
--degree_[static_cast<size_t>(node)];
if (old_degree != k_) {
return;
}
auto neighbors = Adjacent(node);
neighbors.push_back(node);
EnableMoves(neighbors);
spill_worklist_[static_cast<size_t>(node)] = 0;
if (MoveRelated(node)) {
freeze_worklist_[static_cast<size_t>(node)] = 1;
} else {
simplify_worklist_[static_cast<size_t>(node)] = 1;
}
}
void AddWorkList(int node) {
if (!in_class_[static_cast<size_t>(node)] || is_coalesced_[static_cast<size_t>(node)] ||
in_select_stack_[static_cast<size_t>(node)] || degree_[static_cast<size_t>(node)] >= k_ ||
MoveRelated(node)) {
return;
}
freeze_worklist_[static_cast<size_t>(node)] = 0;
spill_worklist_[static_cast<size_t>(node)] = 0;
simplify_worklist_[static_cast<size_t>(node)] = 1;
}
void Combine(int keep, int remove) {
simplify_worklist_[static_cast<size_t>(remove)] = 0;
freeze_worklist_[static_cast<size_t>(remove)] = 0;
spill_worklist_[static_cast<size_t>(remove)] = 0;
is_coalesced_[static_cast<size_t>(remove)] = 1;
alias_[static_cast<size_t>(remove)] = keep;
live_across_call_[static_cast<size_t>(keep)] |=
live_across_call_[static_cast<size_t>(remove)];
auto& keep_moves = move_list_[static_cast<size_t>(keep)];
const auto& remove_moves = move_list_[static_cast<size_t>(remove)];
keep_moves.insert(keep_moves.end(), remove_moves.begin(), remove_moves.end());
EnableMoves({remove});
for (int neighbor : Adjacent(remove)) {
AddEdge(neighbor, keep);
DecrementDegree(neighbor);
}
if (freeze_worklist_[static_cast<size_t>(keep)] && degree_[static_cast<size_t>(keep)] >= k_) {
freeze_worklist_[static_cast<size_t>(keep)] = 0;
spill_worklist_[static_cast<size_t>(keep)] = 1;
}
}
void FreezeMoves(int node) {
for (int move_index : NodeMoves(node)) {
if (worklist_moves_[static_cast<size_t>(move_index)]) {
worklist_moves_[static_cast<size_t>(move_index)] = 0;
} else if (active_moves_[static_cast<size_t>(move_index)]) {
active_moves_[static_cast<size_t>(move_index)] = 0;
} else {
continue;
}
frozen_moves_[static_cast<size_t>(move_index)] = 1;
const auto& move = moves_[static_cast<size_t>(move_index)];
const int x = GetAlias(move.dst);
const int y = GetAlias(move.src);
const int other = y == GetAlias(node) ? x : y;
if (!MoveRelated(other) && degree_[static_cast<size_t>(other)] < k_) {
freeze_worklist_[static_cast<size_t>(other)] = 0;
simplify_worklist_[static_cast<size_t>(other)] = 1;
}
}
}
void EnableMoves(const std::vector<int>& nodes) {
for (int node : nodes) {
for (int move_index : NodeMoves(node)) {
if (active_moves_[static_cast<size_t>(move_index)]) {
active_moves_[static_cast<size_t>(move_index)] = 0;
worklist_moves_[static_cast<size_t>(move_index)] = 1;
}
}
}
}
std::vector<int> Adjacent(int node) const {
std::vector<int> neighbors;
for (int neighbor : adjacency_[static_cast<size_t>(node)]) {
if (in_select_stack_[static_cast<size_t>(neighbor)] ||
is_coalesced_[static_cast<size_t>(neighbor)]) {
continue;
}
neighbors.push_back(neighbor);
}
return neighbors;
}
std::vector<int> NodeMoves(int node) const {
std::vector<int> related_moves;
for (int move_index : move_list_[static_cast<size_t>(node)]) {
if (worklist_moves_[static_cast<size_t>(move_index)] ||
active_moves_[static_cast<size_t>(move_index)]) {
related_moves.push_back(move_index);
}
}
return related_moves;
}
bool MoveRelated(int node) const { return !NodeMoves(node).empty(); }
int GetAlias(int node) const {
int current = node;
while (is_coalesced_[static_cast<size_t>(current)]) {
current = alias_[static_cast<size_t>(current)];
}
return current;
}
bool AdjacentTo(int lhs, int rhs) const {
return adjacency_[static_cast<size_t>(lhs)].find(rhs) !=
adjacency_[static_cast<size_t>(lhs)].end();
}
bool GeorgeOK(int candidate, int target) const {
for (int neighbor : Adjacent(candidate)) {
if (degree_[static_cast<size_t>(neighbor)] >= k_ && !AdjacentTo(neighbor, target)) {
return false;
}
}
return true;
}
bool ConservativeUnion(int lhs, int rhs) const {
std::unordered_set<int> union_neighbors;
for (int neighbor : Adjacent(lhs)) {
union_neighbors.insert(neighbor);
}
for (int neighbor : Adjacent(rhs)) {
union_neighbors.insert(neighbor);
}
int high_degree_count = 0;
for (int neighbor : union_neighbors) {
if (degree_[static_cast<size_t>(neighbor)] >= k_) {
++high_degree_count;
}
}
return high_degree_count < k_;
}
void AddEdge(int lhs, int rhs) {
if (lhs == rhs || !in_class_[static_cast<size_t>(lhs)] ||
!in_class_[static_cast<size_t>(rhs)]) {
return;
}
if (adjacency_[static_cast<size_t>(lhs)].insert(rhs).second) {
adjacency_[static_cast<size_t>(rhs)].insert(lhs);
++degree_[static_cast<size_t>(lhs)];
++degree_[static_cast<size_t>(rhs)];
}
}
std::vector<int> FilterClass(const std::vector<int>& regs) const {
std::vector<int> filtered;
for (int reg : regs) {
if (reg >= 0 && reg < num_vregs_ && in_class_[static_cast<size_t>(reg)]) {
filtered.push_back(reg);
}
}
return filtered;
}
bool HasNodes(const std::vector<std::uint8_t>& worklist) const {
return std::any_of(worklist.begin(), worklist.end(),
[](std::uint8_t flag) { return flag != 0; });
}
bool HasMoves(const std::vector<std::uint8_t>& move_flags) const {
return std::any_of(move_flags.begin(), move_flags.end(),
[](std::uint8_t flag) { return flag != 0; });
}
int PickAnyNode(const std::vector<std::uint8_t>& worklist) const {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (worklist[static_cast<size_t>(vreg)]) {
return vreg;
}
}
throw std::runtime_error(FormatError("mir", "failed to pick worklist node"));
}
int PickBestMove() const {
int best = -1;
int best_score = std::numeric_limits<int>::max();
for (size_t i = 0; i < worklist_moves_.size(); ++i) {
if (!worklist_moves_[i]) {
continue;
}
const auto& move = moves_[i];
const int dst = GetAlias(move.dst);
const int src = GetAlias(move.src);
int score = degree_[static_cast<size_t>(dst)] + degree_[static_cast<size_t>(src)];
if (live_across_call_[static_cast<size_t>(dst)] !=
live_across_call_[static_cast<size_t>(src)]) {
score += 4;
}
if (rematerializable_[static_cast<size_t>(dst)] ||
rematerializable_[static_cast<size_t>(src)]) {
score += 2;
}
if (score < best_score) {
best = static_cast<int>(i);
best_score = score;
}
}
if (best >= 0) {
return best;
}
throw std::runtime_error(FormatError("mir", "failed to pick worklist move"));
}
private:
MachineFunction& function_;
RegClass reg_class_;
std::vector<PhysReg> regs_;
int k_ = 0;
const std::vector<BlockInfo>& block_infos_;
int num_vregs_ = 0;
std::vector<std::uint8_t> in_class_;
std::vector<std::uint8_t> live_across_call_;
std::vector<std::uint8_t> rematerializable_;
std::vector<std::unordered_set<int>> adjacency_;
std::vector<int> degree_;
std::vector<double> spill_cost_;
std::vector<std::vector<int>> move_list_;
std::vector<MoveEdge> moves_;
std::vector<int> alias_;
std::vector<int> color_index_;
std::vector<std::uint8_t> in_select_stack_;
std::vector<std::uint8_t> is_coalesced_;
std::vector<std::uint8_t> is_spilled_;
std::vector<std::uint8_t> is_colored_;
std::vector<std::uint8_t> simplify_worklist_;
std::vector<std::uint8_t> freeze_worklist_;
std::vector<std::uint8_t> spill_worklist_;
std::vector<int> select_stack_;
std::vector<std::uint8_t> worklist_moves_;
std::vector<std::uint8_t> active_moves_;
std::vector<std::uint8_t> coalesced_moves_;
std::vector<std::uint8_t> constrained_moves_;
std::vector<std::uint8_t> frozen_moves_;
};
} // namespace
void RunRegAlloc(MachineModule& module) {
for (auto& function : module.GetFunctions()) {
const auto block_infos = AnalyzeBlocks(*function);
GeorgeColoringAllocator gpr_allocator(*function, RegClass::GPR, block_infos);
gpr_allocator.Run();
GeorgeColoringAllocator fpr_allocator(*function, RegClass::FPR, block_infos);
fpr_allocator.Run();
}
}
} // namespace mir

@ -0,0 +1,76 @@
#include "mir/MIR.h"
#include <stdexcept>
namespace mir {
namespace {
const char* kWRegNames[] = {
"w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7",
"w8", "w9", "w10", "w11", "w12", "w13", "w14", "w15",
"w16", "w17", "w18", "w19", "w20", "w21", "w22", "w23",
"w24", "w25", "w26", "w27", "w28", "w29", "w30"};
const char* kXRegNames[] = {
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15",
"x16", "x17", "x18", "x19", "x20", "x21", "x22", "x23",
"x24", "x25", "x26", "x27", "x28", "x29", "x30"};
const char* kSRegNames[] = {
"s0", "s1", "s2", "s3", "s4", "s5", "s6", "s7",
"s8", "s9", "s10", "s11", "s12", "s13", "s14", "s15",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s24", "s25", "s26", "s27", "s28", "s29", "s30", "s31"};
} // namespace
bool IsGPR(ValueType type) {
return type == ValueType::I1 || type == ValueType::I32 || type == ValueType::Ptr;
}
bool IsFPR(ValueType type) { return type == ValueType::F32; }
int GetValueSize(ValueType type) {
switch (type) {
case ValueType::Void:
return 0;
case ValueType::I1:
case ValueType::I32:
case ValueType::F32:
return 4;
case ValueType::Ptr:
return 8;
}
return 0;
}
int GetValueAlign(ValueType type) {
switch (type) {
case ValueType::Void:
return 1;
case ValueType::Ptr:
return 8;
case ValueType::I1:
case ValueType::I32:
case ValueType::F32:
return 4;
}
return 1;
}
const char* GetPhysRegName(PhysReg reg, ValueType type) {
if (!reg.IsValid()) {
throw std::runtime_error("invalid physical register");
}
if (reg.reg_class == RegClass::FPR) {
if (reg.index < 0 || reg.index >= 32) {
throw std::runtime_error("float register index out of range");
}
return kSRegNames[reg.index];
}
if (reg.index < 0 || reg.index >= 31) {
throw std::runtime_error("gpr register index out of range");
}
return type == ValueType::Ptr ? kXRegNames[reg.index] : kWRegNames[reg.index];
}
} // namespace mir

@ -0,0 +1,239 @@
#include "mir/MIR.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace mir {
namespace {
using BlockList = std::vector<std::unique_ptr<MachineBasicBlock>>;
int FindBlockIndex(const MachineFunction& function, const std::string& name) {
const auto& blocks = function.GetBlocks();
for (size_t i = 0; i < blocks.size(); ++i) {
if (blocks[i] && blocks[i]->GetName() == name) {
return static_cast<int>(i);
}
}
return -1;
}
std::vector<int> CollectSuccessors(const MachineFunction& function, int index) {
std::vector<int> succs;
const auto& blocks = function.GetBlocks();
if (index < 0 || index >= static_cast<int>(blocks.size()) || !blocks[index]) {
return succs;
}
const auto& instructions = blocks[index]->GetInstructions();
if (instructions.empty()) {
return succs;
}
const auto& term = instructions.back();
if (term.GetOpcode() == MachineInstr::Opcode::Br && !term.GetOperands().empty() &&
term.GetOperands()[0].GetKind() == OperandKind::Block) {
const int succ = FindBlockIndex(function, term.GetOperands()[0].GetText());
if (succ >= 0) {
succs.push_back(succ);
}
return succs;
}
if (term.GetOpcode() == MachineInstr::Opcode::CondBr &&
term.GetOperands().size() >= 3) {
for (size_t i = 1; i <= 2; ++i) {
if (term.GetOperands()[i].GetKind() != OperandKind::Block) {
continue;
}
const int succ = FindBlockIndex(function, term.GetOperands()[i].GetText());
if (succ >= 0 &&
std::find(succs.begin(), succs.end(), succ) == succs.end()) {
succs.push_back(succ);
}
}
}
return succs;
}
std::vector<int> BuildPredecessorCount(const MachineFunction& function) {
std::vector<int> preds(function.GetBlocks().size(), 0);
for (size_t i = 0; i < function.GetBlocks().size(); ++i) {
for (int succ : CollectSuccessors(function, static_cast<int>(i))) {
++preds[static_cast<size_t>(succ)];
}
}
return preds;
}
bool IsTrivialJumpBlock(const MachineFunction& function, int index) {
const auto& blocks = function.GetBlocks();
if (index < 0 || index >= static_cast<int>(blocks.size()) || !blocks[index]) {
return false;
}
const auto& instructions = blocks[index]->GetInstructions();
return instructions.size() == 1 &&
instructions.front().GetOpcode() == MachineInstr::Opcode::Br &&
!instructions.front().GetOperands().empty() &&
instructions.front().GetOperands()[0].GetKind() == OperandKind::Block;
}
std::string ResolveJumpChain(const MachineFunction& function, const std::string& target) {
std::string current = target;
std::unordered_set<std::string> visited{current};
while (true) {
const int index = FindBlockIndex(function, current);
if (index < 0 || !IsTrivialJumpBlock(function, index)) {
return current;
}
const auto& inst = function.GetBlocks()[static_cast<size_t>(index)]->GetInstructions().front();
const std::string& next = inst.GetOperands()[0].GetText();
if (!visited.insert(next).second) {
return current;
}
current = next;
}
}
bool RewriteBranchTargets(MachineFunction& function) {
bool changed = false;
for (auto& block : function.GetBlocks()) {
if (!block || block->GetInstructions().empty()) {
continue;
}
auto& term = block->GetInstructions().back();
auto& operands = term.GetOperands();
if (term.GetOpcode() == MachineInstr::Opcode::Br && !operands.empty() &&
operands[0].GetKind() == OperandKind::Block) {
const std::string resolved = ResolveJumpChain(function, operands[0].GetText());
if (resolved != operands[0].GetText()) {
operands[0] = MachineOperand::Block(resolved);
changed = true;
}
continue;
}
if (term.GetOpcode() != MachineInstr::Opcode::CondBr || operands.size() < 3) {
continue;
}
for (size_t i = 1; i <= 2; ++i) {
if (operands[i].GetKind() != OperandKind::Block) {
continue;
}
const std::string resolved = ResolveJumpChain(function, operands[i].GetText());
if (resolved != operands[i].GetText()) {
operands[i] = MachineOperand::Block(resolved);
changed = true;
}
}
if (operands[1].GetKind() == OperandKind::Block &&
operands[2].GetKind() == OperandKind::Block &&
operands[1].GetText() == operands[2].GetText()) {
term = MachineInstr(MachineInstr::Opcode::Br, {operands[1]});
changed = true;
}
}
return changed;
}
bool RemoveUnreachableBlocks(MachineFunction& function) {
auto& blocks = function.GetBlocks();
if (blocks.empty() || !blocks.front()) {
return false;
}
std::unordered_set<std::string> reachable;
std::vector<std::string> stack{blocks.front()->GetName()};
while (!stack.empty()) {
std::string name = stack.back();
stack.pop_back();
if (!reachable.insert(name).second) {
continue;
}
const int index = FindBlockIndex(function, name);
if (index < 0) {
continue;
}
for (int succ : CollectSuccessors(function, index)) {
stack.push_back(blocks[static_cast<size_t>(succ)]->GetName());
}
}
const size_t old_size = blocks.size();
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[&](const std::unique_ptr<MachineBasicBlock>& block) {
return block && reachable.count(block->GetName()) == 0;
}),
blocks.end());
return blocks.size() != old_size;
}
bool MergeLinearBlocks(MachineFunction& function) {
auto preds = BuildPredecessorCount(function);
auto& blocks = function.GetBlocks();
for (size_t i = 0; i < blocks.size(); ++i) {
auto& block = blocks[i];
if (!block || block->GetInstructions().empty()) {
continue;
}
auto& insts = block->GetInstructions();
auto& term = insts.back();
if (term.GetOpcode() != MachineInstr::Opcode::Br || term.GetOperands().empty() ||
term.GetOperands()[0].GetKind() != OperandKind::Block) {
continue;
}
const int succ_index = FindBlockIndex(function, term.GetOperands()[0].GetText());
if (succ_index <= 0 || succ_index == static_cast<int>(i) ||
preds[static_cast<size_t>(succ_index)] != 1) {
continue;
}
auto& succ = blocks[static_cast<size_t>(succ_index)];
if (!succ || succ->GetInstructions().empty()) {
continue;
}
insts.pop_back();
auto& succ_insts = succ->GetInstructions();
insts.insert(insts.end(),
std::make_move_iterator(succ_insts.begin()),
std::make_move_iterator(succ_insts.end()));
blocks.erase(blocks.begin() + succ_index);
return true;
}
return false;
}
bool RunCFGCleanupOnFunction(MachineFunction& function) {
bool changed = false;
while (true) {
bool local_changed = false;
local_changed |= RewriteBranchTargets(function);
local_changed |= RemoveUnreachableBlocks(function);
if (MergeLinearBlocks(function)) {
local_changed = true;
}
changed |= local_changed;
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunCFGCleanup(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCFGCleanupOnFunction(*function);
}
}
return changed;
}
} // namespace mir

@ -0,0 +1,11 @@
add_library(mir_passes STATIC
PassManager.cpp
Peephole.cpp
SpillReduction.cpp
CFGCleanup.cpp
)
target_link_libraries(mir_passes PUBLIC
build_options
mir_core
)

@ -0,0 +1,53 @@
#include "mir/MIR.h"
#include <cstdlib>
namespace mir {
void RunMIRPreRegAllocPassPipeline(MachineModule& module) {
const char* disable_spill_reduction = std::getenv("NUDTC_DISABLE_MIR_SPILL_REDUCTION");
const bool run_spill_reduction =
disable_spill_reduction == nullptr || disable_spill_reduction[0] == '\0' ||
disable_spill_reduction[0] == '0';
const char* disable_cfg_cleanup = std::getenv("NUDTC_DISABLE_MIR_CFG_CLEANUP");
const bool run_cfg_cleanup =
disable_cfg_cleanup == nullptr || disable_cfg_cleanup[0] == '\0' ||
disable_cfg_cleanup[0] == '0';
if (run_spill_reduction) {
RunSpillReduction(module);
}
RunAddressHoisting(module);
constexpr int kMaxIterations = 4;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
changed |= RunPeephole(module);
if (run_cfg_cleanup) {
changed |= RunCFGCleanup(module);
}
if (!changed) {
break;
}
}
}
void RunMIRPostRegAllocPassPipeline(MachineModule& module) {
const char* disable_cfg_cleanup = std::getenv("NUDTC_DISABLE_MIR_CFG_CLEANUP");
const bool run_cfg_cleanup =
disable_cfg_cleanup == nullptr || disable_cfg_cleanup[0] == '\0' ||
disable_cfg_cleanup[0] == '0';
constexpr int kMaxIterations = 2;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
changed |= RunPeephole(module);
if (run_cfg_cleanup) {
changed |= RunCFGCleanup(module);
}
if (!changed) {
break;
}
}
}
} // namespace mir

@ -0,0 +1,904 @@
#include "mir/MIR.h"
#include "ir/IR.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace mir {
namespace {
using AliasMap = std::unordered_map<int, MachineOperand>;
struct CFGInfo {
std::vector<std::vector<int>> predecessors;
std::vector<std::vector<int>> successors;
};
struct AddressKey {
AddrBaseKind base_kind = AddrBaseKind::None;
int base_index = -1;
std::string symbol;
std::int64_t const_offset = 0;
std::vector<std::pair<int, std::int64_t>> scaled_vregs;
bool operator==(const AddressKey& rhs) const {
return base_kind == rhs.base_kind && base_index == rhs.base_index &&
symbol == rhs.symbol && const_offset == rhs.const_offset &&
scaled_vregs == rhs.scaled_vregs;
}
};
struct AddressKeyHash {
std::size_t operator()(const AddressKey& key) const {
std::size_t h = static_cast<std::size_t>(key.base_kind);
h ^= std::hash<int>{}(key.base_index) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::string>{}(key.symbol) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::int64_t>{}(key.const_offset) + 0x9e3779b9 + (h << 6) + (h >> 2);
for (const auto& term : key.scaled_vregs) {
h ^= std::hash<int>{}(term.first) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::int64_t>{}(term.second) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
struct MemoryState {
MachineOperand value;
ValueType type = ValueType::Void;
int pending_store_index = -1;
};
using MemoryMap = std::unordered_map<AddressKey, MemoryState, AddressKeyHash>;
bool IsImm(const MachineOperand& operand, std::int64_t value) {
return operand.GetKind() == OperandKind::Imm && operand.GetImm() == value;
}
bool SameExactOperand(const MachineOperand& lhs, const MachineOperand& rhs) {
if (lhs.GetKind() != rhs.GetKind()) {
return false;
}
switch (lhs.GetKind()) {
case OperandKind::Invalid:
return true;
case OperandKind::VReg:
return lhs.GetVReg() == rhs.GetVReg();
case OperandKind::Imm:
return lhs.GetImm() == rhs.GetImm();
case OperandKind::Block:
case OperandKind::Symbol:
return lhs.GetText() == rhs.GetText();
}
return false;
}
bool SameResolvedLocation(const MachineFunction& function, int lhs_vreg, int rhs_vreg) {
if (lhs_vreg == rhs_vreg) {
return true;
}
const auto& lhs = function.GetAllocation(lhs_vreg);
const auto& rhs = function.GetAllocation(rhs_vreg);
if (lhs.kind == Allocation::Kind::Unassigned || rhs.kind == Allocation::Kind::Unassigned ||
lhs.kind != rhs.kind) {
return false;
}
if (lhs.kind == Allocation::Kind::PhysReg) {
return lhs.phys == rhs.phys;
}
if (lhs.kind == Allocation::Kind::Spill) {
return lhs.stack_object == rhs.stack_object;
}
return false;
}
bool SameResolvedOperand(const MachineFunction& function, const MachineOperand& lhs,
const MachineOperand& rhs) {
if (SameExactOperand(lhs, rhs)) {
return true;
}
if (lhs.GetKind() == OperandKind::VReg && rhs.GetKind() == OperandKind::VReg) {
return SameResolvedLocation(function, lhs.GetVReg(), rhs.GetVReg());
}
return false;
}
MachineOperand ResolveAlias(const AliasMap& aliases, const MachineOperand& operand) {
if (operand.GetKind() != OperandKind::VReg) {
return operand;
}
int current = operand.GetVReg();
std::unordered_set<int> visited;
visited.insert(current);
while (true) {
auto it = aliases.find(current);
if (it == aliases.end()) {
return MachineOperand::VReg(current);
}
if (it->second.GetKind() != OperandKind::VReg) {
return it->second;
}
const int next = it->second.GetVReg();
if (!visited.insert(next).second) {
return MachineOperand::VReg(current);
}
current = next;
}
}
bool RewriteOperand(MachineOperand& operand, const AliasMap& aliases) {
const auto rewritten = ResolveAlias(aliases, operand);
if (SameExactOperand(rewritten, operand)) {
return false;
}
operand = rewritten;
return true;
}
bool RewriteAddress(AddressExpr& address, const AliasMap& aliases) {
bool changed = false;
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(address.base_index));
if (rewritten.GetKind() == OperandKind::VReg &&
rewritten.GetVReg() != address.base_index) {
address.base_index = rewritten.GetVReg();
changed = true;
}
}
std::vector<std::pair<int, std::int64_t>> rewritten_scaled;
rewritten_scaled.reserve(address.scaled_vregs.size());
for (const auto& term : address.scaled_vregs) {
const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(term.first));
if (rewritten.GetKind() == OperandKind::Imm) {
address.const_offset += rewritten.GetImm() * term.second;
changed = true;
continue;
}
if (rewritten.GetKind() == OperandKind::VReg && rewritten.GetVReg() != term.first) {
rewritten_scaled.push_back({rewritten.GetVReg(), term.second});
changed = true;
continue;
}
rewritten_scaled.push_back(term);
}
if (rewritten_scaled.size() != address.scaled_vregs.size()) {
changed = true;
}
address.scaled_vregs = std::move(rewritten_scaled);
return changed;
}
bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) {
bool changed = false;
auto& operands = inst.GetOperands();
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
break;
case MachineInstr::Opcode::Store:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
if (operands.size() >= 3) {
changed |= RewriteOperand(operands[2], aliases);
}
break;
case MachineInstr::Opcode::CondBr:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Call: {
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands.size(); ++i) {
changed |= RewriteOperand(operands[i], aliases);
}
break;
}
case MachineInstr::Opcode::Ret:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Memset:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
break;
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::Unreachable:
break;
}
if (inst.HasAddress()) {
changed |= RewriteAddress(inst.GetAddress(), aliases);
}
return changed;
}
MachineInstr MakeCopyLike(const MachineInstr& inst, MachineOperand source) {
return MachineInstr(MachineInstr::Opcode::Copy,
{inst.GetOperands()[0], std::move(source)});
}
bool SimplifyCopy(const MachineFunction& function, MachineInstr& inst) {
if (inst.GetOpcode() != MachineInstr::Opcode::Copy) {
return false;
}
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) {
return false;
}
return SameResolvedOperand(function, operands[0], operands[1]);
}
bool SimplifyZExt(MachineInstr& inst) {
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
return false;
}
inst = MakeCopyLike(inst, MachineOperand::Imm(operands[1].GetImm() != 0 ? 1 : 0));
return true;
}
bool SimplifyIntegerBinary(MachineInstr& inst) {
const auto opcode = inst.GetOpcode();
const auto& operands = inst.GetOperands();
if (operands.size() < 3) {
return false;
}
const auto& lhs = operands[1];
const auto& rhs = operands[2];
switch (opcode) {
case MachineInstr::Opcode::Add:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
return false;
case MachineInstr::Opcode::Sub:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
case MachineInstr::Opcode::Mul:
if (IsImm(rhs, 1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 1)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
if (IsImm(rhs, 0) || IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, MachineOperand::Imm(0));
return true;
}
return false;
case MachineInstr::Opcode::Div:
if (IsImm(rhs, 1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
case MachineInstr::Opcode::And:
if (IsImm(rhs, -1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, -1)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
if (IsImm(rhs, 0) || IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, MachineOperand::Imm(0));
return true;
}
return false;
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
return false;
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
default:
return false;
}
}
bool SimplifyCondBr(MachineInstr& inst) {
auto& operands = inst.GetOperands();
if (operands.size() < 3) {
return false;
}
if (operands[1].GetKind() == OperandKind::Block &&
operands[2].GetKind() == OperandKind::Block &&
operands[1].GetText() == operands[2].GetText()) {
inst = MachineInstr(MachineInstr::Opcode::Br, {operands[1]});
return true;
}
if (operands[0].GetKind() != OperandKind::Imm) {
return false;
}
inst = MachineInstr(MachineInstr::Opcode::Br,
{operands[0].GetImm() != 0 ? operands[1] : operands[2]});
return true;
}
bool SimplifyInstruction(MachineInstr& inst) {
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::ZExt:
return SimplifyZExt(inst);
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
return SimplifyIntegerBinary(inst);
case MachineInstr::Opcode::CondBr:
return SimplifyCondBr(inst);
default:
return false;
}
}
bool TrackAlias(const MachineInstr& inst, AliasMap& aliases) {
if (inst.GetOpcode() != MachineInstr::Opcode::Copy) {
return false;
}
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) {
return false;
}
aliases[operands[0].GetVReg()] = operands[1];
return true;
}
AddressKey MakeAddressKey(const AddressExpr& address) {
return {address.base_kind, address.base_index, address.symbol, address.const_offset,
address.scaled_vregs};
}
bool HasTrackedAddress(const MachineInstr& inst) {
return inst.HasAddress() && inst.GetAddress().base_kind != AddrBaseKind::None;
}
const ir::Function* LookupSourceCallee(const MachineModule& module,
const MachineInstr& inst) {
if (inst.GetOpcode() != MachineInstr::Opcode::Call || inst.GetCallee().empty()) {
return nullptr;
}
return module.GetSourceModule().GetFunction(inst.GetCallee());
}
bool CallMayReadMemory(const MachineModule& module, const MachineInstr& inst) {
auto* callee = LookupSourceCallee(module, inst);
return callee == nullptr || callee->MayReadMemory();
}
bool CallMayWriteMemory(const MachineModule& module, const MachineInstr& inst) {
auto* callee = LookupSourceCallee(module, inst);
return callee == nullptr || callee->MayWriteMemory();
}
bool SameMemoryStateValue(const MemoryState& lhs, const MemoryState& rhs) {
return lhs.type == rhs.type && SameExactOperand(lhs.value, rhs.value);
}
bool SameMemoryMap(const MemoryMap& lhs, const MemoryMap& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& [key, value] : lhs) {
auto it = rhs.find(key);
if (it == rhs.end() || !SameMemoryStateValue(value, it->second)) {
return false;
}
}
return true;
}
MemoryMap MeetMemoryStates(const std::vector<const MemoryMap*>& predecessors) {
if (predecessors.empty()) {
return {};
}
MemoryMap in = *predecessors.front();
for (auto it = in.begin(); it != in.end();) {
bool keep = true;
for (std::size_t i = 1; i < predecessors.size(); ++i) {
auto pred_it = predecessors[i]->find(it->first);
if (pred_it == predecessors[i]->end() ||
!SameMemoryStateValue(it->second, pred_it->second)) {
keep = false;
break;
}
}
if (!keep) {
it = in.erase(it);
continue;
}
++it;
}
return in;
}
CFGInfo BuildCFG(const MachineFunction& function) {
CFGInfo cfg;
const auto& blocks = function.GetBlocks();
cfg.predecessors.resize(blocks.size());
cfg.successors.resize(blocks.size());
std::unordered_map<std::string, int> name_to_index;
for (std::size_t i = 0; i < blocks.size(); ++i) {
name_to_index.emplace(blocks[i]->GetName(), static_cast<int>(i));
}
auto add_edge = [&](int pred, const std::string& succ_name) {
auto it = name_to_index.find(succ_name);
if (it == name_to_index.end()) {
return;
}
cfg.successors[static_cast<std::size_t>(pred)].push_back(it->second);
cfg.predecessors[static_cast<std::size_t>(it->second)].push_back(pred);
};
for (std::size_t i = 0; i < blocks.size(); ++i) {
const auto& instructions = blocks[i]->GetInstructions();
if (instructions.empty()) {
continue;
}
const auto& terminator = instructions.back();
if (terminator.GetOpcode() == MachineInstr::Opcode::Br &&
!terminator.GetOperands().empty()) {
add_edge(static_cast<int>(i), terminator.GetOperands()[0].GetText());
} else if (terminator.GetOpcode() == MachineInstr::Opcode::CondBr &&
terminator.GetOperands().size() >= 3) {
add_edge(static_cast<int>(i), terminator.GetOperands()[1].GetText());
add_edge(static_cast<int>(i), terminator.GetOperands()[2].GetText());
}
auto& succs = cfg.successors[i];
std::sort(succs.begin(), succs.end());
succs.erase(std::unique(succs.begin(), succs.end()), succs.end());
}
for (auto& preds : cfg.predecessors) {
std::sort(preds.begin(), preds.end());
preds.erase(std::unique(preds.begin(), preds.end()), preds.end());
}
return cfg;
}
bool SameBaseObject(const AddressKey& lhs, const AddressKey& rhs) {
if (lhs.base_kind != rhs.base_kind) {
return false;
}
switch (lhs.base_kind) {
case AddrBaseKind::FrameObject:
case AddrBaseKind::VReg:
return lhs.base_index == rhs.base_index;
case AddrBaseKind::Global:
return lhs.symbol == rhs.symbol;
case AddrBaseKind::None:
return false;
}
return false;
}
void InvalidateMemoryState(std::unordered_map<AddressKey, MemoryState, AddressKeyHash>& states,
const AddressKey* store_key) {
if (store_key == nullptr) {
states.clear();
return;
}
if (store_key->base_kind == AddrBaseKind::VReg) {
states.clear();
return;
}
for (auto it = states.begin(); it != states.end();) {
if (it->first.base_kind == AddrBaseKind::VReg || SameBaseObject(it->first, *store_key)) {
it = states.erase(it);
continue;
}
++it;
}
}
void ObservePendingStores(MemoryMap& states) {
for (auto& [_, state] : states) {
state.pending_store_index = -1;
}
}
bool TryOptimizeMemoryInstruction(
const MachineModule& module, const MachineFunction& function,
MachineInstr& inst,
MemoryMap& states,
std::vector<bool>& removed,
std::size_t current_index,
bool* remove_current) {
*remove_current = false;
if (inst.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayWriteMemory(module, inst)) {
InvalidateMemoryState(states, nullptr);
}
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Memset) {
InvalidateMemoryState(states, nullptr);
return false;
}
if (!HasTrackedAddress(inst)) {
return false;
}
const AddressKey key = MakeAddressKey(inst.GetAddress());
if (inst.GetOpcode() == MachineInstr::Opcode::Load) {
ValueType load_type = ValueType::Void;
if (!inst.GetOperands().empty() && inst.GetOperands()[0].GetKind() == OperandKind::VReg) {
load_type = function.GetVRegInfo(inst.GetOperands()[0].GetVReg()).type;
}
auto it = states.find(key);
if (it != states.end() && it->second.type == load_type) {
inst = MakeCopyLike(inst, it->second.value);
it->second.pending_store_index = -1;
return true;
}
auto dest = inst.GetOperands()[0];
states[key] = {dest, load_type, -1};
return false;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Store) {
return false;
}
const auto value = inst.GetOperands()[0];
auto existing = states.find(key);
if (existing != states.end() && existing->second.type == inst.GetValueType() &&
SameExactOperand(existing->second.value, value)) {
*remove_current = true;
return true;
}
if (existing != states.end() && existing->second.pending_store_index >= 0) {
removed[static_cast<std::size_t>(existing->second.pending_store_index)] = true;
}
InvalidateMemoryState(states, &key);
states[key] = {value, inst.GetValueType(), static_cast<int>(current_index)};
return false;
}
void ApplyMemoryDataflowInstruction(const MachineModule& module, const MachineInstr& inst,
MemoryMap& states) {
if (inst.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayWriteMemory(module, inst)) {
InvalidateMemoryState(states, nullptr);
}
return;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Memset) {
InvalidateMemoryState(states, nullptr);
return;
}
if (!HasTrackedAddress(inst)) {
return;
}
const AddressKey key = MakeAddressKey(inst.GetAddress());
if (inst.GetOpcode() == MachineInstr::Opcode::Store) {
InvalidateMemoryState(states, &key);
states[key] = {inst.GetOperands()[0], inst.GetValueType(), -1};
return;
}
}
MemoryMap SimulateBlockMemory(const MachineModule& module, const MachineBasicBlock& block,
const MemoryMap& in_state) {
MemoryMap state = in_state;
for (const auto& inst : block.GetInstructions()) {
ApplyMemoryDataflowInstruction(module, inst, state);
}
return state;
}
bool RunPeepholeOnBlock(const MachineModule& module, const MachineFunction& function,
MachineBasicBlock& block, const MemoryMap& in_state) {
bool changed = false;
AliasMap aliases;
MemoryMap memory_states = in_state;
std::vector<MachineInstr> rewritten;
std::vector<bool> removed;
rewritten.reserve(block.GetInstructions().size());
removed.reserve(block.GetInstructions().size());
for (const auto& original : block.GetInstructions()) {
MachineInstr inst = original;
changed |= RewriteUses(inst, aliases);
changed |= SimplifyInstruction(inst);
if (SimplifyCopy(function, inst)) {
changed = true;
continue;
}
rewritten.push_back(std::move(inst));
removed.push_back(false);
MachineInstr& current = rewritten.back();
bool remove_current = false;
changed |= TryOptimizeMemoryInstruction(module, function, current, memory_states, removed,
rewritten.size() - 1, &remove_current);
if (remove_current) {
removed.back() = true;
changed = true;
continue;
}
changed |= SimplifyInstruction(current);
if (SimplifyCopy(function, current)) {
removed.back() = true;
changed = true;
continue;
}
if (current.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayReadMemory(module, current) || CallMayWriteMemory(module, current)) {
ObservePendingStores(memory_states);
}
} else if (current.GetOpcode() == MachineInstr::Opcode::Memset) {
ObservePendingStores(memory_states);
}
TrackAlias(current, aliases);
}
std::vector<MachineInstr> compacted;
compacted.reserve(rewritten.size());
for (std::size_t i = 0; i < rewritten.size(); ++i) {
if (!removed[i]) {
compacted.push_back(std::move(rewritten[i]));
} else {
changed = true;
}
}
if (compacted.size() != block.GetInstructions().size()) {
changed = true;
}
if (changed) {
block.GetInstructions() = std::move(compacted);
}
return changed;
}
bool IsSideEffectFree(const MachineInstr& inst) {
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::FNeg:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
return true;
case MachineInstr::Opcode::Store:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::CondBr:
case MachineInstr::Opcode::Call:
case MachineInstr::Opcode::Ret:
case MachineInstr::Opcode::Memset:
case MachineInstr::Opcode::Unreachable:
return false;
}
return false;
}
bool RunDeadInstrElimination(MachineFunction& function) {
bool changed = false;
while (true) {
std::unordered_map<int, int> use_counts;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
for (int use : inst.GetUses()) {
++use_counts[use];
}
}
}
bool local_changed = false;
for (auto& block : function.GetBlocks()) {
std::vector<MachineInstr> rewritten;
rewritten.reserve(block->GetInstructions().size());
for (auto& inst : block->GetInstructions()) {
const auto defs = inst.GetDefs();
const bool has_live_def =
defs.empty() || use_counts.find(defs.front()) != use_counts.end();
if (has_live_def || !IsSideEffectFree(inst)) {
rewritten.push_back(inst);
continue;
}
local_changed = true;
}
if (local_changed) {
block->GetInstructions() = std::move(rewritten);
}
}
if (!local_changed) {
break;
}
changed = true;
}
return changed;
}
bool HasAssignedAllocations(const MachineFunction& function) {
for (const auto& vreg : function.GetVRegs()) {
if (function.GetAllocation(vreg.id).kind != Allocation::Kind::Unassigned) {
return true;
}
}
return false;
}
} // namespace
bool RunPeephole(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (!function) {
continue;
}
bool function_changed = false;
const auto cfg = BuildCFG(*function);
std::vector<MemoryMap> in_states(function->GetBlocks().size());
std::vector<MemoryMap> out_states(function->GetBlocks().size());
bool dataflow_changed = true;
while (dataflow_changed) {
dataflow_changed = false;
for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) {
MemoryMap in_state;
if (i != 0) {
std::vector<const MemoryMap*> predecessors;
for (int pred : cfg.predecessors[i]) {
predecessors.push_back(&out_states[static_cast<std::size_t>(pred)]);
}
in_state = MeetMemoryStates(predecessors);
}
auto out_state =
SimulateBlockMemory(module, *function->GetBlocks()[i], in_state);
if (!SameMemoryMap(in_states[i], in_state)) {
in_states[i] = std::move(in_state);
dataflow_changed = true;
}
if (!SameMemoryMap(out_states[i], out_state)) {
out_states[i] = std::move(out_state);
dataflow_changed = true;
}
}
}
for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) {
function_changed |=
RunPeepholeOnBlock(module, *function, *function->GetBlocks()[i], in_states[i]);
}
if (!HasAssignedAllocations(*function)) {
function_changed |= RunDeadInstrElimination(*function);
}
changed |= function_changed;
}
return changed;
}
} // namespace mir

@ -0,0 +1,253 @@
#include "mir/MIR.h"
#include <unordered_map>
#include <utility>
#include <vector>
namespace mir {
namespace {
struct RematDef {
enum class Kind { Invalid, ImmCopy, Lea };
Kind kind = Kind::Invalid;
ValueType type = ValueType::Void;
MachineOperand source;
AddressExpr address;
};
bool IsCheapRematerializableDef(const MachineInstr& inst, RematDef& def) {
const auto defs = inst.GetDefs();
if (defs.size() != 1) {
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Copy) {
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
return false;
}
def.kind = RematDef::Kind::ImmCopy;
def.type = inst.GetValueType();
def.source = operands[1];
return true;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress()) {
return false;
}
const auto& address = inst.GetAddress();
if (address.base_kind == AddrBaseKind::VReg || !address.scaled_vregs.empty()) {
return false;
}
def.kind = RematDef::Kind::Lea;
def.type = ValueType::Ptr;
def.address = address;
return true;
}
MachineInstr BuildRematInstr(int dst_vreg, const RematDef& def) {
switch (def.kind) {
case RematDef::Kind::ImmCopy: {
MachineInstr inst(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(dst_vreg), def.source});
inst.SetValueType(def.type);
return inst;
}
case RematDef::Kind::Lea: {
MachineInstr inst(MachineInstr::Opcode::Lea, {MachineOperand::VReg(dst_vreg)});
inst.SetAddress(def.address);
inst.SetValueType(ValueType::Ptr);
return inst;
}
case RematDef::Kind::Invalid:
break;
}
return MachineInstr(MachineInstr::Opcode::Unreachable, {});
}
bool RewriteMappedOperand(MachineOperand& operand,
const std::unordered_map<int, int>& rename_map) {
if (operand.GetKind() != OperandKind::VReg) {
return false;
}
auto it = rename_map.find(operand.GetVReg());
if (it == rename_map.end() || it->second == operand.GetVReg()) {
return false;
}
operand = MachineOperand::VReg(it->second);
return true;
}
bool RewriteMappedAddress(AddressExpr& address,
const std::unordered_map<int, int>& rename_map) {
bool changed = false;
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
auto it = rename_map.find(address.base_index);
if (it != rename_map.end() && it->second != address.base_index) {
address.base_index = it->second;
changed = true;
}
}
for (auto& term : address.scaled_vregs) {
auto it = rename_map.find(term.first);
if (it != rename_map.end() && it->second != term.first) {
term.first = it->second;
changed = true;
}
}
return changed;
}
bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_map) {
bool changed = false;
auto& operands = inst.GetOperands();
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
break;
case MachineInstr::Opcode::Store:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
if (operands.size() >= 3) {
changed |= RewriteMappedOperand(operands[2], rename_map);
}
break;
case MachineInstr::Opcode::CondBr:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Call: {
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands.size(); ++i) {
changed |= RewriteMappedOperand(operands[i], rename_map);
}
break;
}
case MachineInstr::Opcode::Ret:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Memset:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
break;
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::Unreachable:
break;
}
if (inst.HasAddress()) {
changed |= RewriteMappedAddress(inst.GetAddress(), rename_map);
}
return changed;
}
bool RunSpillReductionOnFunction(MachineFunction& function) {
bool changed = false;
for (auto& block_ptr : function.GetBlocks()) {
auto& instructions = block_ptr->GetInstructions();
std::unordered_map<int, RematDef> available_defs;
std::unordered_map<int, RematDef> after_call_defs;
std::unordered_map<int, int> rename_map;
bool after_call = false;
for (size_t i = 0; i < instructions.size(); ++i) {
if (after_call) {
const auto uses = instructions[i].GetUses();
for (int use : uses) {
if (rename_map.count(use) != 0) {
continue;
}
auto it = after_call_defs.find(use);
if (it == after_call_defs.end()) {
continue;
}
const int new_vreg = function.NewVReg(function.GetVRegInfo(use).type);
instructions.insert(instructions.begin() + static_cast<long long>(i),
BuildRematInstr(new_vreg, it->second));
++i;
rename_map[use] = new_vreg;
available_defs[new_vreg] = it->second;
changed = true;
}
RewriteUses(instructions[i], rename_map);
}
const auto defs = instructions[i].GetDefs();
for (int def : defs) {
available_defs.erase(def);
after_call_defs.erase(def);
rename_map.erase(def);
}
RematDef def;
if (IsCheapRematerializableDef(instructions[i], def)) {
for (int vreg : defs) {
available_defs[vreg] = def;
}
}
if (instructions[i].GetOpcode() == MachineInstr::Opcode::Call ||
instructions[i].GetOpcode() == MachineInstr::Opcode::Memset) {
after_call_defs = available_defs;
rename_map.clear();
after_call = true;
}
}
}
return changed;
}
} // namespace
bool RunSpillReduction(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (function) {
changed |= RunSpillReductionOnFunction(*function);
}
}
return changed;
}
} // namespace mir

@ -0,0 +1,11 @@
add_library(sem STATIC
Sema.cpp
SymbolTable.cpp
ConstEval.cpp
)
target_link_libraries(sem PUBLIC
build_options
frontend
${ANTLR4_RUNTIME_TARGET}
)

@ -0,0 +1,4 @@
// 常量求值:
// - 处理数组维度、全局初始化、const 表达式等编译期可计算场景
// - 为语义分析与 IR 生成提供常量折叠/常量值信息

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save