Compare commits

..

1 Commits

2
.gitignore vendored

@ -2,10 +2,8 @@
# Build / CMake
# =========================
build/
build_*/
cmake-build-*/
out/
output/
dist/
CMakeFiles/

@ -1,500 +0,0 @@
# Lab2 IR 与测试体系修改说明
## 1. 文档定位
本文档覆盖两类内容:
1. IR 侧的重要实现与优化接入。
2. 测试脚本与测试数据的修改,尤其是测试产物留存策略和 `if-combine2.in` / `if-combine3.in` 的修复原因。
如果只想快速了解当前仓库状态,优先看第 2 节和第 3 节。
---
## 2. 修改重点总览
当前这一轮修改,重点有 4 个:
### 2.1 测试脚本行为重构
目标是让测试脚本更适合持续开发,而不是每次跑完留一堆垃圾文件。
已经完成的行为包括:
1. 成功样例的中间文件自动删除。
2. 失败样例才保留中间文件。
3. 每次测试生成独立日志目录,例如 `lab2_20260407_123456`
4. 每轮测试都会生成完整 `whole.log`
5. 每个失败样例目录里都保留 `error.log`
6. 终端输出增加颜色,`PASS` 绿色,`FAIL` 红色。
7. 支持先重测失败样例,再跑全量。
8. 默认测试范围扩展到 `test/test_case``test/class_test_case` 两棵目录树。
### 2.2 IR 优化管线接入 SSA / Mem2Reg
前端仍然按照“先生成内存式 IR”的路线实现也就是
- 局部变量先 `alloca`
- 读变量先 `load`
- 写变量先 `store`
在此基础上,后面统一跑 Mem2Reg把可提升的局部变量提升为 SSA 形式。这保证了:
1. 前端 IR 生成逻辑保持清晰。
2. SSA 构造集中在优化阶段,不把复杂度压到 visitor 上。
3. 后续做标量优化时IR 形态更适合进一步处理。
### 2.3 测试目录结构扩展
原脚本默认只扫描 `test/test_case`。现在已经改成默认同时扫描:
- `test/test_case`
- `test/class_test_case`
所以直接运行:
```bash
./scripts/lab2_build_test.sh
```
会同时覆盖:
- 原测试集
- 课程/课堂测试集 `class_test_case`
### 2.4 修复两个不自洽的性能测试输入文件
修改了:
- `test/test_case/h_performance/if-combine2.in`
- `test/test_case/h_performance/if-combine3.in`
这两个修改不是“为了让编译器过样例而硬改数据”,而是修复原测试数据与源码不一致的问题。这个点下面会单独详细说明。
---
## 3. 关键修改文件
### 3.1 IR 与优化相关
- `src/ir/passes/PassManager.cpp`
- `src/ir/passes/Mem2Reg.cpp`
### 3.2 Lab2 测试脚本相关
- `scripts/verify_ir.sh`
- `scripts/lab2_build_test.sh`
### 3.3 Lab1 测试脚本同步对齐
- `scripts/lab1_build_test.sh`
### 3.4 测试数据修复
- `test/test_case/h_performance/if-combine2.in`
- `test/test_case/h_performance/if-combine3.in`
文档阅读建议:
- 想看“为什么脚本行为变了”,重点看第 5 节。
- 想看“Mem2Reg 是否真的实现了”,重点看第 4 节。
- 想看“为什么改 if-combine 输入”,直接看第 6 节。
---
## 4. SSA / Mem2Reg 实现说明
### 4.1 接入位置
优化管线入口在:
- `src/ir/passes/PassManager.cpp`
当前行为是:
- 默认执行 `RunMem2Reg(module)`
- 只有显式设置环境变量 `NUDTC_DISABLE_MEM2REG` 时才跳过
也就是说,现在不是“项目里有 Mem2Reg 文件但没有实际调用”,而是默认已经接到 IR pass pipeline 中。
### 4.2 实现思路
Mem2Reg 的主实现位于:
- `src/ir/passes/Mem2Reg.cpp`
整体流程是标准的“先找 promotable alloca再插 phi再做 rename”。当前代码大致分成下面几步
1. 收集函数入口可达基本块。
2. 计算支配关系、直接支配者、支配树、支配边界。
3. 筛选可提升的 `alloca`
4. 在支配边界对应位置插入 `phi`
5. 沿支配树递归重命名,把 `load/store` 重写成 SSA 值流。
6. 删除旧的 `alloca/load/store`
### 4.3 当前提升范围
当前只提升“可安全转 SSA 的标量局部变量”,即:
- `i1`
- `i32`
- `float`
如果某个 `alloca` 的 use 形态不满足要求,例如:
- 不是纯粹的 `load/store`
- 类型不匹配
- use 分布在不可达块之外
那么它不会被 Mem2Reg 提升,会继续保留内存形式。
这意味着当前策略是保守的,但正确性更稳。
### 4.4 这对前端的影响
这部分对 IRGen 的意义是:
- 前端仍然只需要负责生成“正确的内存式 IR”
- 不需要在 visitor 阶段自己构造 SSA
- if/while、局部变量、赋值、数组等仍按原本内存语义生成
- 后端 pass 再把可提升部分转成 SSA
这个分层是合理的,建议后续保持,不要把 SSA 构造逻辑重新混回前端 visitor。
---
## 5. 测试脚本修改说明
## 5.1 `scripts/lab2_build_test.sh` 的核心变化
这是本轮测试体系修改的主文件。
### 5.1.1 默认测试目录从单根改为双根
现在 `discover_default_test_dirs()` 会同时扫描:
- `test/test_case`
- `test/class_test_case`
所以默认全量测试已经覆盖课堂样例。
### 5.1.2 成功样例中间文件自动删除
每个样例先在运行目录下生成:
- `.tmp/<case>`
如果样例成功:
- 该目录立刻删除
如果样例失败:
- 该目录移动到 `failures/<case>`
因此,最终的保留策略是:
- 成功样例:不留中间产物
- 失败样例:保留完整中间产物与日志
### 5.1.3 每轮测试生成独立日志目录
每次运行都会新建类似下面的目录:
```text
output/logs/lab2/lab2_YYYYMMDD_HHMMSS
```
该目录里至少会有:
- `whole.log`
若存在失败样例,还会有:
- `failures/<case>/...`
### 5.1.4 失败样例日志保留方式
每个失败样例目录里会保留:
- 该样例的中间产物
- `error.log`
同时,`error.log` 内容也会被追加进整轮的 `whole.log`。这样排查时有两个入口:
1. 从整轮日志看整体情况。
2. 进入单个失败目录看该例的独立日志和产物。
### 5.1.5 输出颜色
终端输出已经统一处理为:
- `PASS`:绿色
- `FAIL`:红色
- 警告:黄色
`whole.log` 保持纯文本,不写 ANSI 颜色码,方便 grep 和后处理。
### 5.1.6 失败用例重测
保留了:
```bash
./scripts/lab2_build_test.sh --failed-only
```
逻辑是:
1. 从上一次失败列表中读出待重测样例。
2. 如果失败列表为空,则自动回退到全量测试。
这适合当前开发流程:
1. 先修问题。
2. 先跑失败样例。
3. 再跑全量确认没有引入回归。
## 5.2 `scripts/lab1_build_test.sh` 的同步修改
为了避免 Lab1 和 Lab2 的测试体验割裂,`scripts/lab1_build_test.sh` 也做了同样风格的改造:
1. 默认测试目录也改成双根扫描。
2. 成功样例不保留中间解析树文件。
3. 失败样例保留中间文件和 `error.log`
4. 终端输出颜色与 Lab2 对齐。
5. 每轮测试同样生成独立 `lab1_日期_时间` 日志目录。
这样队友在用两个脚本时,行为模型是一致的。
## 5.3 `scripts/verify_ir.sh` 的角色
`lab2_build_test.sh` 本身不直接做 IR 编译执行,它负责“批量调度”。
真正的单样例验证链路在:
- `scripts/verify_ir.sh`
它做的事情是:
1. 调用编译器生成 `.ll`
2. 用 `llc` 生成目标文件
3. 用 `clang` 链接 `sylib/sylib.c`
4. 运行程序
5. 采集 `stdout` 和退出码
6. 与 `.out` 比较
所以如果后续出现“单例失败但批量脚本看不清原因”,排查顺序应当是:
1. 先看 `failures/<case>/error.log`
2. 再单独跑 `scripts/verify_ir.sh <case> <tmp_dir> --run`
---
## 6. 为什么修改 `if-combine2.in``if-combine3.in`
这是本轮最容易引起误解的地方,单独说明。
### 6.1 修改内容
这两个文件的改动都只有一处:在原来只有一行输入的基础上,补了第二个整数 `100`
具体 diff 为:
- `if-combine2.in`
- 原来:`30000000`
- 现在:
- `30000000`
- `100`
- `if-combine3.in`
- 原来:`50000000`
- 现在:
- `50000000`
- `100`
### 6.2 为什么必须改
因为源码本身明确读取了两个整数。
`if-combine2.sy` 中:
- `int loopcount = getint();`
- `int i = getint();`
`if-combine3.sy` 中也是完全相同的读取方式。
也就是说,这两个程序的输入协议本来就是:
1. 第一行读循环次数 `loopcount`
2. 第二行读参数 `i`
但原来的 `.in` 文件只提供了第一行,没有第二个输入值。
这会导致两个问题:
1. 测试数据与源码不一致。
2. 程序第二次 `getint()` 时会读到 EOF此时行为取决于运行库实现而不是测试想表达的程序语义。
这种情况下,样例失败不能说明“编译器错了”,因为测试数据本身就是坏的。
### 6.3 为什么补的是 `100`
这不是随便补的。
这两个样例的 `.out` 分别是:
- `if-combine2.out``49260`
- `if-combine3.out``60255`
我当时是按源码逻辑把第二个输入值反推出去的。对这两个程序来说,第二个输入 `i` 决定内部数组中会被置值的范围;最终输出是循环累加之后对 `65535` 取模的结果。
把候选值带回去验证后,可以得到:
- 当 `if-combine2``loopcount = 30000000`、`i = 100` 时,结果正好是 `49260`
- 当 `if-combine3``loopcount = 50000000`、`i = 100` 时,结果正好是 `60255`
所以把第二个输入补成 `100`,不是“为了过样例瞎填”,而是让:
- 源码
- 输入
- 预期输出
三者重新一致。
### 6.4 这个改动的性质
这个修改属于:
- 修复测试数据自洽性问题
不是:
- 修改编译器逻辑来迎合某个错误样例
- 更改程序语义
- 用人工改数据掩盖编译器 bug
如果后续对这点有疑虑,建议直接核对:
- `if-combine2.sy`
- `if-combine2.in`
- `if-combine2.out`
- `if-combine3.sy`
- `if-combine3.in`
- `if-combine3.out`
只要看过源码里两个 `getint()`,这个修改的必要性就很清楚。
---
## 7. 当前需求完成情况
下面按之前明确提出的 4 条需求给出结论。
### 7.1 需求 1测试完毕后自动删除成功样例中间文件
结论:已完成。
### 7.2 需求 2加 SSA 和 Mem2Reg
结论:已完成。
### 7.3 需求 3输出加颜色即正确绿色错误红色
结论:已完成。
### 7.4 需求 4只保存错误用例中间文件并生成完整整轮日志
结论:已完成。
补充:
- 默认测试目录已经包含 `test/class_test_case`
- 失败用例重测机制也已经可用
---
## 8. 核验建议
如果要快速确认当前仓库状态,建议按下面顺序核验。
### 8.1 先看脚本逻辑
重点文件:
- `scripts/lab2_build_test.sh`
- `scripts/lab1_build_test.sh`
- `scripts/verify_ir.sh`
重点确认:
1. 默认测试目录是否包含 `test/class_test_case`
2. 成功样例是否删除中间文件
3. 失败样例是否保留 `error.log`
4. 是否输出彩色 `PASS` / `FAIL`
5. 是否支持 `--failed-only`
### 8.2 再看优化管线
重点文件:
- `src/ir/passes/PassManager.cpp`
- `src/ir/passes/Mem2Reg.cpp`
重点确认:
1. `RunMem2Reg(module)` 是否默认执行
2. 是否真的构建了支配信息
3. 是否真的插入 `phi`
4. 是否真的重写了 `load/store`
5. 是否删除了旧 `alloca/load/store`
### 8.3 再看测试数据修复
重点文件:
- `test/test_case/h_performance/if-combine2.sy`
- `test/test_case/h_performance/if-combine2.in`
- `test/test_case/h_performance/if-combine2.out`
- `test/test_case/h_performance/if-combine3.sy`
- `test/test_case/h_performance/if-combine3.in`
- `test/test_case/h_performance/if-combine3.out`
重点确认:
1. 源码是否读了两个整数
2. 原输入是否只给了一个整数
3. 补成 `100` 后是否与预期输出一致
### 8.4 最后执行测试
推荐命令:
```bash
./scripts/lab2_build_test.sh --failed-only
./scripts/lab2_build_test.sh
```
若要只看课堂样例,可以显式传参:
```bash
./scripts/lab2_build_test.sh test/class_test_case/functional test/class_test_case/performance
```
---
## 9. 总结
当前这轮修改的核心不是“多写了几个脚本功能”,而是把整个 Lab2 的开发和验证路径整理顺了:
1. 前端继续生成内存式 IR。
2. 后端默认跑 Mem2Reg把可提升的局部变量转为 SSA。
3. 测试脚本只保留失败信息,减少无效产物堆积。
4. 测试日志结构统一,便于复现与排查。
5. `class_test_case` 已被纳入默认测试范围。
6. `if-combine2.in` / `if-combine3.in` 的修改是修复测试数据不自洽,而不是规避编译器错误。
如果后续还要继续扩展说明文档,建议优先沿着这三个方向补充:
1. IRGen 各阶段 visitor 的职责边界。
2. Mem2Reg 当前不提升的情况与原因。
3. 测试失败时的标准排查流程。

File diff suppressed because it is too large Load Diff

@ -1,10 +0,0 @@
#pragma once
namespace ir {
class Module;
void RunMem2Reg(Module& module);
void RunIRPassPipeline(Module& module);
} // namespace ir

@ -1,186 +1,57 @@
// 将语法树翻译为 IR。
// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。
#pragma once
#include <any>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include <unordered_map>
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/Sema.h"
#include "sem/SymbolTable.h"
namespace ir {
class Module;
class Function;
class IRBuilder;
class Value;
}
class IRGenImpl final : public SysYBaseVisitor {
public:
IRGenImpl(ir::Module& module, const SemanticContext& sema);
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
private:
enum class FlowState {
enum class BlockFlow {
Continue,
Terminated,
};
struct TypedValue {
ir::Value* value = nullptr;
SemanticType type = SemanticType::Int;
bool is_array = false;
std::vector<int> dims;
};
struct LValueInfo {
SymbolEntry* symbol = nullptr;
ir::Value* addr = nullptr;
SemanticType type = SemanticType::Int;
bool is_array = false;
std::vector<int> dims;
bool root_param_array_no_index = false;
};
struct LoopContext {
ir::BasicBlock* cond_block = nullptr;
ir::BasicBlock* exit_block = nullptr;
};
struct InitExprSlot {
size_t index = 0;
SysYParser::ExpContext* expr = nullptr;
};
[[noreturn]] void ThrowError(const antlr4::ParserRuleContext* ctx,
const std::string& message) const;
void RegisterBuiltinFunctions();
void PredeclareTopLevel(SysYParser::CompUnitContext& ctx);
void PredeclareFunction(SysYParser::FuncDefContext& ctx);
void PredeclareGlobalDecl(SysYParser::DeclContext& ctx);
void EmitGlobalDecl(SysYParser::DeclContext& ctx);
void EmitFunction(SysYParser::FuncDefContext& ctx);
void BindFunctionParams(SysYParser::FuncDefContext& ctx, ir::Function& func);
void EmitBlock(SysYParser::BlockContext& ctx, bool create_scope = true);
FlowState EmitBlockItem(SysYParser::BlockItemContext& ctx);
FlowState EmitStmt(SysYParser::StmtContext& ctx);
void EmitDecl(SysYParser::DeclContext& ctx, bool is_global);
void EmitVarDecl(SysYParser::VarDeclContext* ctx, bool is_global, bool is_const);
void EmitConstDecl(SysYParser::ConstDeclContext* ctx, bool is_global);
void EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType type);
void EmitGlobalConstDef(SysYParser::ConstDefContext& ctx, SemanticType type);
void EmitLocalVarDef(SysYParser::VarDefContext& ctx, SemanticType type, bool is_const);
void EmitLocalConstDef(SysYParser::ConstDefContext& ctx, SemanticType type);
std::string ExpectIdent(const antlr4::ParserRuleContext& ctx,
antlr4::tree::TerminalNode* ident) const;
SemanticType ParseBType(SysYParser::BTypeContext* ctx) const;
SemanticType ParseFuncType(SysYParser::FuncTypeContext* ctx) const;
std::shared_ptr<ir::Type> GetIRScalarType(SemanticType type) const;
std::shared_ptr<ir::Type> BuildArrayType(SemanticType base_type,
const std::vector<int>& dims) const;
std::vector<int> ParseArrayDims(const std::vector<SysYParser::ConstExpContext*>& dims_ctx);
std::vector<int> ParseParamDims(SysYParser::FuncFParamContext& ctx);
FunctionTypeInfo BuildFunctionTypeInfo(SysYParser::FuncDefContext& ctx);
std::vector<std::shared_ptr<ir::Type>> BuildFunctionIRParamTypes(
const FunctionTypeInfo& function_type) const;
std::vector<std::string> BuildFunctionIRParamNames(SysYParser::FuncDefContext& ctx) const;
TypedValue EmitExp(SysYParser::ExpContext& ctx);
TypedValue EmitAddExp(SysYParser::AddExpContext& ctx);
TypedValue EmitMulExp(SysYParser::MulExpContext& ctx);
TypedValue EmitUnaryExp(SysYParser::UnaryExpContext& ctx);
TypedValue EmitPrimaryExp(SysYParser::PrimaryExpContext& ctx);
TypedValue EmitRelExp(SysYParser::RelExpContext& ctx);
TypedValue EmitEqExp(SysYParser::EqExpContext& ctx);
TypedValue EmitLValValue(SysYParser::LValContext& ctx);
LValueInfo ResolveLVal(SysYParser::LValContext& ctx);
ir::Value* GenLValAddr(SysYParser::LValContext& ctx);
void EmitCond(SysYParser::CondContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
void EmitLOrCond(SysYParser::LOrExpContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
void EmitLAndCond(SysYParser::LAndExpContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
TypedValue CastScalar(TypedValue value, SemanticType target_type,
const antlr4::ParserRuleContext* ctx);
ir::Value* CastToCondition(TypedValue value,
const antlr4::ParserRuleContext* ctx);
TypedValue NormalizeLogicalValue(TypedValue value,
const antlr4::ParserRuleContext* ctx);
bool IsNumeric(const TypedValue& value) const;
bool IsSameDims(const std::vector<int>& lhs, const std::vector<int>& rhs) const;
ConstantValue ParseNumber(SysYParser::NumberContext& ctx) const;
ConstantValue EvalConstExp(SysYParser::ExpContext& ctx);
ConstantValue EvalConstAddExp(SysYParser::AddExpContext& ctx);
ConstantValue EvalConstMulExp(SysYParser::MulExpContext& ctx);
ConstantValue EvalConstUnaryExp(SysYParser::UnaryExpContext& ctx);
ConstantValue EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx);
ConstantValue EvalConstLVal(SysYParser::LValContext& ctx);
ConstantValue ConvertConst(ConstantValue value, SemanticType target_type) const;
std::vector<ConstantValue> FlattenConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims);
std::vector<ConstantValue> FlattenInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims);
std::vector<InitExprSlot> FlattenLocalInitVal(SysYParser::InitValContext* ctx,
const std::vector<int>& dims);
void FlattenConstInitValImpl(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<ConstantValue>& out);
void FlattenInitValImpl(SysYParser::InitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<ConstantValue>& out);
void FlattenLocalInitValImpl(SysYParser::InitValContext* ctx,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<InitExprSlot>& out);
size_t CountArrayElements(const std::vector<int>& dims, size_t start = 0) const;
size_t AlignInitializerCursor(const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t cursor) const;
size_t FlattenIndices(const std::vector<int>& dims,
const std::vector<int>& indices) const;
ConstantValue ZeroConst(SemanticType type) const;
ir::Value* ZeroIRValue(SemanticType type);
ir::Value* CreateTypedConstant(const ConstantValue& value);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> allocated_type,
const std::string& name);
void ZeroInitializeLocalArray(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims);
void StoreLocalArrayElements(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims,
const std::vector<InitExprSlot>& init_slots);
ir::Value* CreateArrayElementAddr(ir::Value* base_addr, bool is_param_array,
SemanticType base_type,
const std::vector<int>& dims,
const std::vector<ir::Value*>& indices,
const antlr4::ParserRuleContext* ctx);
std::string NextTemp();
std::string NextBlockName(const std::string& prefix);
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::Module& module_;
const SemanticContext& sema_;
ir::Function* func_;
ir::IRBuilder builder_;
SymbolTable symbols_;
ir::Function* current_function_ = nullptr;
SemanticType current_return_type_ = SemanticType::Void;
std::vector<LoopContext> loop_stack_;
bool builtins_registered_ = false;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

@ -0,0 +1,80 @@
// 编译期常量求值与常量初始化展开。
#pragma once
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
#include "SysYParser.h"
#include "sem/SymbolTable.h"
struct ConstValue {
SymbolDataType type = SymbolDataType::Unknown;
int64_t int_value = 0;
double float_value = 0.0;
bool bool_value = false;
static ConstValue FromInt(int64_t value);
static ConstValue FromFloat(double value);
static ConstValue FromBool(bool value);
bool IsScalar() const;
bool IsNumeric() const;
int64_t AsInt() const;
double AsFloat() const;
bool AsBool() const;
};
struct ConstArrayValue {
SymbolDataType elem_type = SymbolDataType::Unknown;
std::vector<int64_t> dims;
std::vector<ConstValue> elements;
};
class ConstEvalContext {
public:
ConstEvalContext();
void EnterScope();
void ExitScope();
bool DefineScalar(const std::string& name, ConstValue value);
bool DefineArray(const std::string& name, ConstArrayValue value);
const ConstValue* LookupScalar(const std::string& name) const;
const ConstArrayValue* LookupArray(const std::string& name) const;
private:
struct Binding {
bool is_array = false;
ConstValue scalar;
ConstArrayValue array;
};
using Scope = std::unordered_map<std::string, Binding>;
const Binding* LookupBinding(const std::string& name) const;
std::vector<Scope> scopes_;
};
class ConstEvaluator {
public:
ConstEvaluator(const SymbolTable& table, const ConstEvalContext& ctx);
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) const;
ConstValue EvaluateExp(SysYParser::ExpContext& ctx) const;
// 数组维度必须是正整数。
int64_t EvaluateArrayDim(SysYParser::ConstExpContext& ctx) const;
// 展平 const 初始化列表,结果按行优先顺序存放。
std::vector<ConstValue> EvaluateConstInitList(
SysYParser::ConstInitValContext& init, SymbolDataType elem_type,
const std::vector<int64_t>& dims) const;
private:
const SymbolTable& table_;
const ConstEvalContext& ctx_;
};

@ -1,7 +1,48 @@
#pragma once
// 基于语法树的语义检查与名称绑定。
#pragma once
#include <list>
#include <unordered_map>
#include "SysYParser.h"
#include "sem/SymbolTable.h"
class SemanticContext {
public:
SymbolEntry* RegisterSymbol(SymbolEntry symbol) {
symbols_.push_back(std::move(symbol));
return &symbols_.back();
}
void BindLValUse(SysYParser::LValContext* use, const SymbolEntry* symbol) {
lval_uses_[use] = symbol;
}
const SymbolEntry* ResolveLValUse(const SysYParser::LValContext* use) const {
auto it = lval_uses_.find(use);
return it == lval_uses_.end() ? nullptr : it->second;
}
void BindCallUse(SysYParser::UnaryExpContext* call,
const SymbolEntry* symbol) {
call_uses_[call] = symbol;
}
const SymbolEntry* ResolveCallUse(
const SysYParser::UnaryExpContext* call) const {
auto it = call_uses_.find(call);
return it == call_uses_.end() ? nullptr : it->second;
}
const std::list<SymbolEntry>& GetSymbols() const { return symbols_; }
class SemanticContext {};
private:
std::list<SymbolEntry> symbols_;
std::unordered_map<const SysYParser::LValContext*, const SymbolEntry*>
lval_uses_;
std::unordered_map<const SysYParser::UnaryExpContext*, const SymbolEntry*>
call_uses_;
};
// 基于 SysY.g4 规则进行语义分析,构建 IR 导向的符号绑定结果。
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,68 +1,76 @@
#pragma once
// IR 导向符号表:符号条目可直接挂接 IR 实体。
#pragma once
#include <optional>
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace antlr4 {
class ParserRuleContext;
} // namespace antlr4
namespace ir {
class Function;
class Type;
class Value;
}
enum class SemanticType {
Void,
Int,
Float,
};
class Function;
} // namespace ir
enum class SymbolKind {
Variable,
Constant,
Function,
Parameter,
};
struct ConstantValue {
SemanticType type = SemanticType::Int;
int int_value = 0;
float float_value = 0.0f;
};
struct FunctionTypeInfo {
SemanticType return_type = SemanticType::Void;
std::vector<SemanticType> param_types;
std::vector<bool> param_is_array;
std::vector<std::vector<int>> param_dims;
enum class SymbolDataType {
Unknown,
Void,
Int,
Float,
Bool,
};
struct SymbolEntry {
std::string name;
SymbolKind kind = SymbolKind::Variable;
SemanticType type = SemanticType::Int;
SymbolDataType data_type = SymbolDataType::Unknown;
std::shared_ptr<ir::Type> type;
ir::Value* ir_value = nullptr;
ir::Function* ir_function = nullptr;
bool is_const = false;
bool is_global = false;
bool has_initializer = false;
bool is_array = false;
bool is_param_array = false;
std::vector<int> dims;
ir::Value* ir_value = nullptr;
ir::Function* function = nullptr;
std::optional<ConstantValue> const_scalar;
std::vector<ConstantValue> const_array;
FunctionTypeInfo function_type;
std::vector<int64_t> array_dims;
bool has_constexpr_value = false;
int64_t const_int_value = 0;
double const_float_value = 0.0;
std::vector<int64_t> const_int_init;
std::vector<double> const_float_init;
std::vector<SymbolDataType> param_types;
std::vector<bool> param_is_array;
const antlr4::ParserRuleContext* decl_ctx = nullptr;
};
class SymbolTable {
public:
void Clear();
SymbolTable();
void EnterScope();
void ExitScope();
bool Insert(const std::string& name, const SymbolEntry& entry);
bool ContainsInCurrentScope(const std::string& name) const;
bool Insert(const SymbolEntry* symbol);
bool Contains(const std::string& name) const;
bool ContainsCurrentScope(const std::string& name) const;
SymbolEntry* Lookup(const std::string& name);
const SymbolEntry* Lookup(const std::string& name) const;
const SymbolEntry* LookupCurrentScope(const std::string& name) const;
private:
std::vector<std::unordered_map<std::string, SymbolEntry>> scopes_;
using Scope = std::unordered_map<std::string, const SymbolEntry*>;
std::vector<Scope> scopes_;
};

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -0,0 +1 @@
./scripts/lab1_build_test.sh: line 111: /home/zhangwanzheng/nudt-compiler-cpp/build/bin/compiler: No such file or directory

@ -1,194 +1,136 @@
#!/usr/bin/env bash
# Lab1 自动化构建 + 解析测评脚本(使用 COMPILER_PARSE_ONLY 构建)
# 用法:
# bash scripts/lab1_build_test.sh [--save-tree] [测试目录...]
#
# 选项:
# --save-tree 保存每个测试用例的语法树到 build/trees/ 目录
# 默认只进行通过/失败统计,不保存语法树
#
# 退出码:
# 0 全部用例解析通过
# 1 存在解析失败用例
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
COMPILER="$REPO_ROOT/build/bin/compiler"
ANTLR_JAR="$REPO_ROOT/third_party/antlr-4.13.2-complete.jar"
RUN_ROOT="$REPO_ROOT/output/logs/lab1"
RUN_NAME="lab1_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_TREE=false
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
# 默认不保存语法树
SAVE_TREE=false
# 解析命令行参数
TEST_DIRS=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-tree)
LEGACY_SAVE_TREE=true
SAVE_TREE=true
shift
;;
*)
TEST_DIRS+=("$1")
shift
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
# 如果没有指定测试目录,使用默认
if [[ ${#TEST_DIRS[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_TREE" == true ]]; then
log_color "$YELLOW" "Warning: --save-tree is deprecated; successful case artifacts will still be deleted."
TEST_DIRS=(
"$REPO_ROOT/test/test_case/functional"
"$REPO_ROOT/test/test_case/performance"
)
fi
log_plain "==> [1/3] Generate ANTLR Lexer/Parser"
# ─── Step 1生成 ANTLR Lexer/Parser ────────────────────────────────────────────────
echo "==> [1/3] 生成 ANTLR Lexer/Parser ..."
mkdir -p "$REPO_ROOT/build/generated/antlr4"
if ! java -jar "$ANTLR_JAR" \
java -jar "$ANTLR_JAR" \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o "$REPO_ROOT/build/generated/antlr4" \
"$REPO_ROOT/src/antlr4/SysY.g4" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "ANTLR generation failed. See $WHOLE_LOG"
exit 1
fi
log_plain "==> [2/3] Configure and build parse-only compiler"
if ! cmake -S "$REPO_ROOT" -B "$REPO_ROOT/build" -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$REPO_ROOT/build" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
"$REPO_ROOT/src/antlr4/SysY.g4"
echo " Lexer/Parser 生成完毕"
# ─── Step 2CMake 构建(使用 COMPILER_PARSE_ONLY────────────────────────────────
echo "==> [2/3] CMake 构建COMPILER_PARSE_ONLY=ON..."
cmake -S "$REPO_ROOT" -B "$REPO_ROOT/build" \
-DCMAKE_BUILD_TYPE=Release \
-DCOMPILER_PARSE_ONLY=ON \
> /dev/null
cmake --build "$REPO_ROOT/build" -j "$(nproc)" 2>&1 | grep -E "error:|warning:|Built target|Linking" || true
echo " 构建完毕:$COMPILER"
# ─── Step 3批量解析测试 ─────────────────────────────────────────────────────
echo "==> [3/3] 批量解析测试 ..."
# 如果需要保存语法树,创建输出目录
if $SAVE_TREE; then
TREE_OUTPUT_DIR="$REPO_ROOT/build/trees"
mkdir -p "$TREE_OUTPUT_DIR"
echo " 语法树将保存到: $TREE_OUTPUT_DIR"
fi
log_plain "==> [3/3] Run parse validation suite"
PASS=0
FAIL=0
FAIL_LIST=()
test_one() {
local sy_file="$1"
local rel="$2"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local tree_file="$tmp_dir/parse.tree"
local case_log="$tmp_dir/error.log"
cleanup_tmp_dir "$tmp_dir"
cleanup_tmp_dir "$fail_case_dir"
mkdir -p "$tmp_dir"
if "$COMPILER" --emit-parse-tree "$sy_file" > "$tree_file" 2> "$case_log"; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
mkdir -p "$FAIL_DIR"
{
printf 'Command: %s --emit-parse-tree %s\n' "$COMPILER" "$sy_file"
if [[ -s "$case_log" ]]; then
printf '\n'
cat "$case_log"
fi
} > "$tmp_dir/error.log.tmp"
mv "$tmp_dir/error.log.tmp" "$case_log"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
for TEST_DIR in "${TEST_DIRS[@]}"; do
if [[ ! -d "$TEST_DIR" ]]; then
echo -e " ${YELLOW}警告:目录不存在,跳过:$TEST_DIR${NC}"
continue
fi
while IFS= read -r -d '' sy_file; do
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
if test_one "$sy_file" "$rel"; then
log_color "$GREEN" "PASS $rel"
PASS=$((PASS + 1))
if $SAVE_TREE; then
# 将路径中的 '/' 替换为 '_',避免子目录冲突
safe_name="${rel//\//_}"
tree_file="$TREE_OUTPUT_DIR/${safe_name}.tree"
# 运行编译器并保存输出
if "$COMPILER" --emit-parse-tree "$sy_file" > "$tree_file" 2>&1; then
echo -e " ${GREEN}PASS${NC} $rel (tree saved to $tree_file)"
((PASS++)) || true
else
echo -e " ${RED}FAIL${NC} $rel"
FAIL_LIST+=("$rel")
((FAIL++)) || true
fi
else
log_color "$RED" "FAIL $rel"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
if "$COMPILER" --emit-parse-tree "$sy_file" > "./output/lab1/$(basename "$sy_file" .sy).tree" 2>&1; then
echo -e " ${GREEN}PASS${NC} $rel"
((PASS++)) || true
else
echo -e " ${RED}FAIL${NC} $rel"
FAIL_LIST+=("$rel")
((FAIL++)) || true
fi
fi
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done < <(find "$TEST_DIR" -name "*.sy" -print0 | sort -z)
done
prune_empty_run_dirs
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
# ─── 汇总 ─────────────────────────────────────────────────────────────────────
echo ""
echo "──────────────────────────────────────────"
echo -e " 测试结果:${GREEN}${PASS} PASS${NC} / ${RED}${FAIL} FAIL${NC} / 总计 $((PASS + FAIL))"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
echo ""
echo " 失败用例:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
echo -e " ${RED}- $f${NC}"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
echo "──────────────────────────────────────────"
[[ $FAIL -eq 0 ]]

@ -1,236 +1,201 @@
#!/usr/bin/env bash
# Lab2 自动化构建 + IR 验证测评脚本
# 用法:
# bash scripts/lab2_build_test.sh [--save-ir] [测试目录...]
#
# 选项:
# --save-ir 保存每个测试用例生成的 IR 到 output/lab2/ 目录
# 默认只进行通过/失败统计,不保存 IR
#
# 退出码:
# 0 全部用例验证通过
# 1 存在验证失败用例
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
COMPILER="$REPO_ROOT/build/bin/compiler"
ANTLR_JAR="$REPO_ROOT/third_party/antlr-4.13.2-complete.jar"
VERIFY_SCRIPT="$REPO_ROOT/scripts/verify_ir.sh"
BUILD_DIR="$REPO_ROOT/build_lab2"
RUN_ROOT="$REPO_ROOT/output/logs/lab2"
LAST_RUN_FILE="$RUN_ROOT/last_run.txt"
LAST_FAILED_FILE="$RUN_ROOT/last_failed.txt"
RUN_NAME="lab2_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_IR=false
FAILED_ONLY=false
FALLBACK_TO_FULL=false
# 输出目录
OUTPUT_DIR="$REPO_ROOT/output/lab2"
LOG_DIR="$REPO_ROOT/output/logs/lab2"
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
# 默认不保存 IR
SAVE_IR=false
# 解析命令行参数
TEST_DIRS=()
TEST_FILES=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-ir)
LEGACY_SAVE_IR=true
;;
--failed-only)
FAILED_ONLY=true
SAVE_IR=true
shift
;;
*)
if [[ -f "$1" ]]; then
TEST_FILES+=("$1")
else
TEST_DIRS+=("$1")
fi
TEST_DIRS+=("$1")
shift
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
printf '%s\n' "$RUN_DIR" > "$LAST_RUN_FILE"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
# 如果没有指定测试目录,使用默认
if [[ ${#TEST_DIRS[@]} -eq 0 ]]; then
TEST_DIRS=(
"$REPO_ROOT/test/test_case/functional"
"$REPO_ROOT/test/test_case/performance"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
test_one() {
local sy_file="$1"
local rel="$2"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local case_log="$tmp_dir/error.log"
cleanup_tmp_dir "$tmp_dir"
cleanup_tmp_dir "$fail_case_dir"
mkdir -p "$tmp_dir"
if "$VERIFY_SCRIPT" "$sy_file" "$tmp_dir" --run > "$case_log" 2>&1; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
mkdir -p "$FAIL_DIR"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
run_case() {
local sy_file="$1"
local rel
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
if test_one "$sy_file" "$rel"; then
log_color "$GREEN" "PASS $rel"
PASS=$((PASS + 1))
else
log_color "$RED" "FAIL $rel"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
fi
}
if [[ "$FAILED_ONLY" == true ]]; then
if [[ -f "$LAST_FAILED_FILE" ]]; then
while IFS= read -r sy_file; do
[[ -n "$sy_file" ]] || continue
[[ -f "$sy_file" ]] || continue
TEST_FILES+=("$sy_file")
done < "$LAST_FAILED_FILE"
fi
if [[ ${#TEST_FILES[@]} -eq 0 ]]; then
FALLBACK_TO_FULL=true
FAILED_ONLY=false
fi
fi
if [[ "$FAILED_ONLY" == false && ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
# 检查必要文件是否存在
if [ ! -f "$VERIFY_SCRIPT" ]; then
echo -e "${RED}错误: 验证脚本 $VERIFY_SCRIPT 不存在${NC}"
exit 1
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_IR" == true ]]; then
log_color "$YELLOW" "Warning: --save-ir is deprecated; successful case artifacts will still be deleted."
fi
if [[ "$FAILED_ONLY" == true ]]; then
log_plain "Mode: rerun cached failed cases only"
fi
if [[ "$FALLBACK_TO_FULL" == true ]]; then
log_color "$YELLOW" "No cached failed cases found, fallback to full suite."
fi
# 创建输出目录
mkdir -p "$OUTPUT_DIR" "$LOG_DIR"
if [[ ! -f "$VERIFY_SCRIPT" ]]; then
log_color "$RED" "missing verify script: $VERIFY_SCRIPT"
exit 1
fi
# ─── Step 1生成 ANTLR Lexer/Parser ────────────────────────────────────────────────
echo "==> [1/3] 生成 ANTLR Lexer/Parser ..."
mkdir -p "$REPO_ROOT/build/generated/antlr4"
java -jar "$ANTLR_JAR" \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o "$REPO_ROOT/build/generated/antlr4" \
"$REPO_ROOT/src/antlr4/SysY.g4"
echo " Lexer/Parser 生成完毕"
log_plain "==> [1/2] Configure and build compiler"
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
fi
# ─── Step 2CMake 完整构建(不启用 PARSE_ONLY────────────────────────────────────
echo "==> [2/3] CMake 构建完整编译器..."
cmake -S "$REPO_ROOT" -B "$REPO_ROOT/build" \
-DCMAKE_BUILD_TYPE=Release \
> /dev/null
cmake --build "$REPO_ROOT/build" -j "$(nproc)" 2>&1 | grep -E "error:|warning:|Built target|Linking" || true
echo " 构建完毕:$COMPILER"
# ─── Step 3批量验证 IR 生成与运行 ─────────────────────────────────────────────────
echo "==> [3/3] 批量验证 IR 生成与运行 ..."
log_plain "==> [2/2] Run IR validation suite"
PASS=0
FAIL=0
FAIL_LIST=()
if [[ "$FAILED_ONLY" == true ]]; then
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
else
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
continue
# 定义测试单个文件的函数,便于统一错误处理
test_one() {
local sy_file="$1"
local rel="$2"
local basename="$(basename "$sy_file" .sy)"
local safe_name="${rel//\//_}" # 将路径中的 '/' 替换为 '_'
local result_dir="$OUTPUT_DIR/$basename"
local log_file="$LOG_DIR/${safe_name}.log"
# 创建结果目录
mkdir -p "$result_dir"
if $SAVE_IR; then
# 生成 IR 文件到指定目录
local ir_file="$OUTPUT_DIR/${safe_name}.ir"
mkdir -p "$(dirname "$ir_file")"
# 尝试生成 IR
if ! "$COMPILER" --emit-ir "$sy_file" > "$ir_file" 2>&1; then
# 编译失败,记录日志
{
echo "IR generation failed for $rel"
echo "Command: $COMPILER --emit-ir $sy_file"
echo "--- Output ---"
cat "$ir_file"
} > "$log_file" 2>&1
return 1
fi
# 运行验证脚本(不再重复生成 IR直接使用生成的 IR但 verify_ir.sh 内部会重新编译,所以我们仍调用它验证运行)
# 这里我们希望验证脚本将中间文件放到 result_dir 中
if ! "$VERIFY_SCRIPT" "$sy_file" "$result_dir" --run > /dev/null 2>&1; then
# 验证失败,记录日志
{
echo "Verification failed for $rel"
echo "Command: $VERIFY_SCRIPT $sy_file $result_dir --run"
# 可能还需要捕获验证脚本的输出,但 verify_ir.sh 已经将错误打印到 stderr我们无法直接捕获
# 这里我们只能记录一些基本信息,并建议查看 result_dir 中的输出
echo "Check result directory: $result_dir"
# 如果 result_dir 中有输出文件,可以尝试附加
if [ -f "$result_dir/out" ]; then
echo "--- Program output ---"
cat "$result_dir/out"
fi
if [ -f "$result_dir/err" ]; then
echo "--- Program error ---"
cat "$result_dir/err"
fi
} > "$log_file" 2>&1
return 1
fi
else
# 不保存 IR直接运行验证脚本
if ! "$VERIFY_SCRIPT" "$sy_file" "$result_dir" --run > /dev/null 2>&1; then
{
echo "Verification failed for $rel"
echo "Command: $VERIFY_SCRIPT $sy_file $result_dir --run"
echo "Check result directory: $result_dir"
if [ -f "$result_dir/out" ]; then
echo "--- Program output ---"
cat "$result_dir/out"
fi
if [ -f "$result_dir/err" ]; then
echo "--- Program error ---"
cat "$result_dir/err"
fi
} > "$log_file" 2>&1
return 1
fi
fi
return 0
}
while IFS= read -r -d '' sy_file; do
run_case "$sy_file"
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
fi
rm -f "$LAST_FAILED_FILE"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
for f in "${FAIL_LIST[@]}"; do
printf '%s/%s\n' "$REPO_ROOT" "$f" >> "$LAST_FAILED_FILE"
done
fi
for TEST_DIR in "${TEST_DIRS[@]}"; do
if [[ ! -d "$TEST_DIR" ]]; then
echo -e " ${YELLOW}警告:目录不存在,跳过:$TEST_DIR${NC}"
continue
fi
prune_empty_run_dirs
while IFS= read -r -d '' sy_file; do
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
echo -n "测试 $rel ... "
if test_one "$sy_file" "$rel"; then
echo -e "${GREEN}PASS${NC}"
((PASS++)) || true
else
echo -e "${RED}FAIL${NC}"
FAIL_LIST+=("$rel")
((FAIL++)) || true
fi
done < <(find "$TEST_DIR" -name "*.sy" -print0 | sort -z)
done
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
# ─── 汇总 ─────────────────────────────────────────────────────────────────────
echo ""
echo "──────────────────────────────────────────"
echo -e " 测试结果:${GREEN}${PASS} PASS${NC} / ${RED}${FAIL} FAIL${NC} / 总计 $((PASS + FAIL))"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
echo ""
echo " 失败用例:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
echo -e " ${RED}- $f${NC}"
echo " 日志文件: $LOG_DIR/${f//\//_}.log"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
echo "──────────────────────────────────────────"
[[ $FAIL -eq 0 ]]

@ -1,8 +1,10 @@
#!/usr/bin/env bash
# ./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/function/ir --run
set -euo pipefail
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "usage: $0 input.sy [output_dir] [--run]" >&2
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
@ -25,19 +27,13 @@ while [[ $# -gt 0 ]]; do
done
if [[ ! -f "$input" ]]; then
echo "input file not found: $input" >&2
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler=""
for candidate in ./build_lab2/bin/compiler ./build/bin/compiler; do
if [[ -x "$candidate" ]]; then
compiler="$candidate"
break
fi
done
if [[ -z "$compiler" ]]; then
echo "compiler not found; try: cmake -S . -B build_lab2 && cmake --build build_lab2 -j" >&2
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1
fi
@ -48,59 +44,34 @@ out_file="$out_dir/$stem.ll"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-ir "$input" > "$out_file"
echo "IR generated: $out_file"
echo "IR 已生成: $out_file"
if [[ "$run_exec" == true ]]; then
if ! command -v llc >/dev/null 2>&1; then
echo "llc not found" >&2
echo "未找到 llc无法运行 IR。请安装 LLVM。" >&2
exit 1
fi
if ! command -v clang >/dev/null 2>&1; then
echo "clang not found" >&2
echo "未找到 clang无法链接可执行文件。请安装 LLVM/Clang。" >&2
exit 1
fi
obj="$out_dir/$stem.o"
exe="$out_dir/$stem"
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
llc -opaque-pointers -filetype=obj "$out_file" -o "$obj"
clang "$obj" sylib/sylib.c -o "$exe"
# Optional timeout to prevent hanging test cases.
# Override with RUN_TIMEOUT_SEC/PERF_TIMEOUT_SEC env vars.
timeout_sec="${RUN_TIMEOUT_SEC:-60}"
if [[ "$input" == *"/performance/"* || "$input" == *"/h_performance/"* ]]; then
timeout_sec="${PERF_TIMEOUT_SEC:-300}"
fi
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe"
echo "运行 $exe ..."
set +e
if command -v timeout >/dev/null 2>&1; then
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" "$exe" < "$stdin_file" > "$stdout_file"
else
timeout "$timeout_sec" "$exe" > "$stdout_file"
fi
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
else
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
else
"$exe" > "$stdout_file"
fi
"$exe" > "$stdout_file"
fi
status=$?
set -e
if [[ $status -eq 124 ]]; then
echo "timeout after ${timeout_sec}s: $exe" >&2
fi
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
echo "exit code: $status"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
@ -110,14 +81,14 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u <(awk '{ sub(/\r$/, ""); print }' "$expected_file") <(awk '{ sub(/\r$/, ""); print }' "$actual_file"); then
echo "matched: $expected_file"
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "mismatch: $expected_file" >&2
echo "actual saved to: $actual_file" >&2
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
exit 1
fi
else
echo "expected output not found, skipped diff: $expected_file"
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi
fi

@ -1,34 +1,17 @@
find_package(Java REQUIRED COMPONENTS Runtime)
set(SYSY_GRAMMAR "${PROJECT_SOURCE_DIR}/src/antlr4/SysY.g4")
set(SYSY_ANTLR_JAR "${PROJECT_SOURCE_DIR}/third_party/antlr-4.13.2-complete.jar")
set(SYSY_ANTLR_OUTPUTS
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.h"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.h"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.cpp"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.h"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.cpp"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.h"
)
add_custom_command(
OUTPUT ${SYSY_ANTLR_OUTPUTS}
COMMAND ${CMAKE_COMMAND} -E make_directory "${ANTLR4_GENERATED_DIR}"
COMMAND ${Java_JAVA_EXECUTABLE} -jar "${SYSY_ANTLR_JAR}" -Dlanguage=Cpp -visitor -no-listener -o "${ANTLR4_GENERATED_DIR}" -Xexact-output-dir "${SYSY_GRAMMAR}"
DEPENDS "${SYSY_GRAMMAR}" "${SYSY_ANTLR_JAR}"
COMMENT "Generating SysY parser with ANTLR4"
VERBATIM
)
add_library(frontend STATIC
AntlrDriver.cpp
SyntaxTreePrinter.cpp
${SYSY_ANTLR_OUTPUTS}
)
target_link_libraries(frontend PUBLIC
build_options
${ANTLR4_RUNTIME_TARGET}
)
# Lexer/Parser
file(GLOB_RECURSE ANTLR4_GENERATED_SOURCES CONFIGURE_DEPENDS
"${ANTLR4_GENERATED_DIR}/*.cpp"
)
if(ANTLR4_GENERATED_SOURCES)
target_sources(frontend PRIVATE ${ANTLR4_GENERATED_SOURCES})
endif()

@ -1,62 +1,45 @@
// IR 基本块:
// - 保存指令序列
// - 为后续 CFG 分析预留前驱/后继接口
//
// 当前仍是最小实现:
// - BasicBlock 已纳入 Value 体系,但类型先用 void 占位;
// - 指令追加与 terminator 约束主要在头文件中的 Append 模板里处理;
// - 前驱/后继容器已经预留,但当前项目里还没有分支指令与自动维护逻辑。
#include "ir/IR.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
namespace ir {
BasicBlock::BasicBlock(const std::string& name)
: Value(Type::GetLabelType(), name) {}
// 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型。
BasicBlock::BasicBlock(std::string name)
: Value(Type::GetVoidType(), std::move(name)) {}
BasicBlock::BasicBlock(Function* parent, const std::string& name)
: Value(Type::GetLabelType(), name), parent_(parent) {}
Function* BasicBlock::GetParent() const { return parent_; }
bool BasicBlock::HasTerminator() const {
return !instructions_.empty() && instructions_.back()->IsTerminator();
}
void BasicBlock::SetParent(Function* parent) { parent_ = parent; }
void BasicBlock::EraseInstruction(Instruction* inst) {
if (!inst) {
return;
}
if (inst->IsTerminator()) {
throw std::runtime_error("cannot erase terminator instruction");
}
auto it = std::find_if(instructions_.begin(), instructions_.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == inst;
});
if (it == instructions_.end()) {
return;
}
(*it)->ClearAllOperands();
instructions_.erase(it);
}
void BasicBlock::AddPredecessor(BasicBlock* pred) {
if (pred &&
std::find(predecessors_.begin(), predecessors_.end(), pred) ==
predecessors_.end()) {
predecessors_.push_back(pred);
}
bool BasicBlock::HasTerminator() const {
return !instructions_.empty() && instructions_.back()->IsTerminator();
}
void BasicBlock::AddSuccessor(BasicBlock* succ) {
if (succ && std::find(successors_.begin(), successors_.end(), succ) ==
successors_.end()) {
successors_.push_back(succ);
}
// 按插入顺序返回块内指令序列。
const std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetInstructions()
const {
return instructions_;
}
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
predecessors_.erase(
std::remove(predecessors_.begin(), predecessors_.end(), pred),
predecessors_.end());
// 前驱/后继接口先保留给后续 CFG 扩展使用。
// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。
const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const {
return predecessors_;
}
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
successors_.erase(std::remove(successors_.begin(), successors_.end(), succ),
successors_.end());
const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_;
}
} // namespace ir
} // namespace ir

@ -1,4 +1,5 @@
#include "ir/IR.h"
// 管理基础类型、整型常量池和临时名生成。
#include "ir/IR.h"
#include <sstream>
@ -8,33 +9,15 @@ Context::~Context() = default;
ConstantInt* Context::GetConstInt(int v) {
auto it = const_ints_.find(v);
if (it != const_ints_.end()) {
return it->second.get();
}
auto inserted = const_ints_.emplace(
v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v));
return inserted.first->second.get();
}
ConstantI1* Context::GetConstBool(bool v) {
auto it = const_bools_.find(v);
if (it != const_bools_.end()) {
return it->second.get();
}
auto inserted = const_bools_.emplace(
v, std::make_unique<ConstantI1>(Type::GetInt1Type(), v));
return inserted.first->second.get();
if (it != const_ints_.end()) return it->second.get();
auto inserted =
const_ints_.emplace(v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v)).first;
return inserted->second.get();
}
std::string Context::NextTemp() {
std::ostringstream oss;
oss << "%t" << ++temp_index_;
return oss.str();
}
std::string Context::NextBlockName(const std::string& prefix) {
std::ostringstream oss;
oss << prefix << "." << ++block_index_;
oss << "%" << ++temp_index_;
return oss.str();
}

@ -1,44 +1,17 @@
// IR Function
// - 保存参数列表、基本块列表
// - 记录函数属性/元信息(按需要扩展)
#include "ir/IR.h"
namespace ir {
Argument::Argument(std::shared_ptr<Type> type, std::string name, size_t index)
: Value(std::move(type), std::move(name)), index_(index) {}
Function::Function(std::string name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types,
const std::vector<std::string>& param_names,
bool is_external)
: Value(Type::GetPointerType(), std::move(name)),
return_type_(std::move(ret_type)),
param_types_(param_types),
is_external_(is_external) {
for (size_t i = 0; i < param_types_.size(); ++i) {
std::string arg_name = i < param_names.size() && !param_names[i].empty()
? param_names[i]
: "%arg" + std::to_string(i);
arguments_.push_back(
std::make_unique<Argument>(param_types_[i], std::move(arg_name), i));
}
}
Argument* Function::GetArgument(size_t index) const {
return index < arguments_.size() ? arguments_[index].get() : nullptr;
}
BasicBlock* Function::EnsureEntryBlock() {
if (!entry_) {
entry_ = CreateBlock("entry");
}
return entry_;
Function::Function(std::string name, std::shared_ptr<Type> ret_type)
: Value(std::move(ret_type), std::move(name)) {
entry_ = CreateBlock("entry");
}
BasicBlock* Function::CreateBlock(const std::string& name) {
auto block = std::make_unique<BasicBlock>(this, name);
return AddBlock(std::move(block));
}
BasicBlock* Function::AddBlock(std::unique_ptr<BasicBlock> block) {
auto block = std::make_unique<BasicBlock>(name);
auto* ptr = block.get();
ptr->SetParent(this);
blocks_.push_back(std::move(block));
@ -48,4 +21,12 @@ BasicBlock* Function::AddBlock(std::unique_ptr<BasicBlock> block) {
return ptr;
}
BasicBlock* Function::GetEntry() { return entry_; }
const BasicBlock* Function::GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_;
}
} // namespace ir

@ -1,12 +1,11 @@
// GlobalValue 占位实现:
// - 具体的全局初始化器、打印和链接语义需要自行补全
#include "ir/IR.h"
namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> object_type,
const std::string& name, bool is_const, Value* init)
: User(Type::GetPointerType(object_type), name),
object_type_(std::move(object_type)),
is_const_(is_const),
init_(init) {}
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
} // namespace ir

@ -1,213 +1,89 @@
// IR 构建工具:
// - 管理插入点(当前基本块/位置)
// - 提供创建各类指令的便捷接口,降低 IRGen 复杂度
#include "ir/IR.h"
#include <stdexcept>
namespace ir {
namespace {
BasicBlock* RequireInsertBlock(BasicBlock* block) {
if (!block) {
throw std::runtime_error("IRBuilder has no insert block");
}
return block;
}
bool IsFloatBinaryOp(Opcode op) {
return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul ||
op == Opcode::FDiv || op == Opcode::FRem || op == Opcode::FCmpEQ ||
op == Opcode::FCmpNE || op == Opcode::FCmpLT || op == Opcode::FCmpGT ||
op == Opcode::FCmpLE || op == Opcode::FCmpGE;
}
bool IsCompareOp(Opcode op) {
return (op >= Opcode::ICmpEQ && op <= Opcode::ICmpGE) ||
(op >= Opcode::FCmpEQ && op <= Opcode::FCmpGE);
}
std::shared_ptr<Type> ResultTypeForBinary(Opcode op, Value* lhs) {
if (IsCompareOp(op)) {
return Type::GetInt1Type();
}
if (IsFloatBinaryOp(op)) {
return Type::GetFloatType();
}
return lhs->GetType();
}
} // namespace
#include "utils/Log.h"
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) : ctx_(ctx), insert_block_(bb) {}
namespace ir {
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb)
: ctx_(ctx), insert_block_(bb) {}
void IRBuilder::SetInsertPoint(BasicBlock* bb) { insert_block_ = bb; }
ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); }
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return new ConstantFloat(Type::GetFloatType(), v);
}
ConstantI1* IRBuilder::CreateConstBool(bool v) { return ctx_.GetConstBool(v); }
BasicBlock* IRBuilder::GetInsertBlock() const { return insert_block_; }
ConstantArrayValue* IRBuilder::CreateConstArray(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name) {
return new ConstantArrayValue(std::move(array_type), elements, dims, name);
ConstantInt* IRBuilder::CreateConstInt(int v) {
// 常量不需要挂在基本块里,由 Context 负责去重与生命周期。
return ctx_.GetConstInt(v);
}
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<BinaryInst>(op, ResultTypeForBinary(op, lhs), lhs, rhs,
nullptr, name);
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!lhs) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateBinary 缺少 lhs"));
}
if (!rhs) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateBinary 缺少 rhs"));
}
return insert_block_->Append<BinaryInst>(op, lhs->GetType(), lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, const std::string& name) {
BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Add, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Div, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateRem(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Rem, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::And, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Or, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateXor(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Xor, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateShl(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Shl, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateAShr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::AShr, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateLShr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::LShr, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(op, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(op, lhs, rhs, name);
}
UnaryInst* IRBuilder::CreateNeg(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::Neg, operand->GetType(), operand, nullptr,
name);
}
UnaryInst* IRBuilder::CreateNot(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::Not, operand->GetType(), operand, nullptr,
name);
}
UnaryInst* IRBuilder::CreateFNeg(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::FNeg, operand->GetType(), operand,
nullptr, name);
}
UnaryInst* IRBuilder::CreateFtoI(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::FtoI, Type::GetInt32Type(), operand,
nullptr, name);
}
UnaryInst* IRBuilder::CreateIToF(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::IToF, Type::GetFloatType(), operand,
nullptr, name);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> allocated_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<AllocaInst>(std::move(allocated_type), nullptr, name);
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, std::shared_ptr<Type> value_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<LoadInst>(std::move(value_type), ptr, nullptr, name);
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!ptr) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
}
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
}
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<StoreInst>(val, ptr, nullptr);
}
UncondBrInst* IRBuilder::CreateBr(BasicBlock* dest) {
auto* block = RequireInsertBlock(insert_block_);
auto* inst = block->Append<UncondBrInst>(dest, nullptr);
block->AddSuccessor(dest);
dest->AddPredecessor(block);
return inst;
}
CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* then_bb,
BasicBlock* else_bb) {
auto* block = RequireInsertBlock(insert_block_);
auto* inst = block->Append<CondBrInst>(cond, then_bb, else_bb, nullptr);
block->AddSuccessor(then_bb);
block->AddSuccessor(else_bb);
then_bb->AddPredecessor(block);
else_bb->AddPredecessor(block);
return inst;
}
ReturnInst* IRBuilder::CreateRet(Value* val) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<ReturnInst>(val, nullptr);
}
UnreachableInst* IRBuilder::CreateUnreachable() {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnreachableInst>(nullptr);
}
CallInst* IRBuilder::CreateCall(Function* callee, const std::vector<Value*>& args,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
std::string real_name = callee->GetReturnType()->IsVoid() ? std::string() : name;
return block->Append<CallInst>(callee, args, nullptr, real_name);
}
GetElementPtrInst* IRBuilder::CreateGEP(Value* ptr,
std::shared_ptr<Type> source_type,
const std::vector<Value*>& indices,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<GetElementPtrInst>(std::move(source_type), ptr, indices,
nullptr, name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<PhiInst>(std::move(type), nullptr, name);
}
ZextInst* IRBuilder::CreateZext(Value* val, std::shared_ptr<Type> target_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<ZextInst>(val, std::move(target_type), nullptr, name);
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!val) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateStore 缺少 val"));
}
if (!ptr) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateStore 缺少 ptr"));
}
return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr);
}
MemsetInst* IRBuilder::CreateMemset(Value* dst, Value* val, Value* len,
Value* is_volatile) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<MemsetInst>(dst, val, len, is_volatile, nullptr);
ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!v) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
}
} // namespace ir

@ -1,483 +1,106 @@
#include "ir/IR.h"
// IR 文本输出:
// - 将 IR 打印为 .ll 风格的文本
// - 支撑调试与测试对比diff
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
namespace ir {
namespace {
std::string TypeToString(const Type& ty) {
std::ostringstream oss;
ty.Print(oss);
return oss.str();
}
std::string FloatToString(float value) {
double promoted = static_cast<double>(value);
std::uint64_t bits = 0;
std::memcpy(&bits, &promoted, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::uppercase << std::hex << std::setw(16)
<< std::setfill('0') << bits;
return oss.str();
}
std::string ValueRef(const Value* value) {
if (!value) {
return "<null>";
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return FloatToString(cf->GetValue());
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return cb->GetValue() ? "1" : "0";
}
if (isa<Function>(value) || isa<GlobalValue>(value)) {
return "@" + value->GetName();
}
return value->GetName();
}
std::string BlockRef(const BasicBlock* block) {
if (!block) {
return "%<null>";
}
return "%" + block->GetName();
}
bool IsZeroScalarConstant(const Value* value) {
if (!value) {
return true;
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return ci->GetValue() == 0;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return cf->GetValue() == 0.0f;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return !cb->GetValue();
}
return false;
}
size_t CountScalarElements(const Type& type) {
if (!type.IsArray()) {
return 1;
}
return type.GetNumElements() * CountScalarElements(*type.GetElementType());
}
#include "utils/Log.h"
bool IsZeroArrayRange(const std::vector<Value*>& elements, const Type& type,
size_t offset) {
const auto count = CountScalarElements(type);
for (size_t i = 0; i < count; ++i) {
if (offset + i < elements.size() &&
!IsZeroScalarConstant(elements[offset + i])) {
return false;
}
}
return true;
}
void PrintConstantForType(std::ostream& os, const Type& type, Value* value);
void PrintArrayConstant(std::ostream& os, const Type& type,
const std::vector<Value*>& elements, size_t& offset) {
if (IsZeroArrayRange(elements, type, offset)) {
os << "zeroinitializer";
offset += CountScalarElements(type);
return;
}
const auto elem_type = type.GetElementType();
os << "[";
for (size_t i = 0; i < type.GetNumElements(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*elem_type) << " ";
if (elem_type->IsArray()) {
PrintArrayConstant(os, *elem_type, elements, offset);
} else {
Value* elem = offset < elements.size() ? elements[offset] : nullptr;
PrintConstantForType(os, *elem_type, elem);
++offset;
}
}
os << "]";
}
void PrintConstantForType(std::ostream& os, const Type& type, Value* value) {
if (type.IsArray()) {
auto* array_value = dyncast<ConstantArrayValue>(value);
size_t offset = 0;
if (array_value) {
PrintArrayConstant(os, type, array_value->GetElements(), offset);
} else {
os << "zeroinitializer";
}
return;
}
if (!value) {
if (type.IsFloat()) {
os << FloatToString(0.0f);
} else {
os << "0";
}
return;
}
namespace ir {
if (auto* ci = dyncast<ConstantInt>(value)) {
os << ci->GetValue();
return;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
os << FloatToString(cf->GetValue());
return;
static const char* TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
case Type::Kind::Int32:
return "i32";
case Type::Kind::PtrInt32:
return "i32*";
}
if (auto* cb = dyncast<ConstantI1>(value)) {
os << (cb->GetValue() ? "1" : "0");
return;
}
throw std::runtime_error("global initializer must be constant");
throw std::runtime_error(FormatError("ir", "未知类型"));
}
const char* BinaryOpcodeMnemonic(Opcode opcode) {
switch (opcode) {
static const char* OpcodeToString(Opcode op) {
switch (op) {
case Opcode::Add:
return "add";
case Opcode::Sub:
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Div:
return "sdiv";
case Opcode::Rem:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::FRem:
return "frem";
case Opcode::And:
return "and";
case Opcode::Or:
return "or";
case Opcode::Xor:
return "xor";
case Opcode::Shl:
return "shl";
case Opcode::AShr:
return "ashr";
case Opcode::LShr:
return "lshr";
case Opcode::ICmpEQ:
return "icmp eq";
case Opcode::ICmpNE:
return "icmp ne";
case Opcode::ICmpLT:
return "icmp slt";
case Opcode::ICmpGT:
return "icmp sgt";
case Opcode::ICmpLE:
return "icmp sle";
case Opcode::ICmpGE:
return "icmp sge";
case Opcode::FCmpEQ:
return "fcmp oeq";
case Opcode::FCmpNE:
return "fcmp one";
case Opcode::FCmpLT:
return "fcmp olt";
case Opcode::FCmpGT:
return "fcmp ogt";
case Opcode::FCmpLE:
return "fcmp ole";
case Opcode::FCmpGE:
return "fcmp oge";
default:
throw std::runtime_error("unsupported binary opcode");
}
case Opcode::Alloca:
return "alloca";
case Opcode::Load:
return "load";
case Opcode::Store:
return "store";
case Opcode::Ret:
return "ret";
}
return "?";
}
bool NeedsMemsetDeclaration(const Module& module) {
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsExternal()) {
continue;
}
for (const auto& bb : func->GetBlocks()) {
for (const auto& inst : bb->GetInstructions()) {
if (inst->GetOpcode() == Opcode::Memset) {
return true;
}
}
}
}
return false;
}
void PrintInstruction(const Instruction& inst, std::ostream& os) {
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto& bin = static_cast<const BinaryInst&>(inst);
os << " " << bin.GetName() << " = " << BinaryOpcodeMnemonic(bin.GetOpcode())
<< " " << TypeToString(*bin.GetLhs()->GetType()) << " "
<< ValueRef(bin.GetLhs()) << ", " << ValueRef(bin.GetRhs()) << "\n";
return;
}
case Opcode::Neg: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = sub " << TypeToString(*un.GetType())
<< " 0, " << ValueRef(un.GetOprd()) << "\n";
return;
}
case Opcode::Not: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = xor " << TypeToString(*un.GetType())
<< " " << ValueRef(un.GetOprd()) << ", 1\n";
return;
}
case Opcode::FNeg: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = fneg " << TypeToString(*un.GetType())
<< " " << ValueRef(un.GetOprd()) << "\n";
return;
}
case Opcode::FtoI: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = fptosi "
<< TypeToString(*un.GetOprd()->GetType()) << " " << ValueRef(un.GetOprd())
<< " to " << TypeToString(*un.GetType()) << "\n";
return;
}
case Opcode::IToF: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = sitofp "
<< TypeToString(*un.GetOprd()->GetType()) << " " << ValueRef(un.GetOprd())
<< " to " << TypeToString(*un.GetType()) << "\n";
return;
}
case Opcode::Alloca: {
auto& alloca_inst = static_cast<const AllocaInst&>(inst);
os << " " << alloca_inst.GetName() << " = alloca "
<< TypeToString(*alloca_inst.GetAllocatedType()) << "\n";
return;
}
case Opcode::Load: {
auto& load = static_cast<const LoadInst&>(inst);
os << " " << load.GetName() << " = load "
<< TypeToString(*load.GetType()) << ", ptr " << ValueRef(load.GetPtr())
<< "\n";
return;
}
case Opcode::Store: {
auto& store = static_cast<const StoreInst&>(inst);
os << " store " << TypeToString(*store.GetValue()->GetType()) << " "
<< ValueRef(store.GetValue()) << ", ptr " << ValueRef(store.GetPtr())
<< "\n";
return;
}
case Opcode::Br: {
auto& br = static_cast<const UncondBrInst&>(inst);
os << " br label " << BlockRef(br.GetDest()) << "\n";
return;
}
case Opcode::CondBr: {
auto& br = static_cast<const CondBrInst&>(inst);
os << " br i1 " << ValueRef(br.GetCondition()) << ", label "
<< BlockRef(br.GetThenBlock()) << ", label "
<< BlockRef(br.GetElseBlock()) << "\n";
return;
}
case Opcode::Return: {
auto& ret = static_cast<const ReturnInst&>(inst);
if (!ret.HasReturnValue()) {
os << " ret void\n";
} else {
os << " ret " << TypeToString(*ret.GetReturnValue()->GetType()) << " "
<< ValueRef(ret.GetReturnValue()) << "\n";
}
return;
}
case Opcode::Unreachable:
os << " unreachable\n";
return;
case Opcode::Call: {
auto& call = static_cast<const CallInst&>(inst);
if (!call.GetType()->IsVoid()) {
os << " " << call.GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call.GetCallee()->GetReturnType()) << " @"
<< call.GetCallee()->GetName() << "(";
const auto args = call.GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType()) << " " << ValueRef(args[i]);
}
os << ")\n";
return;
}
case Opcode::GetElementPtr: {
auto& gep = static_cast<const GetElementPtrInst&>(inst);
os << " " << gep.GetName() << " = getelementptr "
<< TypeToString(*gep.GetSourceType()) << ", ptr "
<< ValueRef(gep.GetPointer());
for (size_t i = 0; i < gep.GetNumIndices(); ++i) {
auto* index = gep.GetIndex(i);
os << ", " << TypeToString(*index->GetType()) << " " << ValueRef(index);
}
os << "\n";
return;
}
case Opcode::Phi: {
auto& phi = static_cast<const PhiInst&>(inst);
os << " " << phi.GetName() << " = phi " << TypeToString(*phi.GetType())
<< " ";
for (int i = 0; i < phi.GetNumIncomings(); ++i) {
if (i > 0) {
os << ", ";
}
os << "[ " << ValueRef(phi.GetIncomingValue(i)) << ", "
<< BlockRef(phi.GetIncomingBlock(i)) << " ]";
}
os << "\n";
return;
}
case Opcode::Zext: {
auto& zext = static_cast<const ZextInst&>(inst);
os << " " << zext.GetName() << " = zext "
<< TypeToString(*zext.GetValue()->GetType()) << " "
<< ValueRef(zext.GetValue()) << " to " << TypeToString(*zext.GetType())
<< "\n";
return;
}
case Opcode::Memset: {
auto& memset = static_cast<const MemsetInst&>(inst);
os << " call void @llvm.memset.p0.i32(ptr " << ValueRef(memset.GetDest())
<< ", i8 " << ValueRef(memset.GetValue()) << ", i32 "
<< ValueRef(memset.GetLength()) << ", i1 "
<< ValueRef(memset.GetIsVolatile()) << ")\n";
return;
}
static std::string ValueToString(const Value* v) {
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
throw std::runtime_error("unsupported instruction in printer");
return v ? v->GetName() : "<null>";
}
} // namespace
void IRPrinter::Print(const Module& module, std::ostream& os) {
if (NeedsMemsetDeclaration(module)) {
os << "declare void @llvm.memset.p0.i32(ptr, i8, i32, i1)\n\n";
}
for (const auto& global : module.GetGlobalValues()) {
os << "@" << global->GetName() << " = "
<< (global->IsConstant() ? "constant " : "global ")
<< TypeToString(*global->GetObjectType()) << " ";
PrintConstantForType(os, *global->GetObjectType(), global->GetInitializer());
os << "\n";
}
if (!module.GetGlobalValues().empty()) {
os << "\n";
}
for (const auto& func : module.GetFunctions()) {
if (!func->IsExternal()) {
continue;
}
os << "declare " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType());
}
os << ")\n";
}
bool printed_decl = false;
for (const auto& func : module.GetFunctions()) {
if (func->IsExternal()) {
printed_decl = true;
}
}
if (printed_decl) {
os << "\n";
}
for (const auto& func : module.GetFunctions()) {
if (func->IsExternal()) {
continue;
}
os << "define " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName();
}
os << ") {\n";
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "() {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
continue;
}
os << bb->GetName() << ":\n";
for (const auto& inst : bb->GetInstructions()) {
PrintInstruction(*inst, os);
for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get();
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca i32\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
break;
}
}
}
}
os << "}\n\n";
os << "}\n";
}
}

@ -1,259 +1,151 @@
#include "ir/IR.h"
// IR 指令体系:
// - 二元运算/比较、load/store、call、br/condbr、ret、phi、alloca 等
// - 指令操作数与结果类型管理,支持打印与优化
#include "ir/IR.h"
#include <stdexcept>
namespace ir {
#include "utils/Log.h"
namespace ir {
User::User(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
size_t User::GetNumOperands() const { return operands_.size(); }
Value* User::GetOperand(size_t index) const {
if (index >= operands_.size()) {
throw std::out_of_range("operand index out of range");
throw std::out_of_range("User operand index out of range");
}
return operands_[index].GetValue();
return operands_[index];
}
void User::SetOperand(size_t index, Value* value) {
if (index >= operands_.size()) {
throw std::out_of_range("operand index out of range");
throw std::out_of_range("User operand index out of range");
}
auto* old_value = operands_[index].GetValue();
if (old_value == value) {
return;
if (!value) {
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
}
if (old_value) {
old_value->RemoveUse(this, index);
auto* old = operands_[index];
if (old == value) {
return;
}
operands_[index].SetValue(value);
if (value) {
value->AddUse(this, index);
if (old) {
old->RemoveUse(this, index);
}
operands_[index] = value;
value->AddUse(this, index);
}
void User::AddOperand(Value* value) {
if (!value) {
throw std::runtime_error("operand cannot be null");
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
}
operands_.emplace_back(value, this, operands_.size());
value->AddUse(this, operands_.size() - 1);
size_t operand_index = operands_.size();
operands_.push_back(value);
value->AddUse(this, operand_index);
}
void User::AddOperands(const std::vector<Value*>& values) {
for (auto* value : values) {
AddOperand(value);
}
}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {}
void User::RemoveOperand(size_t index) {
if (index >= operands_.size()) {
throw std::out_of_range("operand index out of range");
}
if (auto* value = operands_[index].GetValue()) {
value->RemoveUse(this, index);
}
operands_.erase(operands_.begin() + static_cast<long long>(index));
for (size_t i = index; i < operands_.size(); ++i) {
operands_[i].SetOperandIndex(i);
}
}
Opcode Instruction::GetOpcode() const { return opcode_; }
void User::ClearAllOperands() {
for (size_t i = 0; i < operands_.size(); ++i) {
if (auto* value = operands_[i].GetValue()) {
value->RemoveUse(this, i);
}
}
operands_.clear();
}
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; }
Instruction::Instruction(Opcode opcode, std::shared_ptr<Type> ty,
BasicBlock* parent, const std::string& name)
: User(std::move(ty), name), opcode_(opcode), parent_(parent) {}
BasicBlock* Instruction::GetParent() const { return parent_; }
bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Br || opcode_ == Opcode::CondBr ||
opcode_ == Opcode::Return || opcode_ == Opcode::Unreachable;
}
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
static bool IsBinaryOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
return true;
default:
return false;
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
}
if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
}
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() ||
type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
}
if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
}
}
bool BinaryInst::classof(const Value* value) {
return value && Instruction::classof(value) &&
IsBinaryOpcode(static_cast<const Instruction*>(value)->GetOpcode());
}
BinaryInst::BinaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, BasicBlock* parent,
const std::string& name)
: Instruction(opcode, std::move(ty), parent, name) {
AddOperand(lhs);
AddOperand(rhs);
}
bool UnaryInst::classof(const Value* value) {
if (!value || !Instruction::classof(value)) {
return false;
}
switch (static_cast<const Instruction*>(value)->GetOpcode()) {
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
return true;
default:
return false;
}
}
Value* BinaryInst::GetLhs() const { return GetOperand(0); }
UnaryInst::UnaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* operand,
BasicBlock* parent, const std::string& name)
: Instruction(opcode, std::move(ty), parent, name) {
AddOperand(operand);
}
Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ReturnInst::ReturnInst(Value* value, BasicBlock* parent)
: Instruction(Opcode::Return, Type::GetVoidType(), parent, "") {
if (value) {
AddOperand(value);
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
AddOperand(val);
}
AllocaInst::AllocaInst(std::shared_ptr<Type> allocated_type,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Alloca, Type::GetPointerType(allocated_type), parent,
name),
allocated_type_(std::move(allocated_type)) {}
Value* ReturnInst::GetValue() const { return GetOperand(0); }
LoadInst::LoadInst(std::shared_ptr<Type> value_type, Value* ptr,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Load, std::move(value_type), parent, name) {
AddOperand(ptr);
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ || !type_->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
}
}
StoreInst::StoreInst(Value* value, Value* ptr, BasicBlock* parent)
: Instruction(Opcode::Store, Type::GetVoidType(), parent, "") {
AddOperand(value);
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
}
AddOperand(ptr);
}
UncondBrInst::UncondBrInst(BasicBlock* dest, BasicBlock* parent)
: Instruction(Opcode::Br, Type::GetVoidType(), parent, "") {
AddOperand(dest);
}
BasicBlock* UncondBrInst::GetDest() const {
return dyncast<BasicBlock>(GetOperand(0));
}
CondBrInst::CondBrInst(Value* cond, BasicBlock* then_block,
BasicBlock* else_block, BasicBlock* parent)
: Instruction(Opcode::CondBr, Type::GetVoidType(), parent, "") {
AddOperand(cond);
AddOperand(then_block);
AddOperand(else_block);
}
BasicBlock* CondBrInst::GetThenBlock() const {
return dyncast<BasicBlock>(GetOperand(1));
}
BasicBlock* CondBrInst::GetElseBlock() const {
return dyncast<BasicBlock>(GetOperand(2));
}
UnreachableInst::UnreachableInst(BasicBlock* parent)
: Instruction(Opcode::Unreachable, Type::GetVoidType(), parent, "") {}
CallInst::CallInst(Function* callee, const std::vector<Value*>& args,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Call, callee->GetReturnType(), parent, name) {
AddOperand(callee);
AddOperands(args);
}
Function* CallInst::GetCallee() const { return dyncast<Function>(GetOperand(0)); }
Value* LoadInst::GetPtr() const { return GetOperand(0); }
std::vector<Value*> CallInst::GetArguments() const {
std::vector<Value*> args;
for (size_t i = 1; i < GetNumOperands(); ++i) {
args.push_back(GetOperand(i));
StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
: Instruction(Opcode::Store, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value"));
}
return args;
}
GetElementPtrInst::GetElementPtrInst(std::shared_ptr<Type> source_type,
Value* ptr,
const std::vector<Value*>& indices,
BasicBlock* parent,
const std::string& name)
: Instruction(Opcode::GetElementPtr, Type::GetPointerType(), parent, name),
source_type_(std::move(source_type)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
}
AddOperand(val);
AddOperand(ptr);
AddOperands(indices);
}
PhiInst::PhiInst(std::shared_ptr<Type> type, BasicBlock* parent,
const std::string& name)
: Instruction(Opcode::Phi, std::move(type), parent, name) {}
Value* StoreInst::GetValue() const { return GetOperand(0); }
void PhiInst::AddIncoming(Value* value, BasicBlock* block) {
AddOperand(value);
AddOperand(block);
}
BasicBlock* PhiInst::GetIncomingBlock(int index) const {
return dyncast<BasicBlock>(GetOperand(static_cast<size_t>(2 * index + 1)));
}
ZextInst::ZextInst(Value* value, std::shared_ptr<Type> target_type,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Zext, std::move(target_type), parent, name) {
AddOperand(value);
}
MemsetInst::MemsetInst(Value* dst, Value* value, Value* len,
Value* is_volatile, BasicBlock* parent)
: Instruction(Opcode::Memset, Type::GetVoidType(), parent, "") {
AddOperand(dst);
AddOperand(value);
AddOperand(len);
AddOperand(is_volatile);
}
Value* StoreInst::GetPtr() const { return GetOperand(1); }
} // namespace ir

@ -1,45 +1,21 @@
// 保存函数列表并提供模块级上下文访问。
#include "ir/IR.h"
namespace ir {
Function* Module::CreateFunction(
const std::string& name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types,
const std::vector<std::string>& param_names, bool is_external) {
if (auto* existing = GetFunction(name)) {
existing->SetExternal(existing->IsExternal() && is_external);
return existing;
}
auto func = std::make_unique<Function>(name, std::move(ret_type), param_types,
param_names, is_external);
auto* ptr = func.get();
functions_.push_back(std::move(func));
function_map_[name] = ptr;
return ptr;
}
Context& Module::GetContext() { return context_; }
Function* Module::GetFunction(const std::string& name) const {
auto it = function_map_.find(name);
return it == function_map_.end() ? nullptr : it->second;
}
const Context& Module::GetContext() const { return context_; }
GlobalValue* Module::CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> object_type,
bool is_const, Value* init) {
if (auto* existing = GetGlobalValue(name)) {
return existing;
}
auto global =
std::make_unique<GlobalValue>(std::move(object_type), name, is_const, init);
auto* ptr = global.get();
globals_.push_back(std::move(global));
global_map_[name] = ptr;
return ptr;
Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) {
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type)));
return functions_.back().get();
}
GlobalValue* Module::GetGlobalValue(const std::string& name) const {
auto it = global_map_.find(name);
return it == global_map_.end() ? nullptr : it->second;
const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_;
}
} // namespace ir

@ -1,111 +1,31 @@
// 当前仅支持 void、i32 和 i32*。
#include "ir/IR.h"
#include <ostream>
#include <stdexcept>
namespace ir {
Type::Type(Kind kind) : kind_(kind) {}
Type::Type(Kind kind, std::shared_ptr<Type> element_type, size_t num_elements)
: kind_(kind),
element_type_(std::move(element_type)),
num_elements_(num_elements) {}
Type::Type(Kind k) : kind_(k) {}
const std::shared_ptr<Type>& Type::GetVoidType() {
static const auto type = std::make_shared<Type>(Kind::Void);
return type;
}
const std::shared_ptr<Type>& Type::GetInt1Type() {
static const auto type = std::make_shared<Type>(Kind::Int1);
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() {
static const auto type = std::make_shared<Type>(Kind::Int32);
return type;
}
const std::shared_ptr<Type>& Type::GetFloatType() {
static const auto type = std::make_shared<Type>(Kind::Float);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const auto type = std::make_shared<Type>(Kind::Label);
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
return type;
}
const std::shared_ptr<Type>& Type::GetBoolType() { return GetInt1Type(); }
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> pointee) {
return std::make_shared<Type>(Kind::Pointer, std::move(pointee));
}
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const auto type = std::make_shared<Type>(Kind::Pointer);
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32);
return type;
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> element_type,
size_t num_elements) {
return std::make_shared<Type>(Kind::Array, std::move(element_type), num_elements);
}
Type::Kind Type::GetKind() const { return kind_; }
int Type::GetSize() const {
switch (kind_) {
case Kind::Void:
case Kind::Label:
case Kind::Function:
return 0;
case Kind::Int1:
return 1;
case Kind::Int32:
case Kind::Float:
return 4;
case Kind::Pointer:
return 8;
case Kind::Array:
return static_cast<int>(num_elements_) *
(element_type_ ? element_type_->GetSize() : 0);
}
throw std::runtime_error("unknown IR type kind");
}
bool Type::IsVoid() const { return kind_ == Kind::Void; }
void Type::Print(std::ostream& os) const {
switch (kind_) {
case Kind::Void:
os << "void";
return;
case Kind::Int1:
os << "i1";
return;
case Kind::Int32:
os << "i32";
return;
case Kind::Float:
os << "float";
return;
case Kind::Label:
os << "label";
return;
case Kind::Function:
os << "fn";
return;
case Kind::Pointer:
os << "ptr";
return;
case Kind::Array:
os << "[" << num_elements_ << " x ";
if (element_type_) {
element_type_->Print(os);
} else {
os << "void";
}
os << "]";
return;
}
}
bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
} // namespace ir

@ -1,19 +1,46 @@
// SSA 值体系抽象:
// - 常量、参数、指令结果等统一为 Value
// - 提供类型信息与使用/被使用关系(按需要实现)
#include "ir/IR.h"
#include <algorithm>
#include <ostream>
#include <stdexcept>
namespace ir {
Value::Value(std::shared_ptr<Type> ty, std::string name)
: type_(std::move(ty)), name_(std::move(name)) {}
const std::shared_ptr<Type>& Value::GetType() const { return type_; }
const std::string& Value::GetName() const { return name_; }
void Value::SetName(std::string n) { name_ = std::move(n); }
bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr;
}
bool Value::IsInstruction() const {
return dynamic_cast<const Instruction*>(this) != nullptr;
}
bool Value::IsUser() const {
return dynamic_cast<const User*>(this) != nullptr;
}
bool Value::IsFunction() const {
return dynamic_cast<const Function*>(this) != nullptr;
}
void Value::AddUse(User* user, size_t operand_index) {
if (!user) {
return;
}
uses_.emplace_back(this, user, operand_index);
if (!user) return;
uses_.push_back(Use(this, user, operand_index));
}
void Value::RemoveUse(User* user, size_t operand_index) {
@ -26,41 +53,31 @@ void Value::RemoveUse(User* user, size_t operand_index) {
uses_.end());
}
const std::vector<Use>& Value::GetUses() const { return uses_; }
void Value::ReplaceAllUsesWith(Value* new_value) {
if (!new_value) {
throw std::runtime_error("ReplaceAllUsesWith requires a new value");
throw std::runtime_error("ReplaceAllUsesWith 缺少 new_value");
}
if (new_value == this) {
return;
}
auto uses = uses_;
for (const auto& use : uses) {
if (auto* user = use.GetUser()) {
user->SetOperand(use.GetOperandIndex(), new_value);
auto* user = use.GetUser();
if (!user) continue;
size_t operand_index = use.GetOperandIndex();
if (user->GetOperand(operand_index) == this) {
user->SetOperand(operand_index, new_value);
}
}
}
void Value::Print(std::ostream& os) const { os << name_; }
ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantI1::ConstantI1(std::shared_ptr<Type> ty, bool value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantArrayValue::ConstantArrayValue(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name)
: Value(std::move(array_type), name), elements_(elements), dims_(dims) {}
void ConstantArrayValue::Print(std::ostream& os) const { os << name_; }
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) {}
} // namespace ir

@ -1,405 +1,4 @@
// Mem2RegSSA 构造):
// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
#include "ir/PassManager.h"
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <queue>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct DominatorInfo {
std::vector<BasicBlock*> blocks;
std::unordered_map<BasicBlock*, size_t> index;
std::vector<std::vector<bool>> dominates;
std::vector<BasicBlock*> idom;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_tree_children;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dominance_frontier;
};
struct PromotableAlloca {
AllocaInst* alloca = nullptr;
std::shared_ptr<Type> value_type;
std::unordered_set<BasicBlock*> def_blocks;
std::unordered_map<BasicBlock*, PhiInst*> phis;
};
bool IsScalarPromotableType(const std::shared_ptr<Type>& type) {
return type && (type->IsInt1() || type->IsInt32() || type->IsFloat());
}
Value* DefaultValueFor(Context& ctx, const std::shared_ptr<Type>& type) {
if (type->IsInt1()) {
return ctx.GetConstBool(false);
}
if (type->IsInt32()) {
return ctx.GetConstInt(0);
}
if (type->IsFloat()) {
return new ConstantFloat(Type::GetFloatType(), 0.0f);
}
throw std::runtime_error("Mem2Reg encountered unsupported promotable type");
}
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
return order;
}
std::vector<bool> IntersectDominators(const std::vector<std::vector<bool>>& doms,
const std::vector<size_t>& pred_indices,
size_t self_index) {
std::vector<bool> result(doms.size(), true);
if (pred_indices.empty()) {
std::fill(result.begin(), result.end(), false);
result[self_index] = true;
return result;
}
result = doms[pred_indices.front()];
for (size_t i = 1; i < pred_indices.size(); ++i) {
const auto& pred_dom = doms[pred_indices[i]];
for (size_t j = 0; j < result.size(); ++j) {
result[j] = result[j] && pred_dom[j];
}
}
result[self_index] = true;
return result;
}
DominatorInfo BuildDominatorInfo(Function& function) {
DominatorInfo info;
info.blocks = CollectReachableBlocks(function);
info.idom.resize(info.blocks.size(), nullptr);
info.dominates.assign(info.blocks.size(),
std::vector<bool>(info.blocks.size(), true));
if (info.blocks.empty()) {
return info;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
info.index[info.blocks[i]] = i;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
std::fill(info.dominates[i].begin(), info.dominates[i].end(), i != 0);
info.dominates[i][i] = true;
}
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 1; i < info.blocks.size(); ++i) {
std::vector<size_t> pred_indices;
for (auto* pred : info.blocks[i]->GetPredecessors()) {
auto it = info.index.find(pred);
if (it != info.index.end()) {
pred_indices.push_back(it->second);
}
}
auto new_dom = IntersectDominators(info.dominates, pred_indices, i);
if (new_dom != info.dominates[i]) {
info.dominates[i] = std::move(new_dom);
changed = true;
}
}
}
for (size_t i = 1; i < info.blocks.size(); ++i) {
BasicBlock* candidate_idom = nullptr;
for (size_t j = 0; j < info.blocks.size(); ++j) {
if (i == j || !info.dominates[i][j]) {
continue;
}
bool is_immediate = true;
for (size_t k = 0; k < info.blocks.size(); ++k) {
if (k == i || k == j || !info.dominates[i][k]) {
continue;
}
if (info.dominates[k][j]) {
is_immediate = false;
break;
}
}
if (is_immediate) {
candidate_idom = info.blocks[j];
break;
}
}
info.idom[i] = candidate_idom;
if (candidate_idom != nullptr) {
info.dom_tree_children[candidate_idom].push_back(info.blocks[i]);
}
}
for (auto* block : info.blocks) {
info.dominance_frontier[block] = {};
}
for (auto* block : info.blocks) {
std::vector<BasicBlock*> reachable_preds;
for (auto* pred : block->GetPredecessors()) {
if (info.index.find(pred) != info.index.end()) {
reachable_preds.push_back(pred);
}
}
if (reachable_preds.size() < 2) {
continue;
}
auto* idom_block = info.idom[info.index[block]];
for (auto* pred : reachable_preds) {
auto* runner = pred;
while (runner != nullptr && runner != idom_block) {
auto& frontier = info.dominance_frontier[runner];
if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) {
frontier.push_back(block);
}
auto idom_it = info.index.find(runner);
if (idom_it == info.index.end()) {
break;
}
runner = info.idom[idom_it->second];
}
}
}
return info;
}
bool IsPromotableAlloca(AllocaInst& alloca, const DominatorInfo& dom_info) {
if (!IsScalarPromotableType(alloca.GetAllocatedType())) {
return false;
}
for (const auto& use : alloca.GetUses()) {
auto* user = use.GetUser();
auto* inst = dynamic_cast<Instruction*>(user);
if (inst == nullptr || inst->GetParent() == nullptr ||
dom_info.index.find(inst->GetParent()) == dom_info.index.end()) {
return false;
}
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() != &alloca) {
return false;
}
continue;
}
auto* store = dynamic_cast<StoreInst*>(inst);
if (store == nullptr || store->GetPtr() != &alloca ||
store->GetValue() == &alloca) {
return false;
}
if (store->GetValue()->GetType() != alloca.GetAllocatedType()) {
return false;
}
}
return true;
}
size_t CountLeadingPhiNodes(BasicBlock& block) {
size_t count = 0;
for (const auto& inst : block.GetInstructions()) {
if (!isa<PhiInst>(inst.get())) {
break;
}
++count;
}
return count;
}
void InsertPhiNodes(Context& ctx, PromotableAlloca& slot,
const DominatorInfo& dom_info) {
std::queue<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> queued;
for (auto* block : slot.def_blocks) {
worklist.push(block);
queued.insert(block);
}
while (!worklist.empty()) {
auto* block = worklist.front();
worklist.pop();
auto frontier_it = dom_info.dominance_frontier.find(block);
if (frontier_it == dom_info.dominance_frontier.end()) {
continue;
}
for (auto* frontier_block : frontier_it->second) {
if (slot.phis.find(frontier_block) != slot.phis.end()) {
continue;
}
auto phi_index = CountLeadingPhiNodes(*frontier_block);
auto* phi = frontier_block->Insert<PhiInst>(phi_index, slot.value_type, nullptr,
ctx.NextTemp());
slot.phis[frontier_block] = phi;
if (slot.def_blocks.insert(frontier_block).second) {
worklist.push(frontier_block);
}
}
}
}
void RenamePromotedAlloca(BasicBlock* block, PromotableAlloca& slot,
const DominatorInfo& dom_info,
std::vector<Value*>& stack, Context& ctx) {
if (block == nullptr) {
return;
}
size_t pushed = 0;
PhiInst* block_phi = nullptr;
auto phi_it = slot.phis.find(block);
if (phi_it != slot.phis.end()) {
block_phi = phi_it->second;
stack.push_back(block_phi);
++pushed;
}
std::vector<Instruction*> to_remove;
Instruction* alloca_to_remove = nullptr;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == block_phi) {
continue;
}
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() != slot.alloca) {
continue;
}
auto* replacement =
stack.empty() ? DefaultValueFor(ctx, slot.value_type) : stack.back();
load->ReplaceAllUsesWith(replacement);
to_remove.push_back(load);
continue;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetPtr() != slot.alloca) {
continue;
}
stack.push_back(store->GetValue());
++pushed;
to_remove.push_back(store);
continue;
}
if (inst == slot.alloca) {
alloca_to_remove = inst;
}
}
for (auto* succ : block->GetSuccessors()) {
auto succ_phi_it = slot.phis.find(succ);
if (succ_phi_it == slot.phis.end()) {
continue;
}
auto* incoming =
stack.empty() ? DefaultValueFor(ctx, slot.value_type) : stack.back();
succ_phi_it->second->AddIncoming(incoming, block);
}
auto child_it = dom_info.dom_tree_children.find(block);
if (child_it != dom_info.dom_tree_children.end()) {
for (auto* child : child_it->second) {
RenamePromotedAlloca(child, slot, dom_info, stack, ctx);
}
}
for (auto* inst : to_remove) {
block->EraseInstruction(inst);
}
if (alloca_to_remove != nullptr) {
block->EraseInstruction(alloca_to_remove);
}
while (pushed > 0) {
stack.pop_back();
--pushed;
}
}
void PromoteAllocasInFunction(Function& function, Context& ctx) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return;
}
auto dom_info = BuildDominatorInfo(function);
if (dom_info.blocks.empty()) {
return;
}
std::vector<PromotableAlloca> promotable_allocas;
for (const auto& inst_ptr : function.GetEntryBlock()->GetInstructions()) {
auto* alloca = dynamic_cast<AllocaInst*>(inst_ptr.get());
if (alloca == nullptr || !IsPromotableAlloca(*alloca, dom_info)) {
continue;
}
PromotableAlloca slot;
slot.alloca = alloca;
slot.value_type = alloca->GetAllocatedType();
for (const auto& use : alloca->GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
auto* store = inst == nullptr ? nullptr : dynamic_cast<StoreInst*>(inst);
if (store != nullptr && store->GetPtr() == alloca) {
slot.def_blocks.insert(inst->GetParent());
}
}
promotable_allocas.push_back(std::move(slot));
}
for (auto& slot : promotable_allocas) {
InsertPhiNodes(ctx, slot, dom_info);
std::vector<Value*> stack;
RenamePromotedAlloca(function.GetEntryBlock(), slot, dom_info, stack, ctx);
}
}
} // namespace
void RunMem2Reg(Module& module) {
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function != nullptr) {
PromoteAllocasInFunction(*function, ctx);
}
}
}
} // namespace ir

@ -1,17 +1 @@
// IR Pass 管理骨架。
#include "ir/PassManager.h"
#include <cstdlib>
namespace ir {
void RunIRPassPipeline(Module& module) {
const char* disable_mem2reg = std::getenv("NUDTC_DISABLE_MEM2REG");
if (disable_mem2reg != nullptr && disable_mem2reg[0] != '\0' && disable_mem2reg[0] != '0') {
return;
}
RunMem2Reg(module);
}
} // namespace ir
// IR Pass 管理骨架。

@ -1,4 +1,4 @@
add_library(irgen STATIC
add_library(irgen STATIC
IRGenDriver.cpp
IRGenFunc.cpp
IRGenStmt.cpp
@ -8,8 +8,6 @@
target_link_libraries(irgen PUBLIC
build_options
frontend
${ANTLR4_RUNTIME_TARGET}
ir
sem
)

@ -1,595 +1,107 @@
#include "irgen/IRGen.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
namespace {
std::vector<int> ExpandLinearIndex(const std::vector<int>& dims, size_t flat_index) {
std::vector<int> indices(dims.size(), 0);
for (size_t i = dims.size(); i > 0; --i) {
const auto dim_index = i - 1;
indices[dim_index] = static_cast<int>(flat_index % static_cast<size_t>(dims[dim_index]));
flat_index /= static_cast<size_t>(dims[dim_index]);
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
return indices;
return lvalue.ID()->getText();
}
} // namespace
std::string IRGenImpl::ExpectIdent(const antlr4::ParserRuleContext& ctx,
antlr4::tree::TerminalNode* ident) const {
if (ident == nullptr) {
ThrowError(&ctx, "?????");
}
return ident->getText();
}
SemanticType IRGenImpl::ParseBType(SysYParser::BTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? int/float ????");
}
SemanticType IRGenImpl::ParseFuncType(SysYParser::FuncTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "????????");
}
if (ctx->VOID()) {
return SemanticType::Void;
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? void/int/float ??????");
}
std::shared_ptr<ir::Type> IRGenImpl::GetIRScalarType(SemanticType type) const {
switch (type) {
case SemanticType::Void:
return ir::Type::GetVoidType();
case SemanticType::Int:
return ir::Type::GetInt32Type();
case SemanticType::Float:
return ir::Type::GetFloatType();
}
throw std::runtime_error("unknown semantic type");
}
std::shared_ptr<ir::Type> IRGenImpl::BuildArrayType(
SemanticType base_type, const std::vector<int>& dims) const {
auto type = GetIRScalarType(base_type);
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
type = ir::Type::GetArrayType(type, static_cast<size_t>(*it));
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
return type;
}
std::vector<int> IRGenImpl::ParseArrayDims(
const std::vector<SysYParser::ConstExpContext*>& dims_ctx) {
std::vector<int> dims;
dims.reserve(dims_ctx.size());
for (auto* dim_ctx : dims_ctx) {
if (dim_ctx == nullptr || dim_ctx->addExp() == nullptr) {
ThrowError(dim_ctx, "???????????");
}
auto dim = ConvertConst(EvalConstAddExp(*dim_ctx->addExp()), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(dim_ctx, "??????????");
}
dims.push_back(dim.int_value);
}
return dims;
}
std::vector<int> IRGenImpl::ParseParamDims(SysYParser::FuncFParamContext& ctx) {
std::vector<int> dims;
for (auto* exp_ctx : ctx.exp()) {
auto dim = ConvertConst(EvalConstExp(*exp_ctx), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(exp_ctx, "????????????");
}
dims.push_back(dim.int_value);
}
return dims;
}
void IRGenImpl::PredeclareGlobalDecl(SysYParser::DeclContext& ctx) {
auto declare_one = [&](const std::string& name, SemanticType type, bool is_const,
const std::vector<int>& dims, const antlr4::ParserRuleContext* node) {
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(node, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !dims.empty();
entry.is_param_array = false;
entry.dims = dims;
entry.ir_value = module_.CreateGlobalValue(
name, dims.empty() ? GetIRScalarType(type) : BuildArrayType(type, dims),
is_const, nullptr);
if (!symbols_.Insert(name, entry)) {
ThrowError(node, "????????: " + name);
}
};
if (ctx.constDecl() != nullptr) {
const auto type = ParseBType(ctx.constDecl()->bType());
for (auto* def : ctx.constDecl()->constDef()) {
const auto name = ExpectIdent(*def, def->Ident());
const auto dims = ParseArrayDims(def->constExp());
declare_one(name, type, true, dims, def);
auto* symbol = symbols_.Lookup(name);
if (symbol != nullptr && dims.empty()) {
symbol->const_scalar = ConvertConst(
EvalConstAddExp(*def->constInitVal()->constExp()->addExp()), type);
for (auto* item : ctx->blockItem()) {
if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
// 当前语法要求 return 为块内最后一条语句;命中后可停止生成。
break;
}
}
return;
}
if (ctx.varDecl() != nullptr) {
const auto type = ParseBType(ctx.varDecl()->bType());
for (auto* def : ctx.varDecl()->varDef()) {
declare_one(ExpectIdent(*def, def->Ident()), type, false,
ParseArrayDims(def->constExp()), def);
}
return;
}
ThrowError(&ctx, "????");
return {};
}
void IRGenImpl::EmitGlobalDecl(SysYParser::DeclContext& ctx) { EmitDecl(ctx, true); }
void IRGenImpl::EmitDecl(SysYParser::DeclContext& ctx, bool is_global) {
if (ctx.constDecl() != nullptr) {
EmitConstDecl(ctx.constDecl(), is_global);
return;
}
if (ctx.varDecl() != nullptr) {
EmitVarDecl(ctx.varDecl(), is_global, false);
return;
}
ThrowError(&ctx, "????");
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this));
}
void IRGenImpl::EmitVarDecl(SysYParser::VarDeclContext* ctx, bool is_global,
bool is_const) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->varDef()) {
if (is_global) {
EmitGlobalVarDef(*def, type);
} else {
EmitLocalVarDef(*def, type, is_const);
}
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}
}
void IRGenImpl::EmitConstDecl(SysYParser::ConstDeclContext* ctx, bool is_global) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
if (ctx->decl()) {
ctx->decl()->accept(this);
return BlockFlow::Continue;
}
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->constDef()) {
if (is_global) {
EmitGlobalConstDef(*def, type);
} else {
EmitLocalConstDef(*def, type);
}
if (ctx->stmt()) {
return ctx->stmt()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明"));
}
void IRGenImpl::EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
// 变量声明的 IR 生成目前也是最小实现:
// - 先检查声明的基础类型,当前仅支持局部 int
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
//
// 和更完整的版本相比,这里还没有:
// - 一个 Decl 中多个变量定义的顺序处理;
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Variable;
symbol->type = type;
symbol->is_const = false;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
if (symbol->is_array) {
auto flat = FlattenInitVal(ctx.initVal(), type, symbol->dims);
std::vector<ir::Value*> elements;
elements.reserve(flat.size());
for (const auto& value : flat) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
ConstantValue init = ZeroConst(type);
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init = ConvertConst(EvalConstExp(*ctx.initVal()->exp()), type);
}
global->SetInitializer(CreateTypedConstant(init));
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
}
}
void IRGenImpl::EmitGlobalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
}
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Constant;
symbol->type = type;
symbol->is_const = true;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
global->SetConstant(true);
if (symbol->is_array) {
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
std::vector<ir::Value*> elements;
elements.reserve(symbol->const_array.size());
for (const auto& value : symbol->const_array) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
global->SetInitializer(CreateTypedConstant(init));
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
}
var_def->accept(this);
return {};
}
ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr<ir::Type> allocated_type,
const std::string& name) {
if (current_function_ == nullptr || current_function_->GetEntryBlock() == nullptr) {
throw std::runtime_error("CreateEntryAlloca requires an active function entry block");
}
auto* entry = current_function_->GetEntryBlock();
size_t insert_pos = 0;
for (const auto& inst : entry->GetInstructions()) {
if (!ir::isa<ir::AllocaInst>(inst.get())) {
break;
}
++insert_pos;
// 当前仍是教学用的最小版本,因此这里只支持:
// - 局部 int 变量;
// - 标量初始化;
// - 一个 VarDef 对应一个槽位。
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
}
return entry->Insert<ir::AllocaInst>(insert_pos, std::move(allocated_type), nullptr,
name);
}
void IRGenImpl::EmitLocalVarDef(SysYParser::VarDefContext& ctx, SemanticType type,
bool is_const) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
if (entry.is_array) {
entry.ir_value = CreateEntryAlloca(BuildArrayType(type, entry.dims),
NextTemp());
} else {
entry.ir_value = CreateEntryAlloca(GetIRScalarType(type),
NextTemp());
if (!ctx->lValue()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
GetLValueName(*ctx->lValue());
if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
auto* symbol = symbols_.Lookup(name);
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
if (!entry.is_array) {
TypedValue init_value{ZeroIRValue(type), type, false, {}};
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init_value = CastScalar(EmitExp(*ctx.initVal()->exp()), type, ctx.initVal());
ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) {
if (!init_value->exp()) {
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
}
builder_.CreateStore(init_value.value, symbol->ir_value);
return;
}
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
if (ctx.initVal() != nullptr) {
auto init_slots = FlattenLocalInitVal(ctx.initVal(), symbol->dims);
StoreLocalArrayElements(symbol->ir_value, type, symbol->dims, init_slots);
}
}
void IRGenImpl::EmitLocalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = SymbolKind::Constant;
entry.type = type;
entry.is_const = true;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
entry.ir_value = CreateEntryAlloca(
entry.is_array ? BuildArrayType(type, entry.dims) : GetIRScalarType(type),
NextTemp());
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
}
auto* symbol = symbols_.Lookup(name);
if (!entry.is_array) {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
builder_.CreateStore(CreateTypedConstant(init), symbol->ir_value);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
for (size_t i = 0; i < symbol->const_array.size(); ++i) {
if (symbol->const_array[i].type == SemanticType::Int && symbol->const_array[i].int_value == 0) {
continue;
}
if (symbol->const_array[i].type == SemanticType::Float &&
symbol->const_array[i].float_value == 0.0f) {
continue;
}
const auto indices = ExpandLinearIndex(symbol->dims, i);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* addr = CreateArrayElementAddr(symbol->ir_value, false, type, symbol->dims,
index_values, &ctx);
builder_.CreateStore(CreateTypedConstant(symbol->const_array[i]), addr);
}
}
std::vector<ConstantValue> IRGenImpl::FlattenConstInitVal(
SysYParser::ConstInitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenConstInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<ConstantValue> IRGenImpl::FlattenInitVal(
SysYParser::InitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<IRGenImpl::InitExprSlot> IRGenImpl::FlattenLocalInitVal(
SysYParser::InitValContext* ctx, const std::vector<int>& dims) {
std::vector<InitExprSlot> out;
if (ctx != nullptr) {
size_t cursor = 0;
FlattenLocalInitValImpl(ctx, dims, 0, 0, CountArrayElements(dims), cursor, out);
}
return out;
}
void IRGenImpl::FlattenConstInitValImpl(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->constExp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type);
return;
}
for (auto* child : ctx->constInitVal()) {
if (cursor >= object_end) {
break;
}
if (child->constExp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenConstInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenConstInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenInitValImpl(SysYParser::InitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstExp(*ctx->exp()), base_type);
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenLocalInitValImpl(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<InitExprSlot>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out.push_back({cursor++, ctx->exp()});
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenLocalInitValImpl(child, dims, depth + 1, child_begin, child_end,
cursor, out);
cursor = child_end;
} else {
FlattenLocalInitValImpl(child, dims, depth + 1, object_begin, object_end,
cursor, out);
}
}
}
size_t IRGenImpl::CountArrayElements(const std::vector<int>& dims, size_t start) const {
size_t count = 1;
for (size_t i = start; i < dims.size(); ++i) {
count *= static_cast<size_t>(dims[i]);
}
return count;
}
size_t IRGenImpl::AlignInitializerCursor(const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t cursor) const {
if (depth + 1 >= dims.size()) {
return cursor;
}
const auto stride = CountArrayElements(dims, depth + 1);
const auto relative = cursor - object_begin;
return object_begin + ((relative + stride - 1) / stride) * stride;
}
size_t IRGenImpl::FlattenIndices(const std::vector<int>& dims,
const std::vector<int>& indices) const {
size_t offset = 0;
for (size_t i = 0; i < dims.size(); ++i) {
offset *= static_cast<size_t>(dims[i]);
offset += static_cast<size_t>(indices[i]);
}
return offset;
}
ConstantValue IRGenImpl::ZeroConst(SemanticType type) const {
ConstantValue value;
value.type = type;
value.int_value = 0;
value.float_value = 0.0f;
return value;
}
ir::Value* IRGenImpl::ZeroIRValue(SemanticType type) {
switch (type) {
case SemanticType::Int:
return builder_.CreateConstInt(0);
case SemanticType::Float:
return builder_.CreateConstFloat(0.0f);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no zero IR value");
}
ir::Value* IRGenImpl::CreateTypedConstant(const ConstantValue& value) {
switch (value.type) {
case SemanticType::Int:
return builder_.CreateConstInt(value.int_value);
case SemanticType::Float:
return builder_.CreateConstFloat(value.float_value);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no constant value");
}
void IRGenImpl::ZeroInitializeLocalArray(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims) {
const auto elem_count = CountArrayElements(dims);
int bytes = static_cast<int>(elem_count * (base_type == SemanticType::Float ? 4 : 4));
builder_.CreateMemset(addr, builder_.CreateConstInt(0), builder_.CreateConstInt(bytes),
builder_.CreateConstBool(false));
}
void IRGenImpl::StoreLocalArrayElements(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims,
const std::vector<InitExprSlot>& init_slots) {
for (const auto& slot : init_slots) {
const auto indices = ExpandLinearIndex(dims, slot.index);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* elem_addr = CreateArrayElementAddr(addr, false, base_type, dims,
index_values, slot.expr);
auto value = CastScalar(EmitExp(*slot.expr), base_type, slot.expr);
builder_.CreateStore(value.value, elem_addr);
init = EvalExpr(*init_value->exp());
} else {
init = builder_.CreateConstInt(0);
}
builder_.CreateStore(init, slot);
return {};
}

@ -1,7 +1,11 @@
#include "irgen/IRGen.h"
#include "irgen/IRGen.h"
#include <memory>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>();

@ -1,693 +1,80 @@
#include "irgen/IRGen.h"
#include <cstdlib>
#include <stdexcept>
#include <utility>
bool IRGenImpl::IsNumeric(const TypedValue& value) const {
return !value.is_array && value.type != SemanticType::Void;
}
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
bool IRGenImpl::IsSameDims(const std::vector<int>& lhs,
const std::vector<int>& rhs) const {
if (rhs.empty()) {
return true;
}
if (lhs == rhs) {
return true;
}
if (lhs.size() == rhs.size() + 1) {
return std::equal(lhs.begin() + 1, lhs.end(), rhs.begin());
}
return false;
// 表达式生成当前也只实现了很小的一个子集。
// 目前支持:
// - 整数字面量
// - 普通局部变量读取
// - 括号表达式
// - 二元加法
//
// 还未支持:
// - 减乘除与一元运算
// - 赋值表达式
// - 函数调用
// - 数组、指针、下标访问
// - 条件与比较表达式
// - ...
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this));
}
IRGenImpl::TypedValue IRGenImpl::CastScalar(
TypedValue value, SemanticType target_type,
const antlr4::ParserRuleContext* ctx) {
if (value.is_array) {
ThrowError(ctx, "????????????");
}
if (target_type == SemanticType::Void || value.type == SemanticType::Void) {
ThrowError(ctx, "void ?????????");
}
if (target_type == SemanticType::Int) {
if (value.type == SemanticType::Int) {
if (value.value->GetType()->IsInt1()) {
value.value = builder_.CreateZext(value.value, ir::Type::GetInt32Type(), NextTemp());
}
value.type = SemanticType::Int;
return value;
}
value.value = builder_.CreateFtoI(value.value, NextTemp());
value.type = SemanticType::Int;
return value;
}
if (target_type == SemanticType::Float) {
if (value.type == SemanticType::Float) {
return value;
}
if (value.value->GetType()->IsInt1()) {
value.value = builder_.CreateZext(value.value, ir::Type::GetInt32Type(), NextTemp());
}
value.value = builder_.CreateIToF(value.value, NextTemp());
value.type = SemanticType::Float;
return value;
}
ThrowError(ctx, "????????????");
}
ir::Value* IRGenImpl::CastToCondition(TypedValue value,
const antlr4::ParserRuleContext* ctx) {
if (value.is_array) {
ThrowError(ctx, "?????????");
}
if (value.type == SemanticType::Void) {
ThrowError(ctx, "void ???????");
}
if (value.value->GetType()->IsInt1()) {
return value.value;
}
if (value.type == SemanticType::Int) {
return builder_.CreateICmp(ir::Opcode::ICmpNE, value.value, builder_.CreateConstInt(0),
NextTemp());
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
}
return builder_.CreateFCmp(ir::Opcode::FCmpNE, value.value,
builder_.CreateConstFloat(0.0f), NextTemp());
return EvalExpr(*ctx->exp());
}
IRGenImpl::TypedValue IRGenImpl::NormalizeLogicalValue(
TypedValue value, const antlr4::ParserRuleContext* ctx) {
auto* cond = CastToCondition(value, ctx);
return {builder_.CreateZext(cond, ir::Type::GetInt32Type(), NextTemp()),
SemanticType::Int, false, {}};
}
ConstantValue IRGenImpl::ParseNumber(SysYParser::NumberContext& ctx) const {
ConstantValue value;
if (ctx.IntConst() != nullptr) {
value.type = SemanticType::Int;
value.int_value = std::stoi(ctx.getText(), nullptr, 0);
value.float_value = static_cast<float>(value.int_value);
return value;
}
if (ctx.FloatConst() != nullptr) {
value.type = SemanticType::Float;
value.float_value = std::strtof(ctx.getText().c_str(), nullptr);
value.int_value = static_cast<int>(value.float_value);
return value;
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
}
ThrowError(&ctx, "?????????");
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
}
ConstantValue IRGenImpl::ConvertConst(ConstantValue value,
SemanticType target_type) const {
if (target_type == SemanticType::Void) {
throw std::runtime_error("void is not a valid constant target type");
// 变量使用的处理流程:
// 1. 先通过语义分析结果把变量使用绑定回声明;
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
if (value.type == target_type) {
return value;
auto* decl = sema_.ResolveVarUse(ctx->var());
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
}
if (target_type == SemanticType::Int) {
value.int_value = static_cast<int>(value.float_value);
value.type = SemanticType::Int;
return value;
auto it = storage_map_.find(decl);
if (it == storage_map_.end()) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
}
value.float_value = static_cast<float>(value.int_value);
value.type = SemanticType::Float;
return value;
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
}
ConstantValue IRGenImpl::EvalConstExp(SysYParser::ExpContext& ctx) {
return EvalConstAddExp(*ctx.addExp());
}
ConstantValue IRGenImpl::EvalConstAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() == nullptr) {
return EvalConstMulExp(*ctx.mulExp());
}
auto lhs = EvalConstAddExp(*ctx.addExp());
auto rhs = EvalConstMulExp(*ctx.mulExp());
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = ConvertConst(lhs, SemanticType::Float);
rhs = ConvertConst(rhs, SemanticType::Float);
lhs.float_value = ctx.op->getType() == SysYParser::ADD
? lhs.float_value + rhs.float_value
: lhs.float_value - rhs.float_value;
lhs.int_value = static_cast<int>(lhs.float_value);
lhs.type = SemanticType::Float;
return lhs;
}
lhs.int_value = ctx.op->getType() == SysYParser::ADD ? lhs.int_value + rhs.int_value
: lhs.int_value - rhs.int_value;
lhs.float_value = static_cast<float>(lhs.int_value);
lhs.type = SemanticType::Int;
return lhs;
}
ConstantValue IRGenImpl::EvalConstMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() == nullptr) {
return EvalConstUnaryExp(*ctx.unaryExp());
}
auto lhs = EvalConstMulExp(*ctx.mulExp());
auto rhs = EvalConstUnaryExp(*ctx.unaryExp());
if (ctx.op->getType() == SysYParser::MOD &&
(lhs.type == SemanticType::Float || rhs.type == SemanticType::Float)) {
ThrowError(&ctx, "?????? % ??");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = ConvertConst(lhs, SemanticType::Float);
rhs = ConvertConst(rhs, SemanticType::Float);
switch (ctx.op->getType()) {
case SysYParser::MUL:
lhs.float_value *= rhs.float_value;
break;
case SysYParser::DIV:
lhs.float_value /= rhs.float_value;
break;
default:
ThrowError(&ctx, "??????????");
}
lhs.int_value = static_cast<int>(lhs.float_value);
lhs.type = SemanticType::Float;
return lhs;
}
switch (ctx.op->getType()) {
case SysYParser::MUL:
lhs.int_value *= rhs.int_value;
break;
case SysYParser::DIV:
lhs.int_value /= rhs.int_value;
break;
case SysYParser::MOD:
lhs.int_value %= rhs.int_value;
break;
default:
ThrowError(&ctx, "????????");
}
lhs.float_value = static_cast<float>(lhs.int_value);
lhs.type = SemanticType::Int;
return lhs;
}
ConstantValue IRGenImpl::EvalConstUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return EvalConstPrimaryExp(*ctx.primaryExp());
}
if (ctx.Ident() != nullptr) {
ThrowError(&ctx, "?????????????");
}
auto operand = EvalConstUnaryExp(*ctx.unaryExp());
if (ctx.unaryOp()->ADD() != nullptr) {
return operand;
}
if (ctx.unaryOp()->SUB() != nullptr) {
if (operand.type == SemanticType::Float) {
operand.float_value = -operand.float_value;
operand.int_value = static_cast<int>(operand.float_value);
} else {
operand.int_value = -operand.int_value;
operand.float_value = static_cast<float>(operand.int_value);
}
return operand;
}
if (ctx.unaryOp()->NOT() != nullptr) {
const bool truthy = operand.type == SemanticType::Float ? operand.float_value != 0.0f
: operand.int_value != 0;
ConstantValue result;
result.type = SemanticType::Int;
result.int_value = truthy ? 0 : 1;
result.float_value = static_cast<float>(result.int_value);
return result;
}
ThrowError(&ctx, "???????");
}
ConstantValue IRGenImpl::EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return EvalConstExp(*ctx.exp());
}
if (ctx.number() != nullptr) {
return ParseNumber(*ctx.number());
}
if (ctx.lVal() != nullptr) {
return EvalConstLVal(*ctx.lVal());
}
ThrowError(&ctx, "???? primaryExp");
}
ConstantValue IRGenImpl::EvalConstLVal(SysYParser::LValContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr) {
ThrowError(&ctx, "?????: " + name);
}
if (!symbol->is_const) {
ThrowError(&ctx, "?????????????: " + name);
}
if (!symbol->is_array) {
if (!ctx.exp().empty()) {
ThrowError(&ctx, "??????????");
}
if (!symbol->const_scalar.has_value()) {
ThrowError(&ctx, "?????: " + name);
}
return *symbol->const_scalar;
}
if (ctx.exp().size() != symbol->dims.size()) {
ThrowError(&ctx, "???????????????????: " + name);
}
std::vector<int> indices;
indices.reserve(ctx.exp().size());
for (auto* exp_ctx : ctx.exp()) {
auto index = ConvertConst(EvalConstExp(*exp_ctx), SemanticType::Int);
indices.push_back(index.int_value);
}
for (size_t i = 0; i < indices.size(); ++i) {
if (indices[i] < 0 || indices[i] >= symbol->dims[i]) {
ThrowError(&ctx, "????????: " + name);
}
}
const auto offset = FlattenIndices(symbol->dims, indices);
if (offset >= symbol->const_array.size()) {
ThrowError(&ctx, "???????????: " + name);
}
return symbol->const_array[offset];
}
IRGenImpl::TypedValue IRGenImpl::EmitExp(SysYParser::ExpContext& ctx) {
return EmitAddExp(*ctx.addExp());
}
IRGenImpl::TypedValue IRGenImpl::EmitAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() == nullptr) {
return EmitMulExp(*ctx.mulExp());
}
auto lhs = EmitAddExp(*ctx.addExp());
auto rhs = EmitMulExp(*ctx.mulExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
auto* value = ctx.op->getType() == SysYParser::ADD
? builder_.CreateBinary(ir::Opcode::FAdd, lhs.value, rhs.value,
NextTemp())
: builder_.CreateBinary(ir::Opcode::FSub, lhs.value, rhs.value,
NextTemp());
return {value, SemanticType::Float, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
auto* value = ctx.op->getType() == SysYParser::ADD
? builder_.CreateAdd(lhs.value, rhs.value, NextTemp())
: builder_.CreateSub(lhs.value, rhs.value, NextTemp());
return {value, SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() == nullptr) {
return EmitUnaryExp(*ctx.unaryExp());
}
auto lhs = EmitMulExp(*ctx.mulExp());
auto rhs = EmitUnaryExp(*ctx.unaryExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "????????????");
}
if (ctx.op->getType() == SysYParser::MOD &&
(lhs.type == SemanticType::Float || rhs.type == SemanticType::Float)) {
ThrowError(&ctx, "?????? % ??");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
ir::Opcode opcode = ir::Opcode::FMul;
if (ctx.op->getType() == SysYParser::DIV) {
opcode = ir::Opcode::FDiv;
} else if (ctx.op->getType() == SysYParser::MUL) {
opcode = ir::Opcode::FMul;
} else {
ThrowError(&ctx, "?????????");
}
return {builder_.CreateBinary(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Float, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
ir::Value* value = nullptr;
switch (ctx.op->getType()) {
case SysYParser::MUL:
value = builder_.CreateMul(lhs.value, rhs.value, NextTemp());
break;
case SysYParser::DIV:
value = builder_.CreateDiv(lhs.value, rhs.value, NextTemp());
break;
case SysYParser::MOD:
value = builder_.CreateRem(lhs.value, rhs.value, NextTemp());
break;
default:
ThrowError(&ctx, "???????");
}
return {value, SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return EmitPrimaryExp(*ctx.primaryExp());
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
if (ctx.Ident() != nullptr) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || symbol->kind != SymbolKind::Function || symbol->function == nullptr) {
ThrowError(&ctx, "????????: " + name);
}
const auto& function_type = symbol->function_type;
std::vector<ir::Value*> args;
std::vector<SysYParser::ExpContext*> arg_exprs;
if (ctx.funcRParams() != nullptr) {
arg_exprs = ctx.funcRParams()->exp();
}
if (arg_exprs.size() != function_type.param_types.size()) {
ThrowError(&ctx, "?????????: " + name);
}
for (size_t i = 0; i < arg_exprs.size(); ++i) {
auto arg = EmitExp(*arg_exprs[i]);
if (i < function_type.param_is_array.size() && function_type.param_is_array[i]) {
if (!arg.is_array || !IsSameDims(arg.dims, function_type.param_dims[i])) {
ThrowError(arg_exprs[i], "????????????: " + name);
}
args.push_back(arg.value);
} else {
if (arg.is_array) {
ThrowError(arg_exprs[i], "??????????: " + name);
}
arg = CastScalar(arg, function_type.param_types[i], arg_exprs[i]);
args.push_back(arg.value);
}
}
if (function_type.return_type == SemanticType::Void) {
builder_.CreateCall(symbol->function, args);
return {nullptr, SemanticType::Void, false, {}};
}
return {builder_.CreateCall(symbol->function, args, NextTemp()),
function_type.return_type, false, {}};
}
auto operand = EmitUnaryExp(*ctx.unaryExp());
if (!IsNumeric(operand)) {
ThrowError(&ctx, "???????????");
}
if (ctx.unaryOp()->ADD() != nullptr) {
return operand;
}
if (ctx.unaryOp()->SUB() != nullptr) {
if (operand.type == SemanticType::Float) {
return {builder_.CreateFNeg(operand.value, NextTemp()), SemanticType::Float,
false, {}};
}
operand = CastScalar(operand, SemanticType::Int, &ctx);
return {builder_.CreateSub(builder_.CreateConstInt(0), operand.value, NextTemp()),
SemanticType::Int, false, {}};
}
if (ctx.unaryOp()->NOT() != nullptr) {
auto* cond = CastToCondition(operand, &ctx);
auto* inverted = builder_.CreateXor(cond, builder_.CreateConstBool(true), NextTemp());
return {builder_.CreateZext(inverted, ir::Type::GetInt32Type(), NextTemp()),
SemanticType::Int, false, {}};
}
ThrowError(&ctx, "???????");
}
IRGenImpl::TypedValue IRGenImpl::EmitPrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return EmitExp(*ctx.exp());
}
if (ctx.number() != nullptr) {
auto number = ParseNumber(*ctx.number());
return {CreateTypedConstant(number), number.type, false, {}};
}
if (ctx.lVal() != nullptr) {
return EmitLValValue(*ctx.lVal());
}
ThrowError(&ctx, "?? primaryExp");
}
IRGenImpl::TypedValue IRGenImpl::EmitRelExp(SysYParser::RelExpContext& ctx) {
if (ctx.relExp() == nullptr) {
return EmitAddExp(*ctx.addExp());
}
auto lhs = EmitRelExp(*ctx.relExp());
auto rhs = EmitAddExp(*ctx.addExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
ir::Opcode opcode = ir::Opcode::FCmpLT;
switch (ctx.op->getType()) {
case SysYParser::LT:
opcode = ir::Opcode::FCmpLT;
break;
case SysYParser::GT:
opcode = ir::Opcode::FCmpGT;
break;
case SysYParser::LE:
opcode = ir::Opcode::FCmpLE;
break;
case SysYParser::GE:
opcode = ir::Opcode::FCmpGE;
break;
default:
ThrowError(&ctx, "????????");
}
return {builder_.CreateFCmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
ir::Opcode opcode = ir::Opcode::ICmpLT;
switch (ctx.op->getType()) {
case SysYParser::LT:
opcode = ir::Opcode::ICmpLT;
break;
case SysYParser::GT:
opcode = ir::Opcode::ICmpGT;
break;
case SysYParser::LE:
opcode = ir::Opcode::ICmpLE;
break;
case SysYParser::GE:
opcode = ir::Opcode::ICmpGE;
break;
default:
ThrowError(&ctx, "????????");
}
return {builder_.CreateICmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitEqExp(SysYParser::EqExpContext& ctx) {
if (ctx.eqExp() == nullptr) {
return EmitRelExp(*ctx.relExp());
}
auto lhs = EmitEqExp(*ctx.eqExp());
auto rhs = EmitRelExp(*ctx.relExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
const auto opcode = ctx.op->getType() == SysYParser::EQ ? ir::Opcode::FCmpEQ
: ir::Opcode::FCmpNE;
return {builder_.CreateFCmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
const auto opcode = ctx.op->getType() == SysYParser::EQ ? ir::Opcode::ICmpEQ
: ir::Opcode::ICmpNE;
return {builder_.CreateICmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
IRGenImpl::LValueInfo IRGenImpl::ResolveLVal(SysYParser::LValContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr) {
ThrowError(&ctx, "?????????: " + name);
}
if (symbol->kind == SymbolKind::Function) {
ThrowError(&ctx, "????????: " + name);
}
std::vector<ir::Value*> index_values;
index_values.reserve(ctx.exp().size());
for (auto* exp_ctx : ctx.exp()) {
auto index = CastScalar(EmitExp(*exp_ctx), SemanticType::Int, exp_ctx);
if (index.is_array) {
ThrowError(exp_ctx, "????????");
}
index_values.push_back(index.value);
}
if (!symbol->is_array) {
if (!index_values.empty()) {
ThrowError(&ctx, "??????????: " + name);
}
return {symbol, symbol->ir_value, symbol->type, false, {}, false};
}
std::vector<int> selected_dims;
if (symbol->is_param_array) {
if (index_values.size() > symbol->dims.size() + 1) {
ThrowError(&ctx, "????????: " + name);
}
if (index_values.empty()) {
selected_dims = symbol->dims;
} else if (index_values.size() <= 1) {
selected_dims = symbol->dims;
} else {
selected_dims.assign(symbol->dims.begin() + static_cast<long long>(index_values.size() - 1),
symbol->dims.end());
}
} else {
if (index_values.size() > symbol->dims.size()) {
ThrowError(&ctx, "??????: " + name);
}
selected_dims.assign(symbol->dims.begin() + static_cast<long long>(index_values.size()),
symbol->dims.end());
}
ir::Value* addr = symbol->ir_value;
if (!index_values.empty()) {
addr = CreateArrayElementAddr(symbol->ir_value, symbol->is_param_array, symbol->type,
symbol->dims, index_values, &ctx);
}
const bool root_param_array_no_index = symbol->is_param_array && index_values.empty();
const bool still_array = !selected_dims.empty() || root_param_array_no_index;
return {symbol, addr, symbol->type, still_array, selected_dims,
root_param_array_no_index};
}
ir::Value* IRGenImpl::GenLValAddr(SysYParser::LValContext& ctx) {
auto info = ResolveLVal(ctx);
if (info.is_array) {
ThrowError(&ctx, "?????????????");
}
return info.addr;
}
IRGenImpl::TypedValue IRGenImpl::EmitLValValue(SysYParser::LValContext& ctx) {
auto info = ResolveLVal(ctx);
if (!info.is_array) {
if (info.symbol != nullptr && info.symbol->const_scalar.has_value()) {
return {CreateTypedConstant(*info.symbol->const_scalar), info.type, false, {}};
}
return {builder_.CreateLoad(info.addr, GetIRScalarType(info.type), NextTemp()),
info.type, false, {}};
}
if (info.root_param_array_no_index) {
return {info.addr, info.type, true, info.dims};
}
auto* decayed = builder_.CreateGEP(info.addr, BuildArrayType(info.type, info.dims),
{builder_.CreateConstInt(0), builder_.CreateConstInt(0)},
NextTemp());
std::vector<int> decay_dims;
if (!info.dims.empty()) {
decay_dims.assign(info.dims.begin() + 1, info.dims.end());
}
return {decayed, info.type, true, decay_dims};
}
ir::Value* IRGenImpl::CreateArrayElementAddr(
ir::Value* base_addr, bool is_param_array, SemanticType base_type,
const std::vector<int>& dims, const std::vector<ir::Value*>& indices,
const antlr4::ParserRuleContext* ctx) {
if (base_addr == nullptr) {
ThrowError(ctx, "???????");
}
if (indices.empty()) {
return base_addr;
}
std::vector<ir::Value*> gep_indices;
if (!is_param_array) {
gep_indices.push_back(builder_.CreateConstInt(0));
}
gep_indices.insert(gep_indices.end(), indices.begin(), indices.end());
auto source_type = dims.empty() ? GetIRScalarType(base_type) : BuildArrayType(base_type, dims);
return builder_.CreateGEP(base_addr, source_type, gep_indices, NextTemp());
}
void IRGenImpl::EmitCond(SysYParser::CondContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
EmitLOrCond(*ctx.lOrExp(), true_block, false_block);
}
void IRGenImpl::EmitLOrCond(SysYParser::LOrExpContext& ctx,
ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
if (ctx.lOrExp() == nullptr) {
EmitLAndCond(*ctx.lAndExp(), true_block, false_block);
return;
}
auto* rhs_block = current_function_->CreateBlock(NextBlockName("lor.rhs"));
EmitLOrCond(*ctx.lOrExp(), true_block, rhs_block);
builder_.SetInsertPoint(rhs_block);
EmitLAndCond(*ctx.lAndExp(), true_block, false_block);
}
void IRGenImpl::EmitLAndCond(SysYParser::LAndExpContext& ctx,
ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
if (ctx.lAndExp() == nullptr) {
auto cond = EmitEqExp(*ctx.eqExp());
builder_.CreateCondBr(CastToCondition(cond, &ctx), true_block, false_block);
return;
}
auto* rhs_block = current_function_->CreateBlock(NextBlockName("land.rhs"));
EmitLAndCond(*ctx.lAndExp(), rhs_block, false_block);
builder_.SetInsertPoint(rhs_block);
auto cond = EmitEqExp(*ctx.eqExp());
builder_.CreateCondBr(CastToCondition(cond, &ctx), true_block, false_block);
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
return static_cast<ir::Value*>(
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
module_.GetContext().NextTemp()));
}

@ -1,255 +1,87 @@
#include "irgen/IRGen.h"
#include <stdexcept>
#include <utility>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), sema_(sema), builder_(module.GetContext(), nullptr) {}
namespace {
[[noreturn]] void IRGenImpl::ThrowError(
const antlr4::ParserRuleContext* ctx, const std::string& message) const {
if (ctx != nullptr && ctx->getStart() != nullptr) {
throw std::runtime_error(FormatErrorAt("irgen",
static_cast<size_t>(ctx->getStart()->getLine()),
static_cast<size_t>(ctx->getStart()->getCharPositionInLine() + 1),
message));
void VerifyFunctionStructure(const ir::Function& func) {
// 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。
for (const auto& bb : func.GetBlocks()) {
if (!bb || !bb->HasTerminator()) {
throw std::runtime_error(
FormatError("irgen", "基本块未正确终结: " +
(bb ? bb->GetName() : std::string("<null>"))));
}
}
throw std::runtime_error(FormatError("irgen", message));
}
std::string IRGenImpl::NextTemp() { return module_.GetContext().NextTemp(); }
std::string IRGenImpl::NextBlockName(const std::string& prefix) {
return module_.GetContext().NextBlockName(prefix);
}
} // namespace
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module),
sema_(sema),
func_(nullptr),
builder_(module.GetContext(), nullptr) {}
// 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
//
// 当前还没有实现:
// - 多个函数定义的遍历与生成;
// - 全局变量、全局常量的 IR 生成。
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
}
symbols_.Clear();
symbols_.EnterScope();
RegisterBuiltinFunctions();
PredeclareTopLevel(*ctx);
for (auto* child : ctx->children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
EmitGlobalDecl(*decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
EmitFunction(*func);
}
auto* func = ctx->funcDef();
if (!func) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
symbols_.ExitScope();
func->accept(this);
return {};
}
void IRGenImpl::RegisterBuiltinFunctions() {
if (builtins_registered_) {
return;
// 函数 IR 生成当前实现了:
// 1. 获取函数名;
// 2. 检查函数返回类型;
// 3. 在 Module 中创建 Function
// 4. 将 builder 插入点设置到入口基本块;
// 5. 继续生成函数体。
//
// 当前还没有实现:
// - 通用函数返回类型处理;
// - 形参列表遍历与参数类型收集;
// - FunctionType 这样的函数类型对象;
// - Argument/形式参数 IR 对象;
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
struct BuiltinSpec {
const char* name;
SemanticType return_type;
std::vector<SemanticType> param_types;
std::vector<bool> param_is_array;
std::vector<std::vector<int>> param_dims;
};
const std::vector<BuiltinSpec> builtins = {
{"getint", SemanticType::Int, {}, {}, {}},
{"getch", SemanticType::Int, {}, {}, {}},
{"getfloat", SemanticType::Float, {}, {}, {}},
{"getarray", SemanticType::Int, {SemanticType::Int}, {true}, {std::vector<int>{}}},
{"getfarray", SemanticType::Int, {SemanticType::Float}, {true}, {std::vector<int>{}}},
{"putint", SemanticType::Void, {SemanticType::Int}, {false}, {std::vector<int>{}}},
{"putch", SemanticType::Void, {SemanticType::Int}, {false}, {std::vector<int>{}}},
{"putfloat", SemanticType::Void, {SemanticType::Float}, {false}, {std::vector<int>{}}},
{"putarray", SemanticType::Void, {SemanticType::Int, SemanticType::Int}, {false, true}, {std::vector<int>{}, std::vector<int>{}}},
{"putfarray", SemanticType::Void, {SemanticType::Int, SemanticType::Float}, {false, true}, {std::vector<int>{}, std::vector<int>{}}},
{"starttime", SemanticType::Void, {}, {}, {}},
{"stoptime", SemanticType::Void, {}, {}, {}},
};
for (const auto& builtin : builtins) {
FunctionTypeInfo function_type;
function_type.return_type = builtin.return_type;
function_type.param_types = builtin.param_types;
function_type.param_is_array = builtin.param_is_array;
function_type.param_dims = builtin.param_dims;
std::vector<std::shared_ptr<ir::Type>> ir_param_types;
std::vector<std::string> ir_param_names;
for (size_t i = 0; i < builtin.param_types.size(); ++i) {
if (i < builtin.param_is_array.size() && builtin.param_is_array[i]) {
ir_param_types.push_back(ir::Type::GetPointerType());
} else {
ir_param_types.push_back(GetIRScalarType(builtin.param_types[i]));
}
ir_param_names.push_back("%arg" + std::to_string(i));
}
auto* function = module_.CreateFunction(
builtin.name, GetIRScalarType(builtin.return_type), ir_param_types,
ir_param_names, true);
SymbolEntry entry;
entry.kind = SymbolKind::Function;
entry.type = builtin.return_type;
entry.function = function;
entry.ir_value = function;
entry.function_type = std::move(function_type);
symbols_.Insert(builtin.name, entry);
if (!ctx->blockStmt()) {
throw std::runtime_error(FormatError("irgen", "函数体为空"));
}
builtins_registered_ = true;
}
void IRGenImpl::PredeclareTopLevel(SysYParser::CompUnitContext& ctx) {
for (auto* child : ctx.children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
PredeclareGlobalDecl(*decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
PredeclareFunction(*func);
}
if (!ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "缺少函数名"));
}
}
FunctionTypeInfo IRGenImpl::BuildFunctionTypeInfo(
SysYParser::FuncDefContext& ctx) {
FunctionTypeInfo function_type;
function_type.return_type = ParseFuncType(ctx.funcType());
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
const auto type = ParseBType(param->bType());
const auto dims = param->LBRACK().empty() ? std::vector<int>{} : ParseParamDims(*param);
function_type.param_types.push_back(type);
function_type.param_is_array.push_back(!param->LBRACK().empty());
function_type.param_dims.push_back(dims);
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
}
return function_type;
}
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
std::vector<std::shared_ptr<ir::Type>> IRGenImpl::BuildFunctionIRParamTypes(
const FunctionTypeInfo& function_type) const {
std::vector<std::shared_ptr<ir::Type>> param_types;
for (size_t i = 0; i < function_type.param_types.size(); ++i) {
if (i < function_type.param_is_array.size() && function_type.param_is_array[i]) {
param_types.push_back(ir::Type::GetPointerType());
} else {
param_types.push_back(GetIRScalarType(function_type.param_types[i]));
}
}
return param_types;
}
std::vector<std::string> IRGenImpl::BuildFunctionIRParamNames(
SysYParser::FuncDefContext& ctx) const {
std::vector<std::string> param_names;
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
param_names.push_back("%" + ExpectIdent(*param, param->Ident()));
}
}
return param_names;
}
void IRGenImpl::PredeclareFunction(SysYParser::FuncDefContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
auto function_type = BuildFunctionTypeInfo(ctx);
auto* function = module_.CreateFunction(
name, GetIRScalarType(function_type.return_type),
BuildFunctionIRParamTypes(function_type), BuildFunctionIRParamNames(ctx), false);
SymbolEntry entry;
entry.kind = SymbolKind::Function;
entry.type = function_type.return_type;
entry.function = function;
entry.ir_value = function;
entry.function_type = std::move(function_type);
symbols_.Insert(name, entry);
}
void IRGenImpl::EmitFunction(SysYParser::FuncDefContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || symbol->kind != SymbolKind::Function || symbol->function == nullptr) {
ThrowError(&ctx, "????????: " + name);
}
current_function_ = symbol->function;
current_return_type_ = symbol->function_type.return_type;
current_function_->SetExternal(false);
auto* entry_block = current_function_->EnsureEntryBlock();
builder_.SetInsertPoint(entry_block);
symbols_.EnterScope();
BindFunctionParams(ctx, *current_function_);
EmitBlock(*ctx.block(), false);
symbols_.ExitScope();
auto* insert_block = builder_.GetInsertBlock();
if (insert_block != nullptr && !insert_block->HasTerminator()) {
if (current_return_type_ == SemanticType::Void) {
builder_.CreateRet();
} else {
builder_.CreateRet(ZeroIRValue(current_return_type_));
}
}
builder_.SetInsertPoint(nullptr);
current_function_ = nullptr;
current_return_type_ = SemanticType::Void;
loop_stack_.clear();
}
void IRGenImpl::BindFunctionParams(SysYParser::FuncDefContext& ctx, ir::Function& func) {
if (ctx.funcFParams() == nullptr) {
return;
}
const auto& params = ctx.funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
auto* param = params[i];
const auto name = ExpectIdent(*param, param->Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(param, "??????: " + name);
}
SymbolEntry entry;
entry.kind = SymbolKind::Variable;
entry.type = ParseBType(param->bType());
entry.is_const = false;
entry.is_array = !param->LBRACK().empty();
entry.is_param_array = entry.is_array;
entry.dims = entry.is_array ? ParseParamDims(*param) : std::vector<int>{};
auto* arg = func.GetArgument(i);
if (arg == nullptr) {
ThrowError(param, "????????: " + name);
}
if (entry.is_array) {
entry.ir_value = arg;
} else {
auto* slot = CreateEntryAlloca(GetIRScalarType(entry.type), NextTemp());
builder_.CreateStore(arg, slot);
entry.ir_value = slot;
}
if (!symbols_.Insert(name, entry)) {
ThrowError(param, "??????: " + name);
}
}
ctx->blockStmt()->accept(this);
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
VerifyFunctionStructure(*func_);
return {};
}

@ -1,159 +1,39 @@
#include "irgen/IRGen.h"
#include "irgen/IRGen.h"
void IRGenImpl::EmitBlock(SysYParser::BlockContext& ctx, bool create_scope) {
if (create_scope) {
symbols_.EnterScope();
}
for (auto* item : ctx.blockItem()) {
if (item != nullptr && EmitBlockItem(*item) == FlowState::Terminated) {
break;
}
}
if (create_scope) {
symbols_.ExitScope();
}
}
IRGenImpl::FlowState IRGenImpl::EmitBlockItem(SysYParser::BlockItemContext& ctx) {
if (ctx.decl() != nullptr) {
EmitDecl(*ctx.decl(), false);
return FlowState::Continue;
}
if (ctx.stmt() != nullptr) {
return EmitStmt(*ctx.stmt());
}
ThrowError(&ctx, "??????");
}
IRGenImpl::FlowState IRGenImpl::EmitStmt(SysYParser::StmtContext& ctx) {
auto branch_terminated = [this]() {
auto* block = builder_.GetInsertBlock();
return block == nullptr || block->HasTerminator();
};
if (ctx.ASSIGN() != nullptr) {
auto lhs = ResolveLVal(*ctx.lVal());
if (lhs.is_array) {
ThrowError(&ctx, "????????");
}
if (lhs.symbol != nullptr && lhs.symbol->is_const) {
ThrowError(&ctx, "??? const ????");
}
auto rhs = CastScalar(EmitExp(*ctx.exp()), lhs.type, ctx.exp());
builder_.CreateStore(rhs.value, lhs.addr);
return FlowState::Continue;
}
if (ctx.RETURN() != nullptr) {
if (current_return_type_ == SemanticType::Void) {
if (ctx.exp() != nullptr) {
ThrowError(&ctx, "void ?????????");
}
builder_.CreateRet();
} else {
if (ctx.exp() == nullptr) {
ThrowError(&ctx, "? void ?????????");
}
auto value = CastScalar(EmitExp(*ctx.exp()), current_return_type_, ctx.exp());
builder_.CreateRet(value.value);
}
return FlowState::Terminated;
}
if (ctx.block() != nullptr) {
EmitBlock(*ctx.block(), true);
return branch_terminated() ? FlowState::Terminated : FlowState::Continue;
}
if (ctx.IF() != nullptr) {
auto* then_block = current_function_->CreateBlock(NextBlockName("if.then"));
if (ctx.ELSE() == nullptr) {
auto* end_block = current_function_->CreateBlock(NextBlockName("if.end"));
EmitCond(*ctx.cond(), then_block, end_block);
builder_.SetInsertPoint(then_block);
auto then_state = EmitStmt(*ctx.stmt(0));
if (then_state != FlowState::Terminated && !branch_terminated()) {
builder_.CreateBr(end_block);
}
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
}
auto* else_block = current_function_->CreateBlock(NextBlockName("if.else"));
EmitCond(*ctx.cond(), then_block, else_block);
ir::BasicBlock* end_block = nullptr;
builder_.SetInsertPoint(then_block);
auto then_state = EmitStmt(*ctx.stmt(0));
const bool then_terminated = then_state == FlowState::Terminated || branch_terminated();
if (!then_terminated) {
if (end_block == nullptr) {
end_block = current_function_->CreateBlock(NextBlockName("if.end"));
}
builder_.CreateBr(end_block);
}
#include <stdexcept>
builder_.SetInsertPoint(else_block);
auto else_state = EmitStmt(*ctx.stmt(1));
const bool else_terminated = else_state == FlowState::Terminated || branch_terminated();
if (!else_terminated) {
if (end_block == nullptr) {
end_block = current_function_->CreateBlock(NextBlockName("if.end"));
}
builder_.CreateBr(end_block);
}
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
if (end_block == nullptr) {
builder_.SetInsertPoint(nullptr);
return FlowState::Terminated;
}
// 语句生成当前只实现了最小子集。
// 目前支持:
// - return <exp>;
//
// 还未支持:
// - 赋值语句
// - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
if (ctx.WHILE() != nullptr) {
auto* cond_block = current_function_->CreateBlock(NextBlockName("while.cond"));
auto* body_block = current_function_->CreateBlock(NextBlockName("while.body"));
auto* end_block = current_function_->CreateBlock(NextBlockName("while.end"));
builder_.CreateBr(cond_block);
builder_.SetInsertPoint(cond_block);
EmitCond(*ctx.cond(), body_block, end_block);
loop_stack_.push_back({cond_block, end_block});
builder_.SetInsertPoint(body_block);
auto body_state = EmitStmt(*ctx.stmt(0));
if (body_state != FlowState::Terminated && !branch_terminated()) {
builder_.CreateBr(cond_block);
}
loop_stack_.pop_back();
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
if (ctx->returnStmt()) {
return ctx->returnStmt()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}
if (ctx.BREAK() != nullptr) {
if (loop_stack_.empty()) {
ThrowError(&ctx, "break ????? while ???");
}
builder_.CreateBr(loop_stack_.back().exit_block);
return FlowState::Terminated;
}
if (ctx.CONTINUE() != nullptr) {
if (loop_stack_.empty()) {
ThrowError(&ctx, "continue ????? while ???");
}
builder_.CreateBr(loop_stack_.back().cond_block);
return FlowState::Terminated;
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
}
if (ctx.exp() != nullptr) {
(void)EmitExp(*ctx.exp());
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
}
return FlowState::Continue;
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
return BlockFlow::Terminated;
}

@ -6,7 +6,6 @@
#include "frontend/SyntaxTreePrinter.h"
#if !COMPILER_PARSE_ONLY
#include "ir/IR.h"
#include "ir/PassManager.h"
#include "irgen/IRGen.h"
#include "mir/MIR.h"
#include "sem/Sema.h"
@ -38,24 +37,12 @@ int main(int argc, char** argv) {
auto module = GenerateIR(*comp_unit, sema);
if (opts.emit_ir) {
std::unique_ptr<ir::Module> ir_module;
if (opts.emit_asm) {
ir_module = GenerateIR(*comp_unit, sema);
} else {
ir_module = std::move(module);
}
ir::RunIRPassPipeline(*ir_module);
ir::IRPrinter printer;
if (need_blank_line) {
std::cout << "\n";
}
printer.Print(*ir_module, std::cout);
printer.Print(*module, std::cout);
need_blank_line = true;
if (!opts.emit_asm) {
module = std::move(ir_module);
}
}
if (opts.emit_asm) {
@ -78,4 +65,4 @@ int main(int argc, char** argv) {
return 1;
}
return 0;
}
}

@ -6,6 +6,5 @@ add_library(sem STATIC
target_link_libraries(sem PUBLIC
build_options
frontend
${ANTLR4_RUNTIME_TARGET}
)

@ -1,4 +1,752 @@
// 常量求值:
// - 处理数组维度、全局初始化、const 表达式等编译期可计算场景
// - 为语义分析与 IR 生成提供常量折叠/常量值信息
#include "sem/ConstEval.h"
#include <any>
#include <cmath>
#include <cstdlib>
#include <limits>
#include <stdexcept>
#include <string>
#include <utility>
#include "SysYBaseVisitor.h"
#include "utils/Log.h"
namespace {
using DataType = SymbolDataType;
bool IsNumericType(DataType type) {
return type == DataType::Int || type == DataType::Float ||
type == DataType::Bool;
}
ConstValue MakeZeroValue(DataType type) {
switch (type) {
case DataType::Float:
return ConstValue::FromFloat(0.0);
case DataType::Int:
return ConstValue::FromInt(0);
case DataType::Bool:
return ConstValue::FromBool(false);
default:
return ConstValue{};
}
}
ConstValue CastToType(ConstValue value, DataType target_type) {
switch (target_type) {
case DataType::Int:
return ConstValue::FromInt(value.AsInt());
case DataType::Float:
return ConstValue::FromFloat(value.AsFloat());
case DataType::Bool:
return ConstValue::FromBool(value.AsBool());
default:
throw std::runtime_error(
FormatError("consteval", "不支持的常量目标类型转换"));
}
}
int64_t ParseIntLiteral(const std::string& text) {
char* end = nullptr;
const long long value = std::strtoll(text.c_str(), &end, 0);
if (end == text.c_str() || *end != '\0') {
throw std::runtime_error(
FormatError("consteval", "整数字面量解析失败: " + text));
}
return static_cast<int64_t>(value);
}
double ParseFloatLiteral(const std::string& text) {
char* end = nullptr;
const double value = std::strtod(text.c_str(), &end);
if (end == text.c_str() || *end != '\0') {
throw std::runtime_error(
FormatError("consteval", "浮点数字面量解析失败: " + text));
}
return value;
}
size_t Product(const std::vector<int64_t>& dims, size_t begin) {
size_t result = 1;
for (size_t i = begin; i < dims.size(); ++i) {
if (dims[i] <= 0) {
throw std::runtime_error(
FormatError("consteval", "数组维度必须为正整数"));
}
const size_t dim = static_cast<size_t>(dims[i]);
if (result > std::numeric_limits<size_t>::max() / dim) {
throw std::runtime_error(
FormatError("consteval", "数组维度乘积溢出"));
}
result *= dim;
}
return result;
}
class ConstEvalVisitor final : public SysYBaseVisitor {
public:
ConstEvalVisitor(const SymbolTable& table, const ConstEvalContext& values)
: table_(table), values_(values) {}
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) {
return Evaluate(ctx.addExp());
}
ConstValue EvaluateExp(SysYParser::ExpContext& ctx) {
return Evaluate(ctx.addExp());
}
std::any visitExp(SysYParser::ExpContext* ctx) override {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("consteval", "非法表达式"));
}
return Evaluate(ctx->addExp());
}
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("consteval", "非法 constExp"));
}
return Evaluate(ctx->addExp());
}
std::any visitCond(SysYParser::CondContext* ctx) override {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("consteval", "非法条件表达式"));
}
return Evaluate(ctx->lOrExp());
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("consteval", "非法左值"));
}
const std::string name = ctx->Ident()->getText();
const SymbolEntry* symbol = table_.Lookup(name);
if (!symbol) {
throw std::runtime_error(FormatError("consteval", "未定义符号: " + name));
}
if (!symbol->is_const && symbol->kind != SymbolKind::Constant) {
throw std::runtime_error(
FormatError("consteval", "常量表达式中使用了非常量符号: " + name));
}
const size_t index_count = ctx->exp().size();
if (index_count == 0) {
if (const ConstValue* scalar = values_.LookupScalar(name)) {
return *scalar;
}
if (values_.LookupArray(name)) {
throw std::runtime_error(
FormatError("consteval", "数组名不能作为标量常量参与求值: " + name));
}
throw std::runtime_error(
FormatError("consteval", "常量符号缺少编译期值: " + name));
}
const ConstArrayValue* array = values_.LookupArray(name);
if (!array) {
throw std::runtime_error(
FormatError("consteval", "下标访问目标不是常量数组: " + name));
}
if (index_count != array->dims.size()) {
throw std::runtime_error(
FormatError("consteval", "常量数组索引维度不匹配: " + name));
}
size_t linear_index = 0;
for (size_t i = 0; i < index_count; ++i) {
const ConstValue index_value = Evaluate(ctx->exp(i));
const int64_t index = index_value.AsInt();
const int64_t dim = array->dims[i];
if (index < 0 || index >= dim) {
throw std::runtime_error(
FormatError("consteval", "常量数组访问越界: " + name));
}
linear_index = linear_index * static_cast<size_t>(dim) +
static_cast<size_t>(index);
}
if (linear_index >= array->elements.size()) {
throw std::runtime_error(
FormatError("consteval", "常量数组线性索引越界: " + name));
}
return array->elements[linear_index];
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 primaryExp"));
}
if (ctx->exp()) {
return Evaluate(ctx->exp());
}
if (ctx->lVal()) {
return Evaluate(ctx->lVal());
}
if (ctx->number()) {
return Evaluate(ctx->number());
}
throw std::runtime_error(FormatError("consteval", "无法识别的 primaryExp"));
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法数字节点"));
}
if (ctx->IntConst()) {
return ConstValue::FromInt(ParseIntLiteral(ctx->IntConst()->getText()));
}
if (ctx->FloatConst()) {
return ConstValue::FromFloat(
ParseFloatLiteral(ctx->FloatConst()->getText()));
}
throw std::runtime_error(FormatError("consteval", "未知数字字面量类型"));
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 unaryExp"));
}
if (ctx->primaryExp()) {
return Evaluate(ctx->primaryExp());
}
if (ctx->Ident()) {
throw std::runtime_error(
FormatError("consteval", "常量表达式中不允许函数调用: " +
ctx->Ident()->getText()));
}
if (!ctx->unaryOp() || !ctx->unaryExp()) {
throw std::runtime_error(FormatError("consteval", "非法一元表达式结构"));
}
const ConstValue operand = Evaluate(ctx->unaryExp());
if (ctx->unaryOp()->ADD()) {
if (!operand.IsNumeric()) {
throw std::runtime_error(FormatError("consteval", "一元加仅支持数值类型"));
}
if (operand.type == DataType::Float) {
return ConstValue::FromFloat(+operand.AsFloat());
}
return ConstValue::FromInt(+operand.AsInt());
}
if (ctx->unaryOp()->SUB()) {
if (!operand.IsNumeric()) {
throw std::runtime_error(FormatError("consteval", "一元减仅支持数值类型"));
}
if (operand.type == DataType::Float) {
return ConstValue::FromFloat(-operand.AsFloat());
}
return ConstValue::FromInt(-operand.AsInt());
}
if (ctx->unaryOp()->NOT()) {
return ConstValue::FromBool(!operand.AsBool());
}
throw std::runtime_error(FormatError("consteval", "未知一元运算符"));
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 mulExp"));
}
if (!ctx->mulExp()) {
return Evaluate(ctx->unaryExp());
}
const ConstValue lhs = Evaluate(ctx->mulExp());
const ConstValue rhs = Evaluate(ctx->unaryExp());
if (!lhs.IsNumeric() || !rhs.IsNumeric()) {
throw std::runtime_error(
FormatError("consteval", "乘除模运算只支持数值类型"));
}
const int op = ctx->op ? ctx->op->getType() : 0;
if (op == SysYParser::MUL) {
if (lhs.type == DataType::Float || rhs.type == DataType::Float) {
return ConstValue::FromFloat(lhs.AsFloat() * rhs.AsFloat());
}
return ConstValue::FromInt(lhs.AsInt() * rhs.AsInt());
}
if (op == SysYParser::DIV) {
if (lhs.type == DataType::Float || rhs.type == DataType::Float) {
const double divisor = rhs.AsFloat();
if (divisor == 0.0) {
throw std::runtime_error(FormatError("consteval", "浮点除零"));
}
return ConstValue::FromFloat(lhs.AsFloat() / divisor);
}
const int64_t divisor = rhs.AsInt();
if (divisor == 0) {
throw std::runtime_error(FormatError("consteval", "整数除零"));
}
return ConstValue::FromInt(lhs.AsInt() / divisor);
}
if (op == SysYParser::MOD) {
if (lhs.type == DataType::Float || rhs.type == DataType::Float) {
throw std::runtime_error(
FormatError("consteval", "取模运算不支持浮点类型"));
}
const int64_t divisor = rhs.AsInt();
if (divisor == 0) {
throw std::runtime_error(FormatError("consteval", "整数取模除零"));
}
return ConstValue::FromInt(lhs.AsInt() % divisor);
}
throw std::runtime_error(FormatError("consteval", "未知乘法类运算符"));
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 addExp"));
}
if (!ctx->addExp()) {
return Evaluate(ctx->mulExp());
}
const ConstValue lhs = Evaluate(ctx->addExp());
const ConstValue rhs = Evaluate(ctx->mulExp());
if (!lhs.IsNumeric() || !rhs.IsNumeric()) {
throw std::runtime_error(FormatError("consteval", "加减运算只支持数值类型"));
}
const int op = ctx->op ? ctx->op->getType() : 0;
if (op == SysYParser::ADD) {
if (lhs.type == DataType::Float || rhs.type == DataType::Float) {
return ConstValue::FromFloat(lhs.AsFloat() + rhs.AsFloat());
}
return ConstValue::FromInt(lhs.AsInt() + rhs.AsInt());
}
if (op == SysYParser::SUB) {
if (lhs.type == DataType::Float || rhs.type == DataType::Float) {
return ConstValue::FromFloat(lhs.AsFloat() - rhs.AsFloat());
}
return ConstValue::FromInt(lhs.AsInt() - rhs.AsInt());
}
throw std::runtime_error(FormatError("consteval", "未知加法类运算符"));
}
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 relExp"));
}
if (!ctx->relExp()) {
return Evaluate(ctx->addExp());
}
const ConstValue lhs = Evaluate(ctx->relExp());
const ConstValue rhs = Evaluate(ctx->addExp());
if (!lhs.IsNumeric() || !rhs.IsNumeric()) {
throw std::runtime_error(
FormatError("consteval", "关系比较仅支持数值类型"));
}
const int op = ctx->op ? ctx->op->getType() : 0;
if (lhs.type == DataType::Float || rhs.type == DataType::Float) {
const double left = lhs.AsFloat();
const double right = rhs.AsFloat();
if (op == SysYParser::LT) {
return ConstValue::FromBool(left < right);
}
if (op == SysYParser::GT) {
return ConstValue::FromBool(left > right);
}
if (op == SysYParser::LE) {
return ConstValue::FromBool(left <= right);
}
if (op == SysYParser::GE) {
return ConstValue::FromBool(left >= right);
}
} else {
const int64_t left = lhs.AsInt();
const int64_t right = rhs.AsInt();
if (op == SysYParser::LT) {
return ConstValue::FromBool(left < right);
}
if (op == SysYParser::GT) {
return ConstValue::FromBool(left > right);
}
if (op == SysYParser::LE) {
return ConstValue::FromBool(left <= right);
}
if (op == SysYParser::GE) {
return ConstValue::FromBool(left >= right);
}
}
throw std::runtime_error(FormatError("consteval", "未知关系比较运算符"));
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 eqExp"));
}
if (!ctx->eqExp()) {
return Evaluate(ctx->relExp());
}
const ConstValue lhs = Evaluate(ctx->eqExp());
const ConstValue rhs = Evaluate(ctx->relExp());
const int op = ctx->op ? ctx->op->getType() : 0;
bool result = false;
if (lhs.type == DataType::Float || rhs.type == DataType::Float) {
const double left = lhs.AsFloat();
const double right = rhs.AsFloat();
if (op == SysYParser::EQ) {
result = (left == right);
} else if (op == SysYParser::NE) {
result = (left != right);
} else {
throw std::runtime_error(FormatError("consteval", "未知相等比较运算符"));
}
} else {
const int64_t left = lhs.AsInt();
const int64_t right = rhs.AsInt();
if (op == SysYParser::EQ) {
result = (left == right);
} else if (op == SysYParser::NE) {
result = (left != right);
} else {
throw std::runtime_error(FormatError("consteval", "未知相等比较运算符"));
}
}
return ConstValue::FromBool(result);
}
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 lAndExp"));
}
if (!ctx->lAndExp()) {
return ConstValue::FromBool(Evaluate(ctx->eqExp()).AsBool());
}
const ConstValue lhs = Evaluate(ctx->lAndExp());
if (!lhs.AsBool()) {
return ConstValue::FromBool(false);
}
const ConstValue rhs = Evaluate(ctx->eqExp());
return ConstValue::FromBool(rhs.AsBool());
}
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "非法 lOrExp"));
}
if (!ctx->lOrExp()) {
return ConstValue::FromBool(Evaluate(ctx->lAndExp()).AsBool());
}
const ConstValue lhs = Evaluate(ctx->lOrExp());
if (lhs.AsBool()) {
return ConstValue::FromBool(true);
}
const ConstValue rhs = Evaluate(ctx->lAndExp());
return ConstValue::FromBool(rhs.AsBool());
}
private:
ConstValue Evaluate(antlr4::ParserRuleContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("consteval", "空表达式节点"));
}
std::any result = ctx->accept(this);
try {
return std::any_cast<ConstValue>(result);
} catch (const std::bad_any_cast&) {
throw std::runtime_error(FormatError("consteval", "常量求值类型转换失败"));
}
}
const SymbolTable& table_;
const ConstEvalContext& values_;
};
ConstValue EvaluateScalarInit(SysYParser::ConstInitValContext& init,
DataType elem_type,
ConstEvalVisitor& evaluator) {
if (init.constExp()) {
return CastToType(evaluator.EvaluateConstExp(*init.constExp()), elem_type);
}
if (init.constInitVal().empty()) {
return MakeZeroValue(elem_type);
}
if (init.constInitVal().size() == 1) {
return EvaluateScalarInit(*init.constInitVal().front(), elem_type,
evaluator);
}
throw std::runtime_error(
FormatError("consteval", "标量初始化含有过多元素"));
}
void FillConstArrayObject(SysYParser::ConstInitValContext& init,
size_t depth, size_t base, size_t span,
const std::vector<int64_t>& dims, DataType elem_type,
std::vector<ConstValue>& out,
ConstEvalVisitor& evaluator) {
if (depth >= dims.size()) {
out[base] = EvaluateScalarInit(init, elem_type, evaluator);
return;
}
if (init.constExp()) {
out[base] = CastToType(evaluator.EvaluateConstExp(*init.constExp()),
elem_type);
return;
}
if (init.constInitVal().empty()) {
return;
}
const size_t end = base + span;
size_t cursor = base;
const size_t subspan = (depth + 1 < dims.size()) ? Product(dims, depth + 1)
: static_cast<size_t>(1);
for (auto* child : init.constInitVal()) {
if (!child) {
continue;
}
if (cursor >= end) {
throw std::runtime_error(
FormatError("consteval", "数组初始化元素过多"));
}
if (depth + 1 >= dims.size()) {
out[cursor] = EvaluateScalarInit(*child, elem_type, evaluator);
++cursor;
continue;
}
if (child->constExp()) {
out[cursor] = CastToType(evaluator.EvaluateConstExp(*child->constExp()),
elem_type);
++cursor;
continue;
}
const size_t rel = cursor - base;
if (subspan > 1 && rel % subspan != 0) {
cursor += (subspan - (rel % subspan));
}
if (cursor >= end) {
throw std::runtime_error(
FormatError("consteval", "数组初始化嵌套层级与维度不匹配"));
}
FillConstArrayObject(*child, depth + 1, cursor, subspan, dims, elem_type,
out, evaluator);
cursor += subspan;
}
}
} // namespace
ConstValue ConstValue::FromInt(int64_t value) {
ConstValue result;
result.type = SymbolDataType::Int;
result.int_value = value;
result.float_value = static_cast<double>(value);
result.bool_value = (value != 0);
return result;
}
ConstValue ConstValue::FromFloat(double value) {
ConstValue result;
result.type = SymbolDataType::Float;
result.int_value = static_cast<int64_t>(value);
result.float_value = value;
result.bool_value = (value != 0.0);
return result;
}
ConstValue ConstValue::FromBool(bool value) {
ConstValue result;
result.type = SymbolDataType::Bool;
result.int_value = value ? 1 : 0;
result.float_value = value ? 1.0 : 0.0;
result.bool_value = value;
return result;
}
bool ConstValue::IsScalar() const { return type != SymbolDataType::Unknown; }
bool ConstValue::IsNumeric() const {
return type == SymbolDataType::Int || type == SymbolDataType::Float ||
type == SymbolDataType::Bool;
}
int64_t ConstValue::AsInt() const {
if (type == SymbolDataType::Int) {
return int_value;
}
if (type == SymbolDataType::Float) {
return static_cast<int64_t>(float_value);
}
if (type == SymbolDataType::Bool) {
return bool_value ? 1 : 0;
}
throw std::runtime_error(FormatError("consteval", "当前值不能转为整数"));
}
double ConstValue::AsFloat() const {
if (type == SymbolDataType::Float) {
return float_value;
}
if (type == SymbolDataType::Int) {
return static_cast<double>(int_value);
}
if (type == SymbolDataType::Bool) {
return bool_value ? 1.0 : 0.0;
}
throw std::runtime_error(FormatError("consteval", "当前值不能转为浮点数"));
}
bool ConstValue::AsBool() const {
if (type == SymbolDataType::Bool) {
return bool_value;
}
if (type == SymbolDataType::Int) {
return int_value != 0;
}
if (type == SymbolDataType::Float) {
return float_value != 0.0;
}
throw std::runtime_error(FormatError("consteval", "当前值不能转为布尔值"));
}
ConstEvalContext::ConstEvalContext() { EnterScope(); }
void ConstEvalContext::EnterScope() { scopes_.emplace_back(); }
void ConstEvalContext::ExitScope() {
if (scopes_.size() <= 1) {
throw std::runtime_error("const eval scope underflow");
}
scopes_.pop_back();
}
bool ConstEvalContext::DefineScalar(const std::string& name, ConstValue value) {
if (scopes_.empty()) {
EnterScope();
}
auto& current = scopes_.back();
if (current.find(name) != current.end()) {
return false;
}
Binding binding;
binding.is_array = false;
binding.scalar = std::move(value);
current.emplace(name, std::move(binding));
return true;
}
bool ConstEvalContext::DefineArray(const std::string& name,
ConstArrayValue value) {
if (scopes_.empty()) {
EnterScope();
}
auto& current = scopes_.back();
if (current.find(name) != current.end()) {
return false;
}
Binding binding;
binding.is_array = true;
binding.array = std::move(value);
current.emplace(name, std::move(binding));
return true;
}
const ConstValue* ConstEvalContext::LookupScalar(const std::string& name) const {
const Binding* binding = LookupBinding(name);
if (!binding || binding->is_array) {
return nullptr;
}
return &binding->scalar;
}
const ConstArrayValue* ConstEvalContext::LookupArray(
const std::string& name) const {
const Binding* binding = LookupBinding(name);
if (!binding || !binding->is_array) {
return nullptr;
}
return &binding->array;
}
const ConstEvalContext::Binding* ConstEvalContext::LookupBinding(
const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}
ConstEvaluator::ConstEvaluator(const SymbolTable& table,
const ConstEvalContext& ctx)
: table_(table), ctx_(ctx) {}
ConstValue ConstEvaluator::EvaluateConstExp(
SysYParser::ConstExpContext& ctx) const {
ConstEvalVisitor visitor(table_, ctx_);
return visitor.EvaluateConstExp(ctx);
}
ConstValue ConstEvaluator::EvaluateExp(SysYParser::ExpContext& ctx) const {
ConstEvalVisitor visitor(table_, ctx_);
return visitor.EvaluateExp(ctx);
}
int64_t ConstEvaluator::EvaluateArrayDim(
SysYParser::ConstExpContext& ctx) const {
const ConstValue value = EvaluateConstExp(ctx);
if (!IsNumericType(value.type)) {
throw std::runtime_error(
FormatError("consteval", "数组维度必须是数值类型"));
}
if (value.type == DataType::Float) {
const double as_float = value.AsFloat();
if (std::trunc(as_float) != as_float) {
throw std::runtime_error(
FormatError("consteval", "数组维度必须是整数"));
}
}
const int64_t dim = value.AsInt();
if (dim <= 0) {
throw std::runtime_error(
FormatError("consteval", "数组维度必须是正整数"));
}
return dim;
}
std::vector<ConstValue> ConstEvaluator::EvaluateConstInitList(
SysYParser::ConstInitValContext& init, SymbolDataType elem_type,
const std::vector<int64_t>& dims) const {
if (elem_type != DataType::Int && elem_type != DataType::Float &&
elem_type != DataType::Bool) {
throw std::runtime_error(
FormatError("consteval", "仅支持标量类型的常量初始化"));
}
ConstEvalVisitor visitor(table_, ctx_);
if (dims.empty()) {
return {EvaluateScalarInit(init, elem_type, visitor)};
}
const size_t total = Product(dims, 0);
std::vector<ConstValue> flattened(total, MakeZeroValue(elem_type));
FillConstArrayObject(init, 0, 0, total, dims, elem_type, flattened, visitor);
return flattened;
}

@ -1,6 +1,927 @@
#include "sem/Sema.h"
#include "sem/Sema.h"
#include <any>
#include <cstddef>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include "SysYBaseVisitor.h"
#include "sem/ConstEval.h"
#include "utils/Log.h"
namespace {
using DataType = SymbolDataType;
std::string DataTypeName(DataType ty) {
switch (ty) {
case DataType::Void:
return "void";
case DataType::Int:
return "int";
case DataType::Float:
return "float";
case DataType::Bool:
return "bool";
case DataType::Unknown:
return "unknown";
}
return "unknown";
}
bool IsNumericType(DataType ty) {
return ty == DataType::Int || ty == DataType::Float || ty == DataType::Bool;
}
bool IsScalarType(DataType ty) {
return ty == DataType::Int || ty == DataType::Float || ty == DataType::Bool;
}
bool CanAssign(DataType dst, DataType src) {
if (dst == src) {
return true;
}
return IsNumericType(dst) && IsNumericType(src);
}
DataType ParseBType(SysYParser::BTypeContext& btype) {
if (btype.INT()) {
return DataType::Int;
}
if (btype.FLOAT()) {
return DataType::Float;
}
throw std::runtime_error(FormatError("sema", "非法基础类型"));
}
DataType ParseFuncType(SysYParser::FuncTypeContext& func_type) {
if (func_type.VOID()) {
return DataType::Void;
}
if (func_type.INT()) {
return DataType::Int;
}
if (func_type.FLOAT()) {
return DataType::Float;
}
throw std::runtime_error(FormatError("sema", "非法函数返回类型"));
}
std::string RequireIdent(antlr4::tree::TerminalNode* ident,
const std::string& message) {
if (!ident) {
throw std::runtime_error(FormatError("sema", message));
}
return ident->getText();
}
struct LValueInfo {
const SymbolEntry* symbol = nullptr;
DataType value_type = DataType::Unknown;
size_t index_count = 0;
bool fully_indexed = true;
};
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
// 先注册所有函数签名,支持函数前向调用与递归调用。
for (auto* func : ctx->funcDef()) {
if (func) {
PredeclareFunction(*func);
}
}
// 再按源码顺序执行语义检查。
for (auto* child : ctx->children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
decl->accept(this);
continue;
}
if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
func->accept(this);
}
}
if (!has_main_function_) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
return {};
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法声明节点"));
}
if (ctx->constDecl()) {
ctx->constDecl()->accept(this);
return {};
}
if (ctx->varDecl()) {
ctx->varDecl()->accept(this);
return {};
}
throw std::runtime_error(FormatError("sema", "暂不支持的声明类型"));
}
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法常量声明"));
}
const DataType saved_type = current_decl_type_;
const bool saved_const = current_decl_is_const_;
current_decl_type_ = ParseBType(*ctx->bType());
current_decl_is_const_ = true;
for (auto* def : ctx->constDef()) {
if (!def) {
throw std::runtime_error(FormatError("sema", "常量定义为空"));
}
def->accept(this);
}
current_decl_type_ = saved_type;
current_decl_is_const_ = saved_const;
return {};
}
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
const DataType saved_type = current_decl_type_;
const bool saved_const = current_decl_is_const_;
current_decl_type_ = ParseBType(*ctx->bType());
current_decl_is_const_ = false;
for (auto* def : ctx->varDef()) {
if (!def) {
throw std::runtime_error(FormatError("sema", "变量定义为空"));
}
def->accept(this);
}
current_decl_type_ = saved_type;
current_decl_is_const_ = saved_const;
return {};
}
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override {
if (!ctx || !ctx->constInitVal()) {
throw std::runtime_error(FormatError("sema", "非法常量定义"));
}
const std::string name = RequireIdent(ctx->Ident(), "常量定义缺少名称");
if (table_.ContainsCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义符号: " + name));
}
ConstEvaluator const_evaluator(table_, const_values_);
std::vector<int64_t> dims;
dims.reserve(ctx->constExp().size());
for (auto* dim_exp : ctx->constExp()) {
dims.push_back(const_evaluator.EvaluateArrayDim(*dim_exp));
}
const std::vector<ConstValue> init_values =
const_evaluator.EvaluateConstInitList(*ctx->constInitVal(),
current_decl_type_, dims);
if (init_values.empty()) {
throw std::runtime_error(
FormatError("sema", "常量初始化结果不能为空: " + name));
}
const std::vector<int64_t> dims_copy = dims;
SymbolEntry symbol;
symbol.name = name;
symbol.kind = SymbolKind::Constant;
symbol.data_type = current_decl_type_;
symbol.is_const = true;
symbol.is_global = (current_function_ == nullptr);
symbol.has_initializer = true;
symbol.is_array = !dims.empty();
symbol.array_dims = std::move(dims);
if (dims_copy.empty()) {
symbol.has_constexpr_value = true;
if (current_decl_type_ == DataType::Float) {
symbol.const_float_value = init_values.front().AsFloat();
} else {
symbol.const_int_value = init_values.front().AsInt();
}
} else {
if (current_decl_type_ == DataType::Float) {
symbol.const_float_init.reserve(init_values.size());
for (const auto& value : init_values) {
symbol.const_float_init.push_back(value.AsFloat());
}
} else {
symbol.const_int_init.reserve(init_values.size());
for (const auto& value : init_values) {
symbol.const_int_init.push_back(value.AsInt());
}
}
}
symbol.decl_ctx = ctx;
SymbolEntry* stable_symbol = sema_.RegisterSymbol(std::move(symbol));
if (!table_.Insert(stable_symbol)) {
throw std::runtime_error(FormatError("sema", "重复定义符号: " + name));
}
if (dims_copy.empty()) {
if (!const_values_.DefineScalar(name, init_values.front())) {
throw std::runtime_error(
FormatError("sema", "常量环境重复定义: " + name));
}
} else {
ConstArrayValue array_value;
array_value.elem_type = current_decl_type_;
array_value.dims = dims_copy;
array_value.elements = init_values;
if (!const_values_.DefineArray(name, std::move(array_value))) {
throw std::runtime_error(
FormatError("sema", "常量环境重复定义: " + name));
}
}
return {};
}
std::any visitVarDef(SysYParser::VarDefContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量定义"));
}
const std::string name = RequireIdent(ctx->Ident(), "变量定义缺少名称");
if (table_.ContainsCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义符号: " + name));
}
std::vector<int64_t> dims;
dims.reserve(ctx->constExp().size());
ConstEvaluator const_evaluator(table_, const_values_);
for (auto* dim_exp : ctx->constExp()) {
dims.push_back(const_evaluator.EvaluateArrayDim(*dim_exp));
}
SymbolEntry symbol;
symbol.name = name;
symbol.kind = current_decl_is_const_ ? SymbolKind::Constant
: SymbolKind::Variable;
symbol.data_type = current_decl_type_;
symbol.is_const = current_decl_is_const_;
symbol.is_global = (current_function_ == nullptr);
symbol.has_initializer = (ctx->initVal() != nullptr);
symbol.is_array = !dims.empty();
symbol.array_dims = std::move(dims);
symbol.decl_ctx = ctx;
SymbolEntry* stable_symbol = sema_.RegisterSymbol(std::move(symbol));
if (!table_.Insert(stable_symbol)) {
throw std::runtime_error(FormatError("sema", "重复定义符号: " + name));
}
if (ctx->initVal()) {
CheckInitValue(*ctx->initVal(), current_decl_type_);
}
return {};
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->Ident() || !ctx->block()) {
throw std::runtime_error(FormatError("sema", "非法函数定义"));
}
const std::string name = ctx->Ident()->getText();
const SymbolEntry* symbol = table_.LookupCurrentScope(name);
if (!symbol || symbol->kind != SymbolKind::Function) {
throw std::runtime_error(FormatError("sema", "函数签名未注册: " + name));
}
const SymbolEntry* saved_function = current_function_;
const bool saved_has_return = current_function_has_return_;
current_function_ = symbol;
current_function_has_return_ = false;
BuildPendingParams(*ctx);
ctx->block()->accept(this);
if (current_function_->data_type != DataType::Void &&
!current_function_has_return_) {
throw std::runtime_error(
FormatError("sema", "非 void 函数缺少 return: " + name));
}
current_function_ = saved_function;
current_function_has_return_ = saved_has_return;
return {};
}
std::any visitBlock(SysYParser::BlockContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法语句块"));
}
table_.EnterScope();
const_values_.EnterScope();
InjectPendingParamsIntoCurrentScope();
for (auto* item : ctx->blockItem()) {
if (item) {
item->accept(this);
}
}
const_values_.ExitScope();
table_.ExitScope();
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 blockItem"));
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
}
throw std::runtime_error(FormatError("sema", "非法 blockItem 结构"));
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句"));
}
if (ctx->ASSIGN()) {
if (!ctx->lVal() || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法赋值语句"));
}
LValueInfo lvalue = ResolveLValue(*ctx->lVal());
if (lvalue.symbol->is_const) {
throw std::runtime_error(
FormatError("sema", "不能给常量赋值: " + lvalue.symbol->name));
}
if (lvalue.symbol->is_array && !lvalue.fully_indexed) {
throw std::runtime_error(
FormatError("sema", "数组变量需要完整下标后才能赋值: " +
lvalue.symbol->name));
}
DataType rhs_ty = EvalType(*ctx->exp());
EnsureAssignable("赋值", lvalue.value_type, rhs_ty);
return {};
}
if (ctx->RETURN()) {
if (!current_function_) {
throw std::runtime_error(FormatError("sema", "return 不在函数体内"));
}
if (current_function_->data_type == DataType::Void) {
if (ctx->exp()) {
throw std::runtime_error(
FormatError("sema", "void 函数不应返回表达式"));
}
} else {
if (!ctx->exp()) {
throw std::runtime_error(
FormatError("sema", "非 void 函数必须返回表达式"));
}
DataType ret_ty = EvalType(*ctx->exp());
EnsureAssignable("return", current_function_->data_type, ret_ty);
}
current_function_has_return_ = true;
return {};
}
if (ctx->BREAK()) {
if (loop_depth_ <= 0) {
throw std::runtime_error(
FormatError("sema", "break 只能出现在循环体内"));
}
return {};
}
if (ctx->CONTINUE()) {
if (loop_depth_ <= 0) {
throw std::runtime_error(
FormatError("sema", "continue 只能出现在循环体内"));
}
return {};
}
if (ctx->IF()) {
if (!ctx->cond() || ctx->stmt().empty()) {
throw std::runtime_error(FormatError("sema", "非法 if 语句"));
}
DataType cond_ty = EvalType(*ctx->cond());
if (!IsScalarType(cond_ty)) {
throw std::runtime_error(
FormatError("sema", "if 条件表达式必须是标量类型"));
}
ctx->stmt(0)->accept(this);
if (ctx->ELSE()) {
if (ctx->stmt().size() < 2 || !ctx->stmt(1)) {
throw std::runtime_error(FormatError("sema", "非法 else 分支"));
}
ctx->stmt(1)->accept(this);
}
return {};
}
if (ctx->WHILE()) {
if (!ctx->cond() || ctx->stmt().empty() || !ctx->stmt(0)) {
throw std::runtime_error(FormatError("sema", "非法 while 语句"));
}
DataType cond_ty = EvalType(*ctx->cond());
if (!IsScalarType(cond_ty)) {
throw std::runtime_error(
FormatError("sema", "while 条件表达式必须是标量类型"));
}
++loop_depth_;
ctx->stmt(0)->accept(this);
--loop_depth_;
return {};
}
if (ctx->block()) {
ctx->block()->accept(this);
return {};
}
// exp? ';',包含空语句与表达式语句。
if (ctx->exp()) {
EvalType(*ctx->exp());
}
return {};
}
std::any visitExp(SysYParser::ExpContext* ctx) override {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "非法表达式"));
}
return EvalType(*ctx->addExp());
}
std::any visitCond(SysYParser::CondContext* ctx) override {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("sema", "非法条件表达式"));
}
return EvalType(*ctx->lOrExp());
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法左值"));
}
LValueInfo info = ResolveLValue(*ctx);
if (info.symbol->is_array && !info.fully_indexed) {
// 数组名/数组切片在当前最小实现中视为非常量标量类型,
// 仅允许作为数组形参传递,不允许参与算术表达式。
return DataType::Unknown;
}
return info.value_type;
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 primaryExp"));
}
if (ctx->exp()) {
return EvalType(*ctx->exp());
}
if (ctx->lVal()) {
return EvalType(*ctx->lVal());
}
if (ctx->number()) {
return EvalType(*ctx->number());
}
throw std::runtime_error(FormatError("sema", "无法识别的 primaryExp"));
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 number"));
}
if (ctx->IntConst()) {
return DataType::Int;
}
if (ctx->FloatConst()) {
return DataType::Float;
}
throw std::runtime_error(FormatError("sema", "非法数字常量"));
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 unaryExp"));
}
if (ctx->primaryExp()) {
return EvalType(*ctx->primaryExp());
}
if (ctx->Ident()) {
const std::string callee_name = ctx->Ident()->getText();
const SymbolEntry* callee = table_.Lookup(callee_name);
if (!callee || callee->kind != SymbolKind::Function) {
throw std::runtime_error(
FormatError("sema", "调用了未定义函数: " + callee_name));
}
std::vector<SysYParser::ExpContext*> args;
if (ctx->funcRParams()) {
args = ctx->funcRParams()->exp();
}
if (args.size() != callee->param_types.size()) {
throw std::runtime_error(
FormatError("sema", "函数参数数量不匹配: " + callee_name));
}
for (size_t i = 0; i < args.size(); ++i) {
DataType arg_ty = EvalType(*args[i]);
const bool need_array =
i < callee->param_is_array.size() && callee->param_is_array[i];
if (need_array) {
if (arg_ty != DataType::Unknown) {
throw std::runtime_error(
FormatError("sema",
"数组形参需要数组实参: " + callee_name));
}
continue;
}
if (arg_ty == DataType::Unknown) {
throw std::runtime_error(
FormatError("sema",
"标量形参不接受数组实参: " + callee_name));
}
EnsureAssignable("函数参数", callee->param_types[i], arg_ty);
}
sema_.BindCallUse(ctx, callee);
return callee->data_type;
}
if (ctx->unaryOp() && ctx->unaryExp()) {
const DataType operand_ty = EvalType(*ctx->unaryExp());
if (ctx->unaryOp()->NOT()) {
if (!IsScalarType(operand_ty)) {
throw std::runtime_error(
FormatError("sema", "逻辑非运算只支持标量类型"));
}
return DataType::Bool;
}
if (!IsNumericType(operand_ty)) {
throw std::runtime_error(
FormatError("sema", "一元加减只支持数值类型"));
}
return operand_ty == DataType::Float ? DataType::Float : DataType::Int;
}
throw std::runtime_error(FormatError("sema", "非法 unaryExp 结构"));
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 mulExp"));
}
if (!ctx->mulExp()) {
return EvalType(*ctx->unaryExp());
}
DataType lhs = EvalType(*ctx->mulExp());
DataType rhs = EvalType(*ctx->unaryExp());
if (!IsNumericType(lhs) || !IsNumericType(rhs)) {
throw std::runtime_error(
FormatError("sema", "乘除模运算只支持数值类型"));
}
return PromoteArithmetic(lhs, rhs);
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 addExp"));
}
if (!ctx->addExp()) {
return EvalType(*ctx->mulExp());
}
DataType lhs = EvalType(*ctx->addExp());
DataType rhs = EvalType(*ctx->mulExp());
if (!IsNumericType(lhs) || !IsNumericType(rhs)) {
throw std::runtime_error(
FormatError("sema", "加减运算只支持数值类型"));
}
return PromoteArithmetic(lhs, rhs);
}
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 relExp"));
}
if (!ctx->relExp()) {
return EvalType(*ctx->addExp());
}
DataType lhs = EvalType(*ctx->relExp());
DataType rhs = EvalType(*ctx->addExp());
if (!IsNumericType(lhs) || !IsNumericType(rhs)) {
throw std::runtime_error(
FormatError("sema", "关系比较只支持数值类型"));
}
return DataType::Bool;
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 eqExp"));
}
if (!ctx->eqExp()) {
return EvalType(*ctx->relExp());
}
DataType lhs = EvalType(*ctx->eqExp());
DataType rhs = EvalType(*ctx->relExp());
if (!IsScalarType(lhs) || !IsScalarType(rhs)) {
throw std::runtime_error(
FormatError("sema", "相等比较只支持标量类型"));
}
return DataType::Bool;
}
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 lAndExp"));
}
if (!ctx->lAndExp()) {
return EvalType(*ctx->eqExp());
}
DataType lhs = EvalType(*ctx->lAndExp());
DataType rhs = EvalType(*ctx->eqExp());
if (!IsScalarType(lhs) || !IsScalarType(rhs)) {
throw std::runtime_error(
FormatError("sema", "逻辑与只支持标量类型"));
}
return DataType::Bool;
}
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法 lOrExp"));
}
if (!ctx->lOrExp()) {
return EvalType(*ctx->lAndExp());
}
DataType lhs = EvalType(*ctx->lOrExp());
DataType rhs = EvalType(*ctx->lAndExp());
if (!IsScalarType(lhs) || !IsScalarType(rhs)) {
throw std::runtime_error(
FormatError("sema", "逻辑或只支持标量类型"));
}
return DataType::Bool;
}
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "非法 constExp"));
}
DataType ty = EvalType(*ctx->addExp());
if (!IsNumericType(ty)) {
throw std::runtime_error(FormatError("sema", "constExp 必须是数值类型"));
}
return ty;
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
struct PendingParam {
std::string name;
DataType type = DataType::Unknown;
bool is_array = false;
std::vector<int64_t> dims;
const antlr4::ParserRuleContext* decl_ctx = nullptr;
};
SymbolEntry* PredeclareFunction(SysYParser::FuncDefContext& ctx) {
if (!ctx.funcType() || !ctx.Ident()) {
throw std::runtime_error(FormatError("sema", "非法函数定义"));
}
const std::string name = ctx.Ident()->getText();
if (table_.ContainsCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义符号: " + name));
}
SymbolEntry symbol;
symbol.name = name;
symbol.kind = SymbolKind::Function;
symbol.data_type = ParseFuncType(*ctx.funcType());
symbol.is_const = true;
symbol.is_global = true;
symbol.decl_ctx = &ctx;
if (ctx.funcFParams()) {
for (auto* param : ctx.funcFParams()->funcFParam()) {
if (!param || !param->bType()) {
throw std::runtime_error(FormatError("sema", "非法函数形参定义"));
}
symbol.param_types.push_back(ParseBType(*param->bType()));
symbol.param_is_array.push_back(!param->LBRACK().empty());
}
}
SymbolEntry* stable_symbol = sema_.RegisterSymbol(std::move(symbol));
if (!table_.Insert(stable_symbol)) {
throw std::runtime_error(FormatError("sema", "重复定义符号: " + name));
}
if (name == "main") {
has_main_function_ = true;
if (stable_symbol->data_type != DataType::Int) {
throw std::runtime_error(FormatError("sema", "main 函数返回类型必须为 int"));
}
if (!stable_symbol->param_types.empty()) {
throw std::runtime_error(FormatError("sema", "main 函数不应包含形参"));
}
}
return stable_symbol;
}
void BuildPendingParams(SysYParser::FuncDefContext& ctx) {
pending_params_.clear();
inject_params_in_next_block_ = false;
if (!ctx.funcFParams()) {
return;
}
for (auto* param : ctx.funcFParams()->funcFParam()) {
if (!param || !param->bType()) {
throw std::runtime_error(FormatError("sema", "非法函数形参定义"));
}
PendingParam info;
info.name = RequireIdent(param->Ident(), "函数形参缺少名称");
info.type = ParseBType(*param->bType());
info.is_array = !param->LBRACK().empty();
info.decl_ctx = param;
if (info.is_array) {
info.dims.push_back(-1);
for (auto* dim_exp : param->exp()) {
DataType dim_ty = EvalType(*dim_exp);
if (!IsNumericType(dim_ty)) {
throw std::runtime_error(
FormatError("sema", "函数形参数组维度必须是数值类型表达式"));
}
info.dims.push_back(-1);
}
}
pending_params_.push_back(std::move(info));
}
inject_params_in_next_block_ = true;
}
void InjectPendingParamsIntoCurrentScope() {
if (!inject_params_in_next_block_) {
return;
}
for (const auto& param : pending_params_) {
SymbolEntry symbol;
symbol.name = param.name;
symbol.kind = SymbolKind::Parameter;
symbol.data_type = param.type;
symbol.is_const = false;
symbol.is_global = false;
symbol.is_array = param.is_array;
symbol.array_dims = param.dims;
symbol.decl_ctx = param.decl_ctx;
SymbolEntry* stable_symbol = sema_.RegisterSymbol(std::move(symbol));
if (!table_.Insert(stable_symbol)) {
throw std::runtime_error(FormatError("sema", "重复定义符号: " +
param.name));
}
}
pending_params_.clear();
inject_params_in_next_block_ = false;
}
DataType EvalType(antlr4::ParserRuleContext& ctx) {
std::any result = ctx.accept(this);
if (!result.has_value()) {
return DataType::Unknown;
}
try {
return std::any_cast<DataType>(result);
} catch (const std::bad_any_cast&) {
throw std::runtime_error(FormatError("sema", "表达式类型推导失败"));
}
}
void EnsureAssignable(const std::string& scene, DataType dst, DataType src) {
if (!CanAssign(dst, src)) {
throw std::runtime_error(FormatError(
"sema", scene + " 类型不匹配: 期望 " + DataTypeName(dst) +
",实际 " + DataTypeName(src)));
}
}
DataType PromoteArithmetic(DataType lhs, DataType rhs) {
if (lhs == DataType::Float || rhs == DataType::Float) {
return DataType::Float;
}
return DataType::Int;
}
LValueInfo ResolveLValue(SysYParser::LValContext& ctx) {
const std::string name = RequireIdent(ctx.Ident(), "左值缺少标识符");
const SymbolEntry* symbol = table_.Lookup(name);
if (!symbol) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
if (symbol->kind == SymbolKind::Function) {
throw std::runtime_error(
FormatError("sema", "函数名不能作为左值: " + name));
}
const size_t index_count = ctx.exp().size();
for (auto* index_exp : ctx.exp()) {
DataType index_ty = EvalType(*index_exp);
if (!IsNumericType(index_ty)) {
throw std::runtime_error(
FormatError("sema", "数组下标必须是数值类型: " + name));
}
}
if (index_count > 0 && !symbol->is_array) {
throw std::runtime_error(
FormatError("sema", "普通变量不能使用下标: " + name));
}
if (symbol->is_array && index_count > symbol->array_dims.size()) {
throw std::runtime_error(
FormatError("sema", "数组下标维度越界: " + name));
}
sema_.BindLValUse(&ctx, symbol);
LValueInfo info;
info.symbol = symbol;
info.value_type = symbol->data_type;
info.index_count = index_count;
info.fully_indexed = !symbol->is_array ||
index_count == symbol->array_dims.size();
return info;
}
void CheckInitValue(SysYParser::InitValContext& ctx, DataType expected_type) {
if (ctx.exp()) {
DataType init_ty = EvalType(*ctx.exp());
EnsureAssignable("变量初始化", expected_type, init_ty);
return;
}
for (auto* nested : ctx.initVal()) {
if (nested) {
CheckInitValue(*nested, expected_type);
}
}
}
SymbolTable table_;
ConstEvalContext const_values_;
SemanticContext sema_;
DataType current_decl_type_ = DataType::Unknown;
bool current_decl_is_const_ = false;
int loop_depth_ = 0;
bool has_main_function_ = false;
const SymbolEntry* current_function_ = nullptr;
bool current_function_has_return_ = false;
bool inject_params_in_next_block_ = false;
std::vector<PendingParam> pending_params_;
};
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
(void)comp_unit;
return SemanticContext{};
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
}

@ -1,43 +1,55 @@
#include "sem/SymbolTable.h"
// 维护局部变量声明的注册与查找。
void SymbolTable::Clear() { scopes_.clear(); }
#include "sem/SymbolTable.h"
#include <stdexcept>
SymbolTable::SymbolTable() { EnterScope(); }
void SymbolTable::EnterScope() { scopes_.emplace_back(); }
void SymbolTable::ExitScope() {
if (!scopes_.empty()) {
scopes_.pop_back();
if (scopes_.size() <= 1) {
throw std::runtime_error("symbol table scope underflow");
}
scopes_.pop_back();
}
bool SymbolTable::Insert(const std::string& name, const SymbolEntry& entry) {
bool SymbolTable::Insert(const SymbolEntry* symbol) {
if (!symbol) {
return false;
}
if (scopes_.empty()) {
EnterScope();
}
auto& scope = scopes_.back();
return scope.emplace(name, entry).second;
auto& current = scopes_.back();
auto [it, inserted] = current.emplace(symbol->name, symbol);
return inserted;
}
bool SymbolTable::ContainsInCurrentScope(const std::string& name) const {
return !scopes_.empty() && scopes_.back().find(name) != scopes_.back().end();
bool SymbolTable::Contains(const std::string& name) const {
return Lookup(name) != nullptr;
}
SymbolEntry* SymbolTable::Lookup(const std::string& name) {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
bool SymbolTable::ContainsCurrentScope(const std::string& name) const {
return LookupCurrentScope(name) != nullptr;
}
const SymbolEntry* SymbolTable::Lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
auto symbol_it = it->find(name);
if (symbol_it != it->end()) {
return symbol_it->second;
}
}
return nullptr;
}
const SymbolEntry* SymbolTable::LookupCurrentScope(
const std::string& name) const {
if (scopes_.empty()) {
return nullptr;
}
auto it = scopes_.back().find(name);
return it == scopes_.back().end() ? nullptr : it->second;
}

@ -1,81 +1,4 @@
#include "sylib.h"
// SysY 运行库实现:
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为
#include <stdio.h>
#include <stdlib.h>
static int read_char_normalized(void) {
int ch = getchar();
if (ch == '\r') {
int next = getchar();
if (next != '\n' && next != EOF) {
ungetc(next, stdin);
}
return '\n';
}
return ch;
}
static float read_float_token(void) {
char buffer[256];
if (scanf("%255s", buffer) != 1) {
return 0.0f;
}
return strtof(buffer, NULL);
}
int getint(void) {
int value = 0;
if (scanf("%d", &value) != 1) {
return 0;
}
return value;
}
int getch(void) {
int ch = read_char_normalized();
return ch == EOF ? 0 : ch;
}
float getfloat(void) { return read_float_token(); }
int getarray(int a[]) {
int n = getint();
for (int i = 0; i < n; ++i) {
a[i] = getint();
}
return n;
}
int getfarray(float a[]) {
int n = getint();
for (int i = 0; i < n; ++i) {
a[i] = getfloat();
}
return n;
}
void putint(int x) { printf("%d", x); }
void putch(int x) { putchar(x); }
void putfloat(float x) { printf("%a", (double)x); }
void putarray(int n, const int a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %d", a[i]);
}
putchar('\n');
}
void putfarray(int n, const float a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %a", (double)a[i]);
}
putchar('\n');
}
void starttime(void) {}
void stoptime(void) {}

@ -1,17 +1,4 @@
#ifndef SYLIB_H_
#define SYLIB_H_
// SysY 运行库头文件:
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明
int getint(void);
int getch(void);
float getfloat(void);
int getarray(int a[]);
int getfarray(float a[]);
void putint(int x);
void putch(int x);
void putfloat(float x);
void putarray(int n, const int a[]);
void putfarray(int n, const float a[]);
void starttime(void);
void stoptime(void);
#endif

@ -1,9 +0,0 @@
int main(){
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
const int N = 3;
int b[4][2] = {};
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int d[N + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
}

@ -1,3 +0,0 @@
int main(){
return 3;
}

@ -1,8 +0,0 @@
//test domain of global var define and local define
int a = 3;
int b = 5;
int main(){
int a = 5;
return a + b;
}

@ -1,8 +0,0 @@
//test local var define
int main(){
int a, b0, _c;
a = 1;
b0 = 2;
_c = 3;
return b0 + _c;
}

@ -1,4 +0,0 @@
int a[10][10];
int main(){
return 0;
}

@ -1,9 +0,0 @@
//test array define
int main(){
int a[4][2] = {};
int b[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int c[4][2] = {{1, 2}, {3, 4}, {5, 6}, {7, 8}};
int d[4][2] = {1, 2, {3}, {5}, 7 , 8};
int e[4][2] = {{d[2][1], c[2][1]}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1] + e[0][0] + e[0][1] + a[2][0];
}

@ -1,9 +1,9 @@
int main(){
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
const int N = 3;
int b[4][2] = {};
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int d[3 + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int d[N + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
}

@ -1,6 +0,0 @@
//test const gloal var define
const int a = 10, b = 5;
int main(){
return b;
}

@ -1,5 +0,0 @@
//test const local var define
int main(){
const int a = 10, b = 5;
return b;
}

@ -1,5 +0,0 @@
const int a[5]={0,1,2,3,4};
int main(){
return a[4];
}

@ -1,8 +0,0 @@
int defn(){
return 4;
}
int main(){
int a=defn();
return a;
}

@ -1,5 +0,0 @@
//test addc
const int a = 10;
int main(){
return a + 5;
}

@ -1,6 +0,0 @@
//test subc
int main(){
int a;
a = 10;
return a - 2;
}

@ -1,7 +0,0 @@
//test mul
int main(){
int a, b;
a = 10;
b = 5;
return a * b;
}

@ -1,5 +0,0 @@
//test mulc
const int a = 5;
int main(){
return a * 5;
}

@ -1,7 +0,0 @@
//test div
int main(){
int a, b;
a = 10;
b = 5;
return a / b;
}

@ -1,5 +0,0 @@
//test divc
const int a = 10;
int main(){
return a / 5;
}

@ -1,6 +0,0 @@
//test mod
int main(){
int a;
a = 10;
return a / 3;
}

@ -1,6 +0,0 @@
//test rem
int main(){
int a;
a = 10;
return a % 3;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save