Compare commits

..

2 Commits

3
.gitignore vendored

@ -73,5 +73,4 @@ test/test_result/
# mxr
# =========================
result.txt
build.sh
gdb.sh
build.sh

@ -0,0 +1,336 @@
# Lab4 工作记录:基本标量优化
本文记录本次 Lab4 中完成的优化 pass、关键实现思路、遇到的问题和测试脚本的使用方式。
## 1. 完成内容概览
本次主要完成并接入了如下内容:
- `Mem2Reg`:将可提升的局部标量变量从 `alloca/load/store` 形式提升到 SSA 形式。
- `ConstFold`:对常量表达式进行编译期求值。
- `ConstProp`:做简单常量传播和代数化简。
- `DCE`:删除没有 use 且没有副作用的死指令。
- `verify_mem2reg.sh`:批量验证 `test/` 下所有用例是否能完成现有 pass 优化,并可选运行语义回归。
当前优化流水线接在 `src/irgen/IRGenDriver.cpp` 中,顺序为:
```text
Mem2Reg -> ConstFold -> ConstProp -> ConstFold -> DCE
```
其中第二次 `ConstFold` 用于吃掉 `ConstProp` 暴露出来的新常量表达式,最后的 `DCE` 清理被替换后不再使用的指令。
## 2. Mem2Reg 做了什么
前端生成 IR 时,局部变量通常先表示成内存形式:
```llvm
%x = alloca i32
store i32 1, i32* %x
%v = load i32, i32* %x
```
这种形式语义直接,但会让后续优化很难判断 `%v` 到底是什么值。`Mem2Reg` 的作用是把这类可提升变量改写成 SSA value。
例如:
```c
int x;
if (cond) {
x = 1;
} else {
x = 2;
}
return x;
```
提升后核心 IR 会变成:
```llvm
merge:
%x.merge.phi0 = phi i32 [1, %then], [2, %else]
ret i32 %x.merge.phi0
```
`phi` 表示“根据控制流从哪个前驱块来,选择对应的值”。
### 2.1 可提升对象
本实验里的 `Mem2Reg` 只提升局部标量 alloca
- `i32*`
- `float*`
- `i1*`
并且要求它们只被直接 `load/store` 使用。
以下情况不会提升:
- 数组 alloca例如 `[100 x i32]`
- 通过 `getelementptr` 复杂访问的内存
- 地址传给函数的变量
- 全局变量
- 其他地址逃逸的变量
所以测试中仍看到数组相关 `alloca` 是正常的。`mem2reg` 不是把所有内存都消掉,而是把可以安全转成 SSA 的局部标量消掉。
### 2.2 核心算法
实现流程如下:
1. 扫描入口块,找到可提升的 `alloca`
2. 收集该变量所有 `store` 所在的定义块。
3. 构建 CFG并计算支配树和支配边界。
4. 在支配边界处插入 `phi`
5. 沿支配树递归重命名:
- 遇到 `store`,更新当前变量值。
- 遇到 `load`,用当前变量值替换该 `load`
- 遍历后继块时,给后继块中的 `phi` 填 incoming value。
6. 删除被提升掉的 `alloca/load/store`
### 2.3 修过的问题
实现过程中修了一个关键问题:多个 phi 结果重名。
原来 phi 名字类似:
```text
变量名.phi
```
复杂循环里同一个变量可能需要多个 phi导致 LLVM 报:
```text
multiple definition of local value named '...phi'
```
现在 phi 名字包含变量名、基本块名和递增编号,例如:
```text
%t45_i.while.cond.t72.phi3
```
这样可以保证同一个函数内 SSA 名字唯一。
## 3. 常量折叠与常量传播
### 3.1 ConstFold
`ConstFold` 会把操作数都是常量的指令直接计算出来,并用常量替换原指令。
目前支持:
- 整数运算:`add/sub/mul/div/mod/and/or`
- 浮点运算:`fadd/fsub/fmul/fdiv`
- 整数比较:`icmp`
- 浮点比较:`fcmp`
- 类型转换:`zext/trunc/sitofp/fptosi`
- 所有 incoming 都是同一常量的简单 `phi`
例如:
```llvm
%t = add i32 20, 4
ret i32 %t
```
会变成:
```llvm
ret i32 24
```
### 3.2 ConstProp
`ConstProp` 主要做简单代数化简和传播:
- `x + 0 -> x`
- `x - 0 -> x`
- `x * 1 -> x`
- `x * 0 -> 0`
- `x / 1 -> x`
- `0 / x -> 0`
- `phi` 所有有效 incoming 相同,则替换为该值
它不做复杂全局数据流分析,目标是配合 `Mem2Reg` 暴露出来的 SSA 值,吃掉一些明显冗余表达式。
## 4. DCE
`DCE` 删除无副作用且没有 use 的指令。
保留的有副作用或控制流指令包括:
- `store`
- `ret`
- `call`
- `br`
- `condbr`
优化后被常量替换掉的二元运算、比较、转换指令,如果不再被使用,会被 DCE 清掉。
## 5. 测试脚本设计
新增或重写的脚本:
```text
scripts/verify_mem2reg.sh
```
这个脚本不再只测试 `test/test_case/mem2reg`,而是默认扫描整个 `test/` 目录下所有 `.sy` 文件。
脚本分三层验证。
### 5.1 IR 生成检查
第一层检查每个 `.sy` 是否能完成:
```bash
./build/bin/compiler --emit-ir xxx.sy
```
如果能生成包含 `define` 的 IR说明前端和当前 pass 流水线都跑完了。
### 5.2 优化结果检查
第二层检查优化后 IR 中是否还有标量 alloca
```llvm
%x = alloca i32
%y = alloca float
%b = alloca i1
```
默认情况下,残留标量 alloca 只作为 warning不直接判失败。原因是不是所有 alloca 都一定能安全提升,尤其在复杂数组、地址使用、函数调用附近,保守处理是合理的。
如果希望更严格,可以使用:
```bash
./scripts/verify_mem2reg.sh --strict-mem2reg
```
这会把残留标量 alloca 当成失败。
### 5.3 运行语义回归
第三层需要手动打开:
```bash
./scripts/verify_mem2reg.sh --run
```
它会执行:
1. 生成优化后 IR。
2. 用 `llc``.ll` 转成目标文件。
3. 用 `clang` 链接目标文件。
4. 如果存在 `sylib/sylib.c`,会先编译并链接运行库。
5. 自动读取同名 `.in` 作为输入。
6. 将程序 stdout 和退出码拼成 actual 结果。
7. 与同名 `.out` 对比。
脚本比较时会统一处理:
- Windows 风格换行 `\r\n`
- 文件末尾是否多一个换行
这样可以避免因为文本格式差异造成误报。
## 6. 测试脚本用法
### 6.1 构建项目
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j "$(nproc)"
```
### 6.2 只检查 pass 能否跑完
```bash
./scripts/verify_mem2reg.sh
```
默认会扫描:
```text
test/
```
输出示例:
```text
IR 生成: 22 / 22
Pass 优化检查: 22 / 22
全部检查通过。
```
### 6.3 同时运行语义回归
```bash
./scripts/verify_mem2reg.sh --run
```
输出示例:
```text
IR 生成: 22 / 22
Pass 优化检查: 22 / 22
运行结果: 22 / 22
全部检查通过。
```
### 6.4 只测试某个目录
```bash
./scripts/verify_mem2reg.sh --test-root test/test_case/functional --run
```
### 6.5 打印详细信息
```bash
./scripts/verify_mem2reg.sh --debug --run
```
### 6.6 严格检查 mem2reg
```bash
./scripts/verify_mem2reg.sh --strict-mem2reg
```
这个模式适合专门检查“还有哪些标量 alloca 没被提升”。当前某些复杂性能样例会有 warning是否要继续优化要结合 IR 使用情况判断。
## 7. 当前测试结论
当前执行:
```bash
./scripts/verify_mem2reg.sh --run
```
结果为:
```text
IR 生成: 22 / 22
Pass 优化检查: 22 / 22
运行结果: 22 / 22
全部检查通过。
```
这说明:
- 所有测试都能完成当前 pass 流水线。
- 生成的 IR 能被 LLVM 工具链接受。
- 链接运行库后,程序输出和退出码均与 `.out` 匹配。
脚本中出现的标量 alloca warning 不影响当前语义正确性,它们只是提示后续还有进一步提升或更精细逃逸分析的空间。
## 8. 后续可改进方向
后续如果继续扩展 Lab4可以考虑
- 为 `Mem2Reg` 增加更完整的可提升性分析。
- 对未初始化 load 做更稳健的默认值处理。
- 增加 CFG Simplify删除常量条件分支和不可达块。
- 增加 CSE消除重复表达式。
- 将常量折叠和传播做成迭代到不动点。
- 增加 IR verifier提前检查 phi incoming 数量、SSA 名字唯一性、基本块前驱匹配等问题。

@ -333,6 +333,7 @@ enum class Opcode {
FPToSI, // 浮点转整数
FPExt, // 浮点扩展
FPTrunc, // 浮点截断
Phi,
};
// ZExt 和 Trunc 是零扩展和截断指令,SysY 的 int (i32) vs LLVM IR 的比较结果 (i1)。
@ -567,6 +568,56 @@ class BranchInst : public Instruction {
BasicBlock* false_target_; // 假分支目标(条件跳转使用)
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name = "")
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void AddIncoming(Value* value, BasicBlock* block) {
if (!block) {
throw std::runtime_error("PhiInst incoming block cannot be null");
}
incoming_.push_back({value, block});
if (value) {
AddOperand(value);
}
}
Value* GetIncomingValue(size_t index) const {
if (index >= incoming_.size()) {
throw std::out_of_range("PhiInst incoming value index out of range");
}
return incoming_[index].first;
}
BasicBlock* GetIncomingBlock(size_t index) const {
if (index >= incoming_.size()) {
throw std::out_of_range("PhiInst incoming block index out of range");
}
return incoming_[index].second;
}
size_t GetNumIncoming() const { return incoming_.size(); }
void SetIncomingValue(size_t index, Value* value) {
if (index >= incoming_.size()) {
throw std::out_of_range("PhiInst incoming value index out of range");
}
incoming_[index].first = value;
SetOperand(index, value);
}
void SetIncomingBlock(size_t index, BasicBlock* block) {
if (index >= incoming_.size()) {
throw std::out_of_range("PhiInst incoming block index out of range");
}
incoming_[index].second = block;
}
private:
std::vector<std::pair<Value*, BasicBlock*>> incoming_;
};
// 创建整数比较指令
class IcmpInst : public Instruction {
public:
@ -730,6 +781,7 @@ class CallInst : public Instruction {
const std::string& name);
Function* GetCallee() const;
const std::vector<Value*>& GetArgs() const;
void SetArg(size_t index, Value* value);
private:
Function* callee_;
@ -774,6 +826,17 @@ class BasicBlock : public Value {
return ptr;
}
template <typename T, typename... Args>
T* InsertAtBeginning(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr;
}
void RemoveInstruction(Instruction* inst);
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;

@ -0,0 +1,40 @@
#pragma once
#include "ir/IR.h"
#include <unordered_map>
#include <vector>
namespace ir {
class DominatorTree {
public:
DominatorTree() = default;
~DominatorTree() = default;
void Recalculate(Function& function);
BasicBlock* GetRoot() const;
BasicBlock* GetIDom(BasicBlock* block) const;
bool Dominates(BasicBlock* a, BasicBlock* b) const;
const std::vector<BasicBlock*>& GetChildren(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetDominanceFrontier(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetPredecessors(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetSuccessors(BasicBlock* block) const;
private:
void BuildCFG(Function& function);
void ComputeIDoms();
void ComputeDominanceFrontiers();
BasicBlock* Intersect(BasicBlock* first, BasicBlock* second) const;
std::vector<BasicBlock*> blocks_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> preds_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> succs_;
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dominance_frontier_;
std::unordered_map<BasicBlock*, int> dfs_number_;
};
} // namespace ir

@ -0,0 +1,15 @@
#pragma once
#include "ir/passes/PassManager.h"
namespace ir {
class ConstFoldPass : public Pass {
public:
ConstFoldPass() = default;
~ConstFoldPass() override = default;
bool RunOnFunction(Function& function) override;
};
} // namespace ir

@ -0,0 +1,15 @@
#pragma once
#include "ir/passes/PassManager.h"
namespace ir {
class ConstPropPass : public Pass {
public:
ConstPropPass() = default;
~ConstPropPass() override = default;
bool RunOnFunction(Function& function) override;
};
} // namespace ir

@ -0,0 +1,15 @@
#pragma once
#include "ir/passes/PassManager.h"
namespace ir {
class DCEPass : public Pass {
public:
DCEPass() = default;
~DCEPass() override = default;
bool RunOnFunction(Function& function) override;
};
} // namespace ir

@ -0,0 +1,27 @@
#pragma once
#include "ir/IR.h"
#include "ir/analysis/DominatorTree.h"
#include "ir/passes/PassManager.h"
namespace ir {
class Mem2RegPass : public Pass {
public:
Mem2RegPass() = default;
~Mem2RegPass() = default;
// 将函数内的可提升内存变量提升到 SSA 形式。
// 返回是否对函数做出了任何修改。
bool RunOnFunction(Function& function) override;
// 可选:在模块级别执行 mem2reg。
bool RunOnModule(Module& module);
private:
bool PromoteAllocas(Function& function, DominatorTree& domtree);
bool changed_ = false;
};
} // namespace ir

@ -0,0 +1,29 @@
#pragma once
#include "ir/IR.h"
#include <memory>
#include <vector>
namespace ir {
class Pass {
public:
virtual ~Pass() = default;
virtual bool RunOnFunction(Function& function) = 0;
};
class PassManager {
public:
PassManager() = default;
~PassManager() = default;
void AddPass(std::unique_ptr<Pass> pass);
bool Run(Function& function);
bool Run(Module& module);
private:
std::vector<std::unique_ptr<Pass>> passes_;
};
} // namespace ir

@ -8,22 +8,12 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <iostream>
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/Sema.h"
//#define DEBUG_IRGen
#ifdef DEBUG_IRGen
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[IRGen Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace ir {
class Module;
class Function;

@ -19,161 +19,41 @@ class MIRContext {
MIRContext& DefaultContext();
enum class PhysReg {
//W0, W8, W9, X29, X30, SP
// 32位通用寄存器
W0, W1, W2, W3, W4, W5, W6, W7, // 参数传递/临时
W8, W9, // 临时寄存器(当前主要使用)
W10, W11, W12, W13, W14, W15, // 临时寄存器(扩展)
W16, W17, // intra-procedure-call 临时
W18, // 平台预留
W19, W20, W21, W22, W23, W24, // 被调用者保存(扩展用)
W25, W26, W27, W28, // 被调用者保存
W29, // 帧指针 (FP)
W30, // 链接寄存器 (LR)
// 64位版本
X0, X1, X2, X3, X4, X5, X6, X7,
X8, X9, X10, X11, X12, X13, X14, X15,
X16, X17, X18,
X19, X20, X21, X22, X23, X24, X25, X26, X27, X28,
X29, // FP
X30, // LR
// 浮点寄存器 (32位)
S0, S1, S2, S3, S4, S5, S6, S7,
S8, S9, S10, S11, S12, S13, S14, S15,
S16, S17, S18, S19, S20, S21, S22, S23,
S24, S25, S26, S27, S28, S29, S30, S31,
// 特殊寄存器
SP, // 栈指针
ZR, // 零寄存器
};
enum class PhysReg { W0, W8, W9, X29, X30, SP };
const char* PhysRegName(PhysReg reg);
// ========== 条件码枚举(用于 BCond 指令)==========
enum class CondCode {
EQ, // 相等 (equal)
NE, // 不等 (not equal)
CS, // 进位设置 (carry set) / 无符号大于等于
CC, // 进位清除 (carry clear) / 无符号小于
MI, // 负数 (minus)
PL, // 非负数 (plus)
VS, // 溢出 (overflow set)
VC, // 无溢出 (overflow clear)
HI, // 无符号大于 (higher)
LS, // 无符号小于等于 (lower or same)
GE, // 有符号大于等于 (greater or equal)
LT, // 有符号小于 (less than)
GT, // 有符号大于 (greater than)
LE, // 有符号小于等于 (less or equal)
AL, // 总是 (always)
};
const char* CondCodeName(CondCode cc);
// ========== MIR 指令操作码枚举 ==========
enum class Opcode {
// ---------- 栈帧相关 ----------
Prologue, // 函数序言(伪指令)
Epilogue, // 函数尾声(伪指令)
// ---------- 数据传输 ----------
MovImm, // 立即数移动到寄存器: MOV w8, #imm
MovReg, // 寄存器之间移动: MOV w8, w9
LoadStack, // 从栈槽加载: LDR w8, [sp, #offset]
StoreStack, // 存储到栈槽: STR w8, [sp, #offset]
LoadStackPair,// 成对加载: LDP x29, x30, [sp], #16
StoreStackPair,// 成对存储: STP x29, x30, [sp, #-16]!
// ---------- 整数算术运算 ----------
AddRR, // 加法: ADD w8, w8, w9
AddRI, // 加法(立即数): ADD w8, w8, #imm
SubRR, // 减法: SUB w8, w8, w9
SubRI, // 减法(立即数): SUB w8, w8, #imm
MulRR, // 乘法: MUL w8, w8, w9
SDivRR, // 有符号除法: SDIV w8, w8, w9
UDivRR, // 无符号除法: UDIV w8, w8, w9
// ---------- 浮点算术运算 ----------
FAddRR, // 浮点加法: FADD s0, s0, s1
FSubRR, // 浮点减法: FSUB s0, s0, s1
FMulRR, // 浮点乘法: FMUL s0, s0, s1
FDivRR, // 浮点除法: FDIV s0, s0, s1
// ---------- 比较运算 ----------
CmpRR, // 比较(寄存器): CMP w8, w9
CmpRI, // 比较(立即数): CMP w8, #imm
FCmpRR, // 浮点比较: FCMP s0, s1
// ---------- 类型转换 ----------
SIToFP, // 有符号整数转浮点: SCVTF s0, w0
FPToSI, // 浮点转有符号整数: FCVTZS w0, s0
ZExt, // 零扩展i1 -> i32: AND w8, w8, #1
// ---------- 控制流 ----------
B, // 无条件跳转: B label
BCond, // 条件跳转: B.EQ label, B.NE label, B.GT label 等
Call, // 函数调用: BL target
Ret, // 函数返回: RET
// ---------- 逻辑运算 ----------
AndRR, // 按位与: AND w8, w8, w9
OrRR, // 按位或: ORR w8, w8, w9
EorRR, // 按位异或: EOR w8, w8, w9
LslRR, // 逻辑左移: LSL w8, w8, w9
LsrRR, // 逻辑右移: LSR w8, w8, w9
AsrRR, // 算术右移: ASR w8, w8, w9
// ---------- 特殊 ----------
Nop, // 空操作: NOP
Label, // 内联标签,不生成实际指令,仅输出标签名
// 添加
Movk, // movk Rd, #imm16, lsl #shift
// 添加
LoadStackAddr, // 将栈帧地址加载到寄存器 (add xd, sp, #offset)
// 用于全局变量地址计算
Adrp, // ADRP Xd, label
AddLabel, // ADD Xd, Xn, :lo12:label
// 新增
Sxtw, // 符号扩展字到双字sxtw Xd, Wn
Prologue,
Epilogue,
MovImm,
LoadStack,
StoreStack,
AddRR,
Ret,
};
// ========== 操作数类 ==========
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex, Cond, Label };
enum class Kind { Reg, Imm, FrameIndex };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand Cond(CondCode cc);
static Operand Label(const std::string& label);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
CondCode GetCondCode() const { return cc_; }
const std::string& GetLabel() const { return label_; }
private:
Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label);
Operand(Kind kind, PhysReg reg, int imm);
Kind kind_;
PhysReg reg_;
int imm_;
CondCode cc_;
std::string label_;
};
// ========== MIR 指令类 ==========
class MachineInstr {
public:
MachineInstr(Opcode opcode, std::vector<Operand> operands = {});
@ -186,14 +66,12 @@ class MachineInstr {
std::vector<Operand> operands_;
};
// ========== 栈槽结构 ==========
struct FrameSlot {
int index = 0;
int size = 4;
int offset = 0;
};
// ========== MIR 基本块 ==========
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
@ -204,133 +82,38 @@ class MachineBasicBlock {
MachineInstr& Append(Opcode opcode,
std::initializer_list<Operand> operands = {});
MachineInstr& Append(Opcode opcode, std::vector<Operand> operands);
// 控制流信息
std::vector<MachineBasicBlock*>& GetSuccessors() { return successors_; }
const std::vector<MachineBasicBlock*>& GetSuccessors() const { return successors_; }
void AddSuccessor(MachineBasicBlock* succ) { successors_.push_back(succ); }
private:
std::string name_;
std::vector<MachineInstr> instructions_;
std::vector<MachineBasicBlock*> successors_;
};
// ========== MIR 函数 ==========
class MachineFunction {
public:
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
// 基本块管理
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBasicBlocks() {
return basic_blocks_;
}
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBasicBlocks() const {
return basic_blocks_;
}
void AddBasicBlock(std::unique_ptr<MachineBasicBlock> bb) {
basic_blocks_.push_back(std::move(bb));
}
// 栈槽管理
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
std::vector<FrameSlot>& GetFrameSlots() { return frame_slots_; }
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
// 栈帧大小
int GetFrameSize() const { return frame_size_; }
void SetFrameSize(int size) { frame_size_ = size; }
private:
std::string name_;
MachineBasicBlock entry_;
std::vector<std::unique_ptr<MachineBasicBlock>> basic_blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
// ========== MIR 模块 ==========
class MachineModule {
public:
MachineModule() = default;
// 添加 MachineFunction
void AddFunction(std::unique_ptr<MachineFunction> func) {
functions_.push_back(std::move(func));
}
// 获取所有函数
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const {
return functions_;
}
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() {
return functions_;
}
// 根据名称查找函数
MachineFunction* GetFunction(const std::string& name) {
for (auto& func : functions_) {
if (func->GetName() == name) {
return func.get();
}
}
return nullptr;
}
const MachineFunction* GetFunction(const std::string& name) const {
for (const auto& func : functions_) {
if (func->GetName() == name) {
return func.get();
}
}
return nullptr;
}
struct GlobalDecl {
std::string name;
int size; // 字节大小
int alignment; // 对齐要求(通常为 4 或 8
bool is_zero_init; // 是否为零初始化
bool has_init_data; // 是否包含初始化数据(用于标量常量)
uint64_t init_data; // 初始化数据≤8字节
// 构造函数,默认零初始化
GlobalDecl(const std::string& n, int sz, int align, bool zero = true,
bool has_data = false, uint64_t data = 0)
: name(n), size(sz), alignment(align), is_zero_init(zero),
has_init_data(has_data), init_data(data) {}
};
void AddGlobal(const std::string& name, int size, int alignment,
bool is_zero_init = true,
bool has_init_data = false, uint64_t init_data = 0) {
globals_.emplace_back(name, size, alignment, is_zero_init,
has_init_data, init_data);
}
const std::vector<GlobalDecl>& GetGlobals() const { return globals_; }
private:
std::vector<std::unique_ptr<MachineFunction>> functions_;
std::vector<GlobalDecl> globals_;
};
// ========== 后端流程函数 ==========
/* std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os); */
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
void PrintAsm(const MachineFunction& function, std::ostream& os);
} // namespace mir

@ -8,6 +8,7 @@ struct CLIOptions {
bool emit_parse_tree = false;
bool emit_ir = true;
bool emit_asm = false;
bool debug = false;
bool show_help = false;
};

@ -10,6 +10,11 @@
void LogInfo(std::string_view msg, std::ostream& os);
void LogError(std::string_view msg, std::ostream& os);
extern bool g_debug_enabled;
void SetDebugEnabled(bool enabled);
bool IsDebugEnabled();
std::ostream& DebugStream();
std::string FormatError(std::string_view stage, std::string_view msg);
std::string FormatErrorAt(std::string_view stage, std::size_t line,
std::size_t column, std::string_view msg);

File diff suppressed because it is too large Load Diff

@ -4,7 +4,6 @@ set -euo pipefail
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
COMPILER="$ROOT_DIR/build/bin/compiler"
TMP_DIR="$ROOT_DIR/build/test_compiler"
RESULT_BASE_DIR="$ROOT_DIR/test/test_result"
TEST_DIRS=("$ROOT_DIR/test/test_case/functional" "$ROOT_DIR/test/test_case/performance")
CC_BIN="${CC:-cc}"
RUNTIME_SRC="$ROOT_DIR/sylib/sylib.c"
@ -67,17 +66,7 @@ for test_dir in "${TEST_DIRS[@]}"; do
ir_total=$((ir_total+1))
base=$(basename "$input")
stem=${base%.sy}
case "$(basename "$test_dir")" in
functional)
out_dir="$RESULT_BASE_DIR/functional/ir"
;;
performance)
out_dir="$RESULT_BASE_DIR/performance/ir"
;;
*)
out_dir="$RESULT_BASE_DIR/$(basename "$test_dir")"
;;
esac
out_dir="$TMP_DIR/$(basename "$test_dir")"
mkdir -p "$out_dir"
ll_file="$out_dir/$stem.ll"
stdout_file="$out_dir/$stem.stdout"

@ -1,167 +0,0 @@
#!/bin/bash
# 测试脚本(支持动态输出目录)
# 用法: ./run_tests.sh [--verbose] [--mode <模式>] [<单个测试文件>]
# 模式: parse-tree, ir, asm, run (默认 run)
COMPILER="./build/bin/compiler"
TEST_SCRIPT="./scripts/verify_asm.sh"
TEST_DIR="test/test_case"
RESULT_FILE="result.txt"
VERBOSE=0
MODE="run" # 默认模式为完整测试
# 解析参数
while [[ $# -gt 0 ]]; do
case "$1" in
--verbose|-v)
VERBOSE=1
shift
;;
--mode)
MODE="$2"
shift 2
;;
--emit-parse-tree)
MODE="parse-tree"
shift
;;
--emit-ir)
MODE="ir"
shift
;;
--emit-asm)
MODE="asm"
shift
;;
--run)
MODE="run"
shift
;;
-*)
echo "未知选项: $1"
exit 1
;;
*)
# 非选项参数视为测试文件
SINGLE_FILE="$1"
shift
;;
esac
done
# 检查编译器是否存在
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器未找到于 $COMPILER"
echo "请先完成项目构建 (cmake 和 make)"
exit 1
fi
# 如果是 run 模式,检查测试脚本是否存在
if [ "$MODE" = "run" ] && [ ! -f "$TEST_SCRIPT" ]; then
echo "错误: 测试脚本未找到于 $TEST_SCRIPT"
exit 1
fi
# 如果指定了单个文件,检查文件是否存在
if [ -n "$SINGLE_FILE" ] && [ ! -f "$SINGLE_FILE" ]; then
echo "错误: 文件 $SINGLE_FILE 不存在"
exit 1
fi
# 清空(或创建)结果文件
> "$RESULT_FILE"
# 计数器
total=0
passed=0
failed=0
echo "开始测试 (模式: $MODE)..."
echo "输出将保存到 $RESULT_FILE"
echo "------------------------"
# 确定测试文件列表
if [ -n "$SINGLE_FILE" ]; then
TEST_FILES=("$SINGLE_FILE")
else
mapfile -t TEST_FILES < <(find "$TEST_DIR" -type f -name "*.sy" | sort)
fi
# 根据模式执行测试
for file in "${TEST_FILES[@]}"; do
((total++))
if [ $VERBOSE -eq 1 ]; then
echo "测试文件: $file"
else
echo -n "测试 $file ... "
fi
echo "========== $file ==========" >> "$RESULT_FILE"
# 根据模式构建命令
case "$MODE" in
parse-tree)
cmd="$COMPILER --emit-parse-tree \"$file\""
;;
ir)
cmd="$COMPILER --emit-ir \"$file\""
;;
asm)
cmd="$COMPILER --emit-asm \"$file\""
;;
run)
# 根据输入文件所在子目录确定输出目录
# 提取 test/test_case/ 之后的子目录名functional 或 performance
rel_path="${file#$TEST_DIR/}"
subdir=$(echo "$rel_path" | cut -d'/' -f1)
if [[ "$subdir" != "functional" && "$subdir" != "performance" ]]; then
echo "警告: 未知子目录 $subdir,使用默认输出目录" >> "$RESULT_FILE"
out_dir="test/test_result/asm"
else
out_dir="test/test_result/$subdir/asm"
fi
cmd="$TEST_SCRIPT \"$file\" \"$out_dir\" --run"
;;
*)
echo "未知模式: $MODE" >&2
exit 1
;;
esac
if [ $VERBOSE -eq 1 ]; then
eval "$cmd" 2>&1 | tee -a "$RESULT_FILE"
result=${PIPESTATUS[0]}
else
eval "$cmd" >> "$RESULT_FILE" 2>&1
result=$?
fi
echo "" >> "$RESULT_FILE"
if [ $result -eq 0 ]; then
if [ $VERBOSE -eq 0 ]; then
echo "通过"
fi
((passed++))
else
if [ $VERBOSE -eq 0 ]; then
echo "失败"
else
echo ">>> 测试失败: $file"
fi
((failed++))
fi
done
echo "------------------------"
echo "总计: $total"
echo "通过: $passed"
echo "失败: $failed"
echo "详细输出已保存至 $RESULT_FILE"
if [ $failed -gt 0 ]; then
exit 1
else
exit 0
fi

@ -52,8 +52,7 @@ expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file"
# 静态链接
aarch64-linux-gnu-gcc -no-pie "$asm_file" -L./sylib -lsysy -static -o "$exe"
aarch64-linux-gnu-gcc "$asm_file" -o "$exe"
echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then
@ -84,7 +83,7 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -w -u "$expected_file" "$actual_file"; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2

@ -0,0 +1,377 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
COMPILER="$ROOT_DIR/build/bin/compiler"
DEFAULT_TEST_ROOT="$ROOT_DIR/test"
TMP_DIR="$ROOT_DIR/build/test_passes"
CC_BIN="${CC:-cc}"
LLC_BIN="${LLC:-llc}"
CLANG_BIN="${CLANG:-clang}"
RUNTIME_SRC="$ROOT_DIR/sylib/sylib.c"
RUNTIME_OBJ="$TMP_DIR/sylib.o"
debug=false
run_exec=false
test_root="$DEFAULT_TEST_ROOT"
stop_on_fail=false
strict_mem2reg=false
usage() {
cat <<EOF
用法: $0 [选项]
选项:
--run 生成 IR 后继续用 llc/clang 运行,并和同名 .out 对比
--debug 打印每个用例的命令与更多诊断信息
--test-root <dir> 指定测试根目录,默认: $DEFAULT_TEST_ROOT
--stop-on-fail 遇到第一个失败立即退出
--strict-mem2reg 将优化后残留标量 alloca 视为失败;默认只作为警告统计
-h, --help 显示帮助
环境变量:
LLC=<path> 指定 llc默认 llc
CLANG=<path> 指定 clang默认 clang
CC=<path> 指定 C 编译器,用于编译 sylib.c默认 cc
EOF
}
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
shift
;;
--debug)
debug=true
shift
;;
--test-root)
if [[ $# -lt 2 ]]; then
echo "--test-root 需要目录参数" >&2
exit 1
fi
test_root="$2"
shift 2
;;
--stop-on-fail)
stop_on_fail=true
shift
;;
--strict-mem2reg)
strict_mem2reg=true
shift
;;
-h|--help)
usage
exit 0
;;
*)
echo "未知参数: $1" >&2
usage >&2
exit 1
;;
esac
done
if [[ ! -x "$COMPILER" ]]; then
echo "未找到编译器: $COMPILER" >&2
echo "请先构建编译器,例如: cmake -S . -B build && cmake --build build -j" >&2
exit 1
fi
if [[ ! -d "$test_root" ]]; then
echo "测试目录不存在: $test_root" >&2
exit 1
fi
mkdir -p "$TMP_DIR"
runtime_ready=0
if [[ "$run_exec" == true ]]; then
if ! command -v "$LLC_BIN" >/dev/null 2>&1; then
echo "未找到 llc: $LLC_BIN" >&2
exit 1
fi
if ! command -v "$CLANG_BIN" >/dev/null 2>&1; then
echo "未找到 clang: $CLANG_BIN" >&2
exit 1
fi
if [[ -f "$RUNTIME_SRC" ]]; then
if "$CC_BIN" -c "$RUNTIME_SRC" -o "$RUNTIME_OBJ" >/dev/null 2>&1; then
runtime_ready=1
else
echo "[WARN] 运行库编译失败,将只链接目标文件: $RUNTIME_SRC" >&2
fi
else
echo "[WARN] 未找到运行库源码,将只链接目标文件: $RUNTIME_SRC" >&2
fi
fi
normalize_file() {
sed 's/\r$//' "$1"
}
make_case_out_dir() {
local input=$1
local rel
rel=$(realpath --relative-to="$test_root" "$(dirname "$input")")
echo "$TMP_DIR/$rel"
}
extract_ir() {
local raw_file=$1
local ll_file=$2
# 编译器在 debug 模式下可能把诊断也写到 stdout这里保留 LLVM-like IR 行。
grep -E '^(define |declare |@|[[:space:]]|})|^[A-Za-z_.$%][A-Za-z0-9_.$%]*:$' \
"$raw_file" > "$ll_file" || true
}
record_failure() {
local bucket=$1
local message=$2
case "$bucket" in
ir) ir_failures+=("$message") ;;
opt) opt_failures+=("$message") ;;
run) run_failures+=("$message") ;;
esac
if [[ "$stop_on_fail" == true ]]; then
echo ""
echo "遇到失败,按 --stop-on-fail 停止。失败文件保留在: $TMP_DIR"
exit 1
fi
}
record_warning() {
local bucket=$1
local message=$2
case "$bucket" in
opt) opt_warnings+=("$message") ;;
esac
}
check_scalar_mem2reg() {
local ll_file=$1
grep -nE '=[[:space:]]*alloca[[:space:]]+(i32|float|i1)\b' "$ll_file" || true
}
compare_result() {
local input=$1
local expected_file=$2
local stdout_file=$3
local status=$4
local actual_file="${stdout_file%.stdout}.actual.out"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && [[ "$(tail -c 1 "$stdout_file" | wc -l)" -eq 0 ]]; then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
local expected_text
local actual_text
expected_text=$(normalize_file "$expected_file")
actual_text=$(normalize_file "$actual_file")
if [[ "$expected_text" == "$actual_text" ]]; then
echo " [RUN] OK"
return 0
fi
echo " [RUN] FAIL: 输出或退出码不匹配"
echo " expected: $expected_file"
echo " actual: $actual_file"
if [[ "$debug" == true ]]; then
diff -u <(printf '%s\n' "$expected_text") <(printf '%s\n' "$actual_text") || true
fi
record_failure run "$input: output mismatch"
return 1
}
mapfile -t test_files < <(find "$test_root" -type f -name '*.sy' | sort)
if [[ ${#test_files[@]} -eq 0 ]]; then
echo "未在目录中找到 .sy 测试: $test_root" >&2
exit 1
fi
ir_total=0
ir_pass=0
opt_total=0
opt_pass=0
run_total=0
run_pass=0
ir_failures=()
opt_failures=()
opt_warnings=()
run_failures=()
echo "测试根目录: $test_root"
echo "输出目录: $TMP_DIR"
echo "测试数量: ${#test_files[@]}"
if [[ "$run_exec" == true ]]; then
echo "运行验证: 开启"
else
echo "运行验证: 关闭(加 --run 可开启语义对拍)"
fi
echo ""
for input in "${test_files[@]}"; do
ir_total=$((ir_total + 1))
opt_total=$((opt_total + 1))
out_dir=$(make_case_out_dir "$input")
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
raw_ir="$out_dir/$stem.raw.ll"
ll_file="$out_dir/$stem.ll"
log_file="$out_dir/$stem.compiler.log"
stdout_file="$out_dir/$stem.stdout"
obj_file="$out_dir/$stem.o"
exe_file="$out_dir/$stem"
input_dir=$(dirname "$input")
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
echo "[TEST] ${input#$ROOT_DIR/}"
if [[ "$debug" == true ]]; then
echo " [CMD] $COMPILER --emit-ir $input"
fi
compiler_status=0
"$COMPILER" --emit-ir "$input" > "$raw_ir" 2> "$log_file" || compiler_status=$?
extract_ir "$raw_ir" "$ll_file"
if [[ $compiler_status -ne 0 ]]; then
echo " [IR] FAIL: 编译器返回 $compiler_status"
record_failure ir "$input: compiler failed ($compiler_status)"
continue
fi
if ! grep -qE '^define ' "$ll_file"; then
echo " [IR] FAIL: 未提取到有效函数定义"
record_failure ir "$input: invalid IR"
continue
fi
ir_pass=$((ir_pass + 1))
echo " [IR] OK"
scalar_allocas=$(check_scalar_mem2reg "$ll_file")
if [[ -n "$scalar_allocas" ]]; then
if [[ "$strict_mem2reg" == true ]]; then
echo " [OPT] FAIL: 优化后仍有可提升标量 alloca"
else
echo " [OPT] WARN: 优化后仍有标量 alloca 残留"
fi
if [[ "$debug" == true ]]; then
echo "$scalar_allocas" | sed 's/^/ /'
fi
if [[ "$strict_mem2reg" == true ]]; then
record_failure opt "$input: scalar alloca remains"
else
opt_pass=$((opt_pass + 1))
record_warning opt "$input: scalar alloca remains"
fi
else
opt_pass=$((opt_pass + 1))
echo " [OPT] OK: 未发现标量 alloca 残留"
fi
if [[ "$run_exec" != true ]]; then
continue
fi
if [[ ! -f "$expected_file" ]]; then
echo " [RUN] SKIP: 未找到期望输出 $expected_file"
continue
fi
run_total=$((run_total + 1))
if ! "$LLC_BIN" -filetype=obj "$ll_file" -o "$obj_file" > "$stdout_file" 2>&1; then
echo " [RUN] FAIL: llc 生成对象文件失败"
record_failure run "$input: llc failed"
continue
fi
if [[ $runtime_ready -eq 1 ]]; then
if ! "$CLANG_BIN" "$obj_file" "$RUNTIME_OBJ" -o "$exe_file" >> "$stdout_file" 2>&1; then
echo " [RUN] FAIL: clang 链接失败"
record_failure run "$input: clang link failed"
continue
fi
else
if ! "$CLANG_BIN" "$obj_file" -o "$exe_file" >> "$stdout_file" 2>&1; then
echo " [RUN] FAIL: clang 链接失败"
record_failure run "$input: clang link failed"
continue
fi
fi
run_status=0
if [[ -f "$stdin_file" ]]; then
"$exe_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$?
else
"$exe_file" > "$stdout_file" 2>&1 || run_status=$?
fi
if compare_result "$input" "$expected_file" "$stdout_file" "$run_status"; then
run_pass=$((run_pass + 1))
fi
done
echo ""
echo "测试完成。"
echo "IR 生成: $ir_pass / $ir_total"
echo "Pass 优化检查: $opt_pass / $opt_total"
if [[ "$run_exec" == true ]]; then
echo "运行结果: $run_pass / $run_total"
fi
if [[ ${#ir_failures[@]} -gt 0 ]]; then
echo ""
echo "IR 失败列表:"
for item in "${ir_failures[@]}"; do
echo " $item"
done
fi
if [[ ${#opt_failures[@]} -gt 0 ]]; then
echo ""
echo "优化检查失败列表:"
for item in "${opt_failures[@]}"; do
echo " $item"
done
fi
if [[ ${#opt_warnings[@]} -gt 0 ]]; then
echo ""
echo "优化警告列表(默认不算失败;加 --strict-mem2reg 可升级为失败):"
for item in "${opt_warnings[@]}"; do
echo " $item"
done
fi
if [[ ${#run_failures[@]} -gt 0 ]]; then
echo ""
echo "运行失败列表:"
for item in "${run_failures[@]}"; do
echo " $item"
done
fi
if [[ ${#ir_failures[@]} -gt 0 || ${#opt_failures[@]} -gt 0 || ${#run_failures[@]} -gt 0 ]]; then
echo ""
echo "失败产物已保留在: $TMP_DIR"
exit 1
fi
echo ""
echo "全部检查通过。"

@ -9,6 +9,7 @@
#include "ir/IR.h"
#include <algorithm>
#include <utility>
namespace ir {
@ -32,6 +33,31 @@ const std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetInstructions()
return instructions_;
}
void BasicBlock::RemoveInstruction(Instruction* inst) {
if (!inst) {
return;
}
auto it = std::find_if(instructions_.begin(), instructions_.end(),
[&](const std::unique_ptr<Instruction>& ptr) {
return ptr.get() == inst;
});
if (it == instructions_.end()) {
return;
}
// 清理该指令对操作数的 use 关系。
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
Value* operand = inst->GetOperand(i);
if (operand) {
operand->RemoveUse(inst, i);
}
}
inst->SetParent(nullptr);
instructions_.erase(it);
}
// 前驱/后继接口先保留给后续 CFG 扩展使用。
// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。
const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const {

@ -196,6 +196,7 @@ static const char* OpcodeToString(Opcode op) {
case Opcode::FPToSI: return "fptosi";
case Opcode::FPExt: return "fpext";
case Opcode::FPTrunc: return "fptrunc";
case Opcode::Phi: return "phi";
}
return "?";
}
@ -457,6 +458,19 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " " << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType());
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
os << (i == 0 ? " " : ", ")
<< "[" << ValueToString(phi->GetIncomingValue(i))
<< ", %" << phi->GetIncomingBlock(i)->GetName() << "]";
}
os << "\n";
break;
}
case Opcode::ZExt: {
auto* zext = static_cast<const ZExtInst*>(inst);
os << " " << zext->GetName() << " = zext "

@ -226,6 +226,14 @@ Function* CallInst::GetCallee() const { return callee_; }
const std::vector<Value*>& CallInst::GetArgs() const { return args_; }
void CallInst::SetArg(size_t index, Value* value) {
if (index >= args_.size()) {
throw std::out_of_range("CallInst argument index out of range");
}
args_[index] = value;
SetOperand(index, value);
}
GEPInst::GEPInst(std::shared_ptr<Type> ptr_ty,
Value* base,
const std::vector<Value*>& indices,
@ -278,4 +286,3 @@ CallInst::CallInst(std::shared_ptr<Type> ret_ty, Function* callee,
} // namespace ir

@ -14,7 +14,6 @@ size_t Type::Size() const {
case Kind::PtrInt32: return 8; // 假设 64 位指针
case Kind::PtrFloat: return 8;
case Kind::Label: return 8; // 标签地址大小(指针大小)
case Kind::PtrInt1: return 8; // 指向 i1 的指针大小
default: return 0; // 派生类应重写
}
}

@ -69,7 +69,19 @@ void Value::ReplaceAllUsesWith(Value* new_value) {
if (!user) continue;
size_t operand_index = use.GetOperandIndex();
if (user->GetOperand(operand_index) == this) {
user->SetOperand(operand_index, new_value);
if (auto* phi = dynamic_cast<PhiInst*>(user)) {
phi->SetIncomingValue(operand_index, new_value);
} else if (auto* br = dynamic_cast<BranchInst*>(user)) {
if (br->IsConditional() && operand_index == 0) {
br->SetCondition(new_value);
} else {
user->SetOperand(operand_index, new_value);
}
} else if (auto* call = dynamic_cast<CallInst*>(user)) {
call->SetArg(operand_index, new_value);
} else {
user->SetOperand(operand_index, new_value);
}
}
}
}

@ -1,4 +1,254 @@
// 支配树分析:
// - 构建/查询 Dominator Tree 及相关关系
// - 为 mem2reg、CFG 优化与循环分析提供基础能力
#include "ir/analysis/DominatorTree.h"
#include <algorithm>
#include <functional>
#include <unordered_set>
namespace ir {
namespace {
std::vector<BasicBlock*> GetBlockSuccessors(BasicBlock* block) {
std::vector<BasicBlock*> succs;
if (!block) {
return succs;
}
const auto& instructions = block->GetInstructions();
if (instructions.empty()) {
return succs;
}
Instruction* term = instructions.back().get();
if (!term->IsTerminator()) {
return succs;
}
if (term->GetOpcode() == Opcode::Br) {
auto* br = static_cast<BranchInst*>(term);
succs.push_back(br->GetTarget());
} else if (term->GetOpcode() == Opcode::CondBr) {
auto* br = static_cast<BranchInst*>(term);
succs.push_back(br->GetTrueTarget());
succs.push_back(br->GetFalseTarget());
}
return succs;
}
} // namespace
void DominatorTree::Recalculate(Function& function) {
BuildCFG(function);
if (blocks_.empty()) {
return;
}
idom_.clear();
ComputeIDoms();
ComputeDominanceFrontiers();
}
BasicBlock* DominatorTree::GetRoot() const {
if (blocks_.empty()) {
return nullptr;
}
return blocks_.front();
}
BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const {
auto it = idom_.find(block);
if (it == idom_.end()) {
return nullptr;
}
return it->second;
}
bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const {
if (!a || !b) {
return false;
}
if (a == b) {
return true;
}
auto it = idom_.find(b);
while (it != idom_.end() && it->second != b) {
if (it->second == a) {
return true;
}
b = it->second;
it = idom_.find(b);
}
return false;
}
const std::vector<BasicBlock*>& DominatorTree::GetChildren(BasicBlock* block) const {
auto it = children_.find(block);
if (it == children_.end()) {
static const std::vector<BasicBlock*> empty;
return empty;
}
return it->second;
}
const std::vector<BasicBlock*>& DominatorTree::GetDominanceFrontier(BasicBlock* block) const {
auto it = dominance_frontier_.find(block);
if (it == dominance_frontier_.end()) {
static const std::vector<BasicBlock*> empty;
return empty;
}
return it->second;
}
const std::vector<BasicBlock*>& DominatorTree::GetPredecessors(BasicBlock* block) const {
auto it = preds_.find(block);
if (it == preds_.end()) {
static const std::vector<BasicBlock*> empty;
return empty;
}
return it->second;
}
const std::vector<BasicBlock*>& DominatorTree::GetSuccessors(BasicBlock* block) const {
auto it = succs_.find(block);
if (it == succs_.end()) {
static const std::vector<BasicBlock*> empty;
return empty;
}
return it->second;
}
void DominatorTree::BuildCFG(Function& function) {
blocks_.clear();
preds_.clear();
succs_.clear();
idom_.clear();
children_.clear();
dominance_frontier_.clear();
dfs_number_.clear();
BasicBlock* entry = function.GetEntry();
if (!entry) {
return;
}
std::unordered_set<BasicBlock*> visited;
int next_number = 0;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* block) {
if (!block || visited.count(block)) {
return;
}
visited.insert(block);
dfs_number_[block] = next_number++;
blocks_.push_back(block);
auto successors = GetBlockSuccessors(block);
succs_[block] = successors;
for (BasicBlock* succ : successors) {
preds_[succ].push_back(block);
dfs(succ);
}
};
dfs(entry);
}
void DominatorTree::ComputeIDoms() {
if (blocks_.empty()) {
return;
}
BasicBlock* entry = blocks_.front();
idom_[entry] = entry;
bool changed = true;
while (changed) {
changed = false;
for (BasicBlock* block : blocks_) {
if (block == entry) {
continue;
}
const auto& predecessors = preds_[block];
BasicBlock* new_idom = nullptr;
for (BasicBlock* pred : predecessors) {
auto pred_it = idom_.find(pred);
if (pred_it == idom_.end()) {
continue;
}
if (!new_idom) {
new_idom = pred;
} else {
new_idom = Intersect(pred, new_idom);
}
}
if (!new_idom) {
continue;
}
if (idom_.find(block) == idom_.end() || idom_[block] != new_idom) {
idom_[block] = new_idom;
changed = true;
}
}
}
children_.clear();
for (const auto& pair : idom_) {
BasicBlock* block = pair.first;
BasicBlock* parent = pair.second;
if (block != parent) {
children_[parent].push_back(block);
}
}
}
void DominatorTree::ComputeDominanceFrontiers() {
dominance_frontier_.clear();
for (BasicBlock* block : blocks_) {
const auto& predecessors = preds_[block];
if (predecessors.size() < 2) {
continue;
}
for (BasicBlock* pred : predecessors) {
BasicBlock* runner = pred;
while (runner != idom_[block]) {
auto& frontier = dominance_frontier_[runner];
if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) {
frontier.push_back(block);
}
runner = idom_[runner];
}
}
}
}
BasicBlock* DominatorTree::Intersect(BasicBlock* first, BasicBlock* second) const {
std::unordered_set<BasicBlock*> first_ancestors;
for (BasicBlock* block = first; block;) {
first_ancestors.insert(block);
auto it = idom_.find(block);
if (it == idom_.end() || it->second == block) {
break;
}
block = it->second;
}
for (BasicBlock* block = second; block;) {
if (first_ancestors.count(block)) {
return block;
}
auto it = idom_.find(block);
if (it == idom_.end() || it->second == block) {
break;
}
block = it->second;
}
return GetRoot();
}
} // namespace ir

@ -1,4 +1,299 @@
// IR 常量折叠:
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪)
#include "ir/passes/ConstFold.h"
#include <cmath>
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct ConstKey {
Type::Kind kind;
int int_value;
uint32_t float_bits;
bool operator==(const ConstKey& other) const {
return kind == other.kind && int_value == other.int_value &&
float_bits == other.float_bits;
}
};
struct ConstKeyHash {
size_t operator()(const ConstKey& key) const {
size_t h = std::hash<int>{}(static_cast<int>(key.kind));
h ^= std::hash<int>{}(key.int_value) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<uint32_t>{}(key.float_bits) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};
ConstantInt* GetIntConstant(std::shared_ptr<Type> ty, int value) {
static std::unordered_map<ConstKey, std::unique_ptr<ConstantInt>, ConstKeyHash> cache;
ConstKey key{ty->GetKind(), value, 0};
auto it = cache.find(key);
if (it != cache.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantInt>(ty, value);
auto* ptr = constant.get();
cache.emplace(key, std::move(constant));
return ptr;
}
ConstantFloat* GetFloatConstant(float value) {
static std::unordered_map<ConstKey, std::unique_ptr<ConstantFloat>, ConstKeyHash> cache;
uint32_t bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
ConstKey key{Type::Kind::Float, 0, bits};
auto it = cache.find(key);
if (it != cache.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantFloat>(Type::GetFloatType(), value);
auto* ptr = constant.get();
cache.emplace(key, std::move(constant));
return ptr;
}
void ReplaceUse(User* user, size_t index, Value* value) {
if (auto* phi = dynamic_cast<PhiInst*>(user)) {
phi->SetIncomingValue(index, value);
} else if (auto* br = dynamic_cast<BranchInst*>(user)) {
if (br->IsConditional() && index == 0) {
br->SetCondition(value);
} else {
user->SetOperand(index, value);
}
} else if (auto* call = dynamic_cast<CallInst*>(user)) {
call->SetArg(index, value);
} else {
user->SetOperand(index, value);
}
}
void ReplaceAllUses(Value* old_value, Value* new_value) {
auto uses = old_value->GetUses();
for (const auto& use : uses) {
User* user = use.GetUser();
if (!user) {
continue;
}
size_t index = use.GetOperandIndex();
if (user->GetOperand(index) == old_value) {
ReplaceUse(user, index, new_value);
}
}
}
bool IsFoldableBinary(Opcode op) {
switch (op) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::And:
case Opcode::Or:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
return true;
default:
return false;
}
}
ConstantValue* FoldBinary(BinaryInst* inst) {
auto* lhs_i = dynamic_cast<ConstantInt*>(inst->GetLhs());
auto* rhs_i = dynamic_cast<ConstantInt*>(inst->GetRhs());
if (lhs_i && rhs_i) {
int lhs = lhs_i->GetValue();
int rhs = rhs_i->GetValue();
int result = 0;
switch (inst->GetOpcode()) {
case Opcode::Add: result = lhs + rhs; break;
case Opcode::Sub: result = lhs - rhs; break;
case Opcode::Mul: result = lhs * rhs; break;
case Opcode::Div:
if (rhs == 0) return nullptr;
result = lhs / rhs;
break;
case Opcode::Mod:
if (rhs == 0) return nullptr;
result = lhs % rhs;
break;
case Opcode::And: result = (lhs != 0 && rhs != 0) ? 1 : 0; break;
case Opcode::Or: result = (lhs != 0 || rhs != 0) ? 1 : 0; break;
default: return nullptr;
}
return GetIntConstant(inst->GetType(), result);
}
auto* lhs_f = dynamic_cast<ConstantFloat*>(inst->GetLhs());
auto* rhs_f = dynamic_cast<ConstantFloat*>(inst->GetRhs());
if (lhs_f && rhs_f) {
float lhs = lhs_f->GetValue();
float rhs = rhs_f->GetValue();
float result = 0.0f;
switch (inst->GetOpcode()) {
case Opcode::FAdd: result = lhs + rhs; break;
case Opcode::FSub: result = lhs - rhs; break;
case Opcode::FMul: result = lhs * rhs; break;
case Opcode::FDiv:
if (rhs == 0.0f) return nullptr;
result = lhs / rhs;
break;
default: return nullptr;
}
return GetFloatConstant(result);
}
return nullptr;
}
ConstantValue* FoldICmp(IcmpInst* inst) {
auto* lhs_c = dynamic_cast<ConstantInt*>(inst->GetLhs());
auto* rhs_c = dynamic_cast<ConstantInt*>(inst->GetRhs());
if (!lhs_c || !rhs_c) {
return nullptr;
}
int lhs = lhs_c->GetValue();
int rhs = rhs_c->GetValue();
bool result = false;
switch (inst->GetPredicate()) {
case IcmpInst::Predicate::EQ: result = lhs == rhs; break;
case IcmpInst::Predicate::NE: result = lhs != rhs; break;
case IcmpInst::Predicate::LT: result = lhs < rhs; break;
case IcmpInst::Predicate::LE: result = lhs <= rhs; break;
case IcmpInst::Predicate::GT: result = lhs > rhs; break;
case IcmpInst::Predicate::GE: result = lhs >= rhs; break;
}
return GetIntConstant(Type::GetInt1Type(), result ? 1 : 0);
}
ConstantValue* FoldFCmp(FcmpInst* inst) {
auto* lhs_c = dynamic_cast<ConstantFloat*>(inst->GetLhs());
auto* rhs_c = dynamic_cast<ConstantFloat*>(inst->GetRhs());
if (!lhs_c || !rhs_c) {
return nullptr;
}
float lhs = lhs_c->GetValue();
float rhs = rhs_c->GetValue();
bool ordered = !std::isnan(lhs) && !std::isnan(rhs);
bool result = false;
switch (inst->GetPredicate()) {
case FcmpInst::Predicate::FALSE: result = false; break;
case FcmpInst::Predicate::OEQ: result = ordered && lhs == rhs; break;
case FcmpInst::Predicate::OGT: result = ordered && lhs > rhs; break;
case FcmpInst::Predicate::OGE: result = ordered && lhs >= rhs; break;
case FcmpInst::Predicate::OLT: result = ordered && lhs < rhs; break;
case FcmpInst::Predicate::OLE: result = ordered && lhs <= rhs; break;
case FcmpInst::Predicate::ONE: result = ordered && lhs != rhs; break;
case FcmpInst::Predicate::ORD: result = ordered; break;
case FcmpInst::Predicate::UNO: result = !ordered; break;
case FcmpInst::Predicate::UEQ: result = !ordered || lhs == rhs; break;
case FcmpInst::Predicate::UGT: result = !ordered || lhs > rhs; break;
case FcmpInst::Predicate::UGE: result = !ordered || lhs >= rhs; break;
case FcmpInst::Predicate::ULT: result = !ordered || lhs < rhs; break;
case FcmpInst::Predicate::ULE: result = !ordered || lhs <= rhs; break;
case FcmpInst::Predicate::UNE: result = !ordered || lhs != rhs; break;
case FcmpInst::Predicate::TRUE: result = true; break;
}
return GetIntConstant(Type::GetInt1Type(), result ? 1 : 0);
}
ConstantValue* FoldCast(Instruction* inst) {
if (auto* zext = dynamic_cast<ZExtInst*>(inst)) {
if (auto* value = dynamic_cast<ConstantInt*>(zext->GetValue())) {
return GetIntConstant(zext->GetType(), value->GetValue() != 0 ? 1 : 0);
}
} else if (auto* trunc = dynamic_cast<TruncInst*>(inst)) {
if (auto* value = dynamic_cast<ConstantInt*>(trunc->GetValue())) {
int result = trunc->GetType()->IsInt1() ? (value->GetValue() != 0 ? 1 : 0)
: value->GetValue();
return GetIntConstant(trunc->GetType(), result);
}
} else if (auto* sitofp = dynamic_cast<SIToFPInst*>(inst)) {
if (auto* value = dynamic_cast<ConstantInt*>(sitofp->GetValue())) {
return GetFloatConstant(static_cast<float>(value->GetValue()));
}
} else if (auto* fptosi = dynamic_cast<FPToSIInst*>(inst)) {
if (auto* value = dynamic_cast<ConstantFloat*>(fptosi->GetValue())) {
return GetIntConstant(fptosi->GetType(), static_cast<int>(value->GetValue()));
}
}
return nullptr;
}
ConstantValue* FoldPhi(PhiInst* phi) {
if (phi->GetNumIncoming() == 0) {
return nullptr;
}
auto* first = dynamic_cast<ConstantValue*>(phi->GetIncomingValue(0));
if (!first) {
return nullptr;
}
for (size_t i = 1; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingValue(i) != first) {
return nullptr;
}
}
return first;
}
ConstantValue* TryFold(Instruction* inst) {
if (auto* binary = dynamic_cast<BinaryInst*>(inst)) {
if (IsFoldableBinary(binary->GetOpcode())) {
return FoldBinary(binary);
}
}
if (auto* icmp = dynamic_cast<IcmpInst*>(inst)) {
return FoldICmp(icmp);
}
if (auto* fcmp = dynamic_cast<FcmpInst*>(inst)) {
return FoldFCmp(fcmp);
}
if (auto* phi = dynamic_cast<PhiInst*>(inst)) {
return FoldPhi(phi);
}
return FoldCast(inst);
}
} // namespace
bool ConstFoldPass::RunOnFunction(Function& function) {
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
BasicBlock* block = block_ptr.get();
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
Instruction* inst = inst_ptr.get();
ConstantValue* folded = TryFold(inst);
if (!folded) {
continue;
}
ReplaceAllUses(inst, folded);
to_remove.push_back(inst);
changed = true;
}
for (Instruction* inst : to_remove) {
block->RemoveInstruction(inst);
}
}
return changed;
}
} // namespace ir

@ -1,5 +1,212 @@
// 常量传播Constant Propagation
// - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
#include "ir/passes/ConstProp.h"
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct IntKey {
Type::Kind kind;
int value;
bool operator==(const IntKey& other) const {
return kind == other.kind && value == other.value;
}
};
struct IntKeyHash {
size_t operator()(const IntKey& key) const {
size_t h = std::hash<int>{}(static_cast<int>(key.kind));
h ^= std::hash<int>{}(key.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};
ConstantInt* GetIntConstant(std::shared_ptr<Type> ty, int value) {
static std::unordered_map<IntKey, std::unique_ptr<ConstantInt>, IntKeyHash> cache;
IntKey key{ty->GetKind(), value};
auto it = cache.find(key);
if (it != cache.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantInt>(ty, value);
auto* ptr = constant.get();
cache.emplace(key, std::move(constant));
return ptr;
}
bool IsZero(Value* value) {
auto* constant = dynamic_cast<ConstantInt*>(value);
return constant && constant->GetValue() == 0;
}
bool IsOne(Value* value) {
auto* constant = dynamic_cast<ConstantInt*>(value);
return constant && constant->GetValue() == 1;
}
bool IsFloatZero(Value* value) {
auto* constant = dynamic_cast<ConstantFloat*>(value);
return constant && constant->GetValue() == 0.0f;
}
bool IsFloatOne(Value* value) {
auto* constant = dynamic_cast<ConstantFloat*>(value);
return constant && constant->GetValue() == 1.0f;
}
void ReplaceUse(User* user, size_t index, Value* value) {
if (auto* phi = dynamic_cast<PhiInst*>(user)) {
phi->SetIncomingValue(index, value);
} else if (auto* br = dynamic_cast<BranchInst*>(user)) {
if (br->IsConditional() && index == 0) {
br->SetCondition(value);
} else {
user->SetOperand(index, value);
}
} else if (auto* call = dynamic_cast<CallInst*>(user)) {
call->SetArg(index, value);
} else {
user->SetOperand(index, value);
}
}
void ReplaceAllUses(Value* old_value, Value* new_value) {
auto uses = old_value->GetUses();
for (const auto& use : uses) {
User* user = use.GetUser();
if (!user) {
continue;
}
size_t index = use.GetOperandIndex();
if (user->GetOperand(index) == old_value) {
ReplaceUse(user, index, new_value);
}
}
}
Value* SimplifyBinary(BinaryInst* inst) {
Value* lhs = inst->GetLhs();
Value* rhs = inst->GetRhs();
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Or:
if (IsZero(lhs)) return rhs;
if (IsZero(rhs)) return lhs;
break;
case Opcode::Sub:
if (IsZero(rhs)) return lhs;
break;
case Opcode::Mul:
if (IsZero(lhs) || IsZero(rhs)) return GetIntConstant(inst->GetType(), 0);
if (IsOne(lhs)) return rhs;
if (IsOne(rhs)) return lhs;
break;
case Opcode::Div:
if (IsZero(lhs)) return GetIntConstant(inst->GetType(), 0);
if (IsOne(rhs)) return lhs;
break;
case Opcode::Mod:
if (IsZero(lhs) || IsOne(rhs)) return GetIntConstant(inst->GetType(), 0);
break;
case Opcode::And:
if (IsZero(lhs) || IsZero(rhs)) return GetIntConstant(inst->GetType(), 0);
if (IsOne(lhs)) return rhs;
if (IsOne(rhs)) return lhs;
break;
case Opcode::FAdd:
if (IsFloatZero(lhs)) return rhs;
if (IsFloatZero(rhs)) return lhs;
break;
case Opcode::FSub:
if (IsFloatZero(rhs)) return lhs;
break;
case Opcode::FMul:
if (IsFloatOne(lhs)) return rhs;
if (IsFloatOne(rhs)) return lhs;
break;
case Opcode::FDiv:
if (IsFloatOne(rhs)) return lhs;
break;
default:
break;
}
return nullptr;
}
Value* SimplifyPhi(PhiInst* phi) {
if (phi->GetNumIncoming() == 0) {
return nullptr;
}
Value* same = nullptr;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
Value* incoming = phi->GetIncomingValue(i);
if (incoming == phi) {
continue;
}
if (!same) {
same = incoming;
continue;
}
if (incoming != same) {
return nullptr;
}
}
return same;
}
Value* TrySimplify(Instruction* inst) {
if (auto* binary = dynamic_cast<BinaryInst*>(inst)) {
return SimplifyBinary(binary);
}
if (auto* phi = dynamic_cast<PhiInst*>(inst)) {
return SimplifyPhi(phi);
}
if (auto* zext = dynamic_cast<ZExtInst*>(inst)) {
if (zext->GetValue()->GetType()->GetKind() == zext->GetType()->GetKind()) {
return zext->GetValue();
}
}
if (auto* trunc = dynamic_cast<TruncInst*>(inst)) {
if (trunc->GetValue()->GetType()->GetKind() == trunc->GetType()->GetKind()) {
return trunc->GetValue();
}
}
return nullptr;
}
} // namespace
bool ConstPropPass::RunOnFunction(Function& function) {
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
BasicBlock* block = block_ptr.get();
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
Instruction* inst = inst_ptr.get();
Value* replacement = TrySimplify(inst);
if (!replacement || replacement == inst) {
continue;
}
ReplaceAllUses(inst, replacement);
to_remove.push_back(inst);
changed = true;
}
for (Instruction* inst : to_remove) {
block->RemoveInstruction(inst);
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,54 @@
// 死代码删除DCE
// - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用
#include "ir/passes/DCE.h"
#include <vector>
namespace ir {
namespace {
bool HasSideEffectOrControl(Instruction* inst) {
switch (inst->GetOpcode()) {
case Opcode::Store:
case Opcode::Ret:
case Opcode::Call:
case Opcode::Br:
case Opcode::CondBr:
return true;
default:
return false;
}
}
bool IsRemovable(Instruction* inst) {
return !HasSideEffectOrControl(inst) && inst->GetUses().empty();
}
} // namespace
bool DCEPass::RunOnFunction(Function& function) {
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
BasicBlock* block = block_ptr.get();
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
Instruction* inst = inst_ptr.get();
if (IsRemovable(inst)) {
to_remove.push_back(inst);
}
}
for (Instruction* inst : to_remove) {
block->RemoveInstruction(inst);
local_changed = true;
changed = true;
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,276 @@
// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
#include "ir/passes/Mem2Reg.h"
#include "ir/analysis/DominatorTree.h"
#include "utils/Log.h"
#include <algorithm>
#include <functional>
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
bool IsScalarAlloca(AllocaInst* alloca) {
if (!alloca) {
return false;
}
auto ty = alloca->GetType();
return ty->IsPtrInt32() || ty->IsPtrFloat() || ty->IsPtrInt1();
}
std::shared_ptr<Type> GetAllocatedElementType(AllocaInst* alloca) {
if (!alloca) {
return nullptr;
}
auto ty = alloca->GetType();
if (ty->IsPtrInt32()) {
return Type::GetInt32Type();
}
if (ty->IsPtrFloat()) {
return Type::GetFloatType();
}
if (ty->IsPtrInt1()) {
return Type::GetInt1Type();
}
return nullptr;
}
bool CollectAllocaUsers(AllocaInst* alloca,
std::vector<LoadInst*>& loads,
std::vector<StoreInst*>& stores) {
loads.clear();
stores.clear();
if (!alloca) {
return false;
}
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
if (!user) {
return false;
}
if (auto* load = dynamic_cast<LoadInst*>(user)) {
if (load->GetPtr() != alloca) {
return false;
}
loads.push_back(load);
} else if (auto* store = dynamic_cast<StoreInst*>(user)) {
if (store->GetPtr() != alloca) {
return false;
}
stores.push_back(store);
} else {
return false;
}
}
return true;
}
bool RenameBlocks(BasicBlock* block, Value* incoming, AllocaInst* alloca,
const DominatorTree& domtree,
const std::unordered_map<BasicBlock*, PhiInst*>& phi_for_block,
std::unordered_map<BasicBlock*, Value*>& block_out,
bool apply_changes) {
Value* current = incoming;
auto phi_it = phi_for_block.find(block);
if (phi_it != phi_for_block.end()) {
current = phi_it->second;
}
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
Instruction* inst = inst_ptr.get();
if (phi_it != phi_for_block.end() && inst == phi_it->second) {
continue;
}
if (inst->GetOpcode() == Opcode::Load) {
auto* load = static_cast<LoadInst*>(inst);
if (load->GetPtr() == alloca) {
if (!current) {
return false;
}
if (apply_changes) {
load->ReplaceAllUsesWith(current);
}
to_remove.push_back(inst);
continue;
}
}
if (inst->GetOpcode() == Opcode::Store) {
auto* store = static_cast<StoreInst*>(inst);
if (store->GetPtr() == alloca) {
current = store->GetValue();
if (apply_changes) {
to_remove.push_back(inst);
}
continue;
}
}
}
block_out[block] = current;
for (BasicBlock* succ : domtree.GetSuccessors(block)) {
auto succ_phi = phi_for_block.find(succ);
if (succ_phi != phi_for_block.end()) {
if (!current) {
return false;
}
if (!apply_changes) {
succ_phi->second->AddIncoming(current, block);
}
}
}
if (apply_changes) {
for (Instruction* inst : to_remove) {
block->RemoveInstruction(inst);
}
}
for (BasicBlock* child : domtree.GetChildren(block)) {
if (!RenameBlocks(child, current, alloca, domtree, phi_for_block,
block_out, apply_changes)) {
return false;
}
}
return true;
}
std::string MakePhiName(AllocaInst* alloca, BasicBlock* block, int id) {
std::string base = alloca && !alloca->GetName().empty() ? alloca->GetName()
: "mem2reg";
std::string block_name = block && !block->GetName().empty() ? block->GetName()
: "block";
return base + "." + block_name + ".phi" + std::to_string(id);
}
} // namespace
bool Mem2RegPass::RunOnFunction(Function& function) {
changed_ = false;
DebugStream() << "[DEBUG] Mem2RegPass: starting on function " << function.GetName() << std::endl;
DominatorTree domtree;
domtree.Recalculate(function);
DebugStream() << "[DEBUG] Mem2RegPass: dominator tree built for " << function.GetName() << std::endl;
changed_ = PromoteAllocas(function, domtree);
DebugStream() << "[DEBUG] Mem2RegPass: finished on function " << function.GetName() << " changed=" << changed_ << std::endl;
return changed_;
}
bool Mem2RegPass::RunOnModule(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed = RunOnFunction(*function) || changed;
}
}
return changed;
}
bool Mem2RegPass::PromoteAllocas(Function& function, DominatorTree& domtree) {
BasicBlock* entry = function.GetEntry();
if (!entry) {
return false;
}
std::vector<AllocaInst*> allocas;
for (const auto& inst_ptr : entry->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst_ptr.get())) {
if (IsScalarAlloca(alloca)) {
allocas.push_back(alloca);
}
}
}
bool changed = false;
int phi_id = 0;
for (AllocaInst* alloca : allocas) {
DebugStream() << "[DEBUG] Mem2RegPass: processing alloca " << alloca->GetName() << std::endl;
std::vector<LoadInst*> loads;
std::vector<StoreInst*> stores;
if (!CollectAllocaUsers(alloca, loads, stores)) {
DebugStream() << "[DEBUG] Mem2RegPass: CollectAllocaUsers failed for " << alloca->GetName() << std::endl;
continue;
}
DebugStream() << "[DEBUG] Mem2RegPass: loads=" << loads.size() << " stores=" << stores.size() << std::endl;
if (stores.empty()) {
continue;
}
std::unordered_set<BasicBlock*> def_blocks;
for (StoreInst* store : stores) {
if (store->GetParent()) {
def_blocks.insert(store->GetParent());
}
}
if (def_blocks.empty()) {
continue;
}
auto element_type = GetAllocatedElementType(alloca);
if (!element_type) {
continue;
}
std::unordered_map<BasicBlock*, PhiInst*> phi_for_block;
std::vector<BasicBlock*> worklist(def_blocks.begin(), def_blocks.end());
std::unordered_set<BasicBlock*> has_phi;
while (!worklist.empty()) {
BasicBlock* block = worklist.back();
worklist.pop_back();
DebugStream() << "[DEBUG] Mem2RegPass: worklist block=" << block->GetName() << std::endl;
for (BasicBlock* frontier : domtree.GetDominanceFrontier(block)) {
if (has_phi.insert(frontier).second) {
PhiInst* phi = frontier->InsertAtBeginning<PhiInst>(element_type,
MakePhiName(alloca, frontier, phi_id++));
DebugStream() << "[DEBUG] Mem2RegPass: inserted phi in " << frontier->GetName() << std::endl;
phi_for_block[frontier] = phi;
if (!def_blocks.count(frontier)) {
worklist.push_back(frontier);
}
}
}
}
std::unordered_map<BasicBlock*, Value*> block_out;
DebugStream() << "[DEBUG] Mem2RegPass: before dry run RenameBlocks for " << alloca->GetName() << std::endl;
if (!RenameBlocks(function.GetEntry(), nullptr, alloca, domtree,
phi_for_block, block_out, false)) {
DebugStream() << "[DEBUG] Mem2RegPass: dry run failed for " << alloca->GetName() << std::endl;
for (auto& pair : phi_for_block) {
pair.first->RemoveInstruction(pair.second);
}
continue;
}
DebugStream() << "[DEBUG] Mem2RegPass: before apply RenameBlocks for " << alloca->GetName() << std::endl;
if (!RenameBlocks(function.GetEntry(), nullptr, alloca, domtree,
phi_for_block, block_out, true)) {
DebugStream() << "[DEBUG] Mem2RegPass: apply run failed for " << alloca->GetName() << std::endl;
for (auto& pair : phi_for_block) {
pair.first->RemoveInstruction(pair.second);
}
continue;
}
if (alloca->GetUses().empty()) {
entry->RemoveInstruction(alloca);
}
changed = true;
}
changed_ = changed;
return changed;
}
} // namespace ir

@ -1 +1,38 @@
// IR Pass 管理骨架。
#include "ir/passes/PassManager.h"
#include "utils/Log.h"
#include <ostream>
namespace ir {
void PassManager::AddPass(std::unique_ptr<Pass> pass) {
if (pass) {
passes_.push_back(std::move(pass));
}
}
bool PassManager::Run(Function& function) {
bool changed = false;
DebugStream() << "[DEBUG] PassManager: running " << passes_.size() << " pass(es) on function "
<< function.GetName() << std::endl;
for (const auto& pass : passes_) {
if (pass) {
DebugStream() << "[DEBUG] PassManager: before pass" << std::endl;
changed = pass->RunOnFunction(function) || changed;
DebugStream() << "[DEBUG] PassManager: after pass" << std::endl;
}
}
DebugStream() << "[DEBUG] PassManager: finished function " << function.GetName() << std::endl;
return changed;
}
bool PassManager::Run(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed = Run(*function) || changed;
}
}
return changed;
}
} // namespace ir

@ -100,14 +100,14 @@ std::string MakeStaticArrayName(const ir::Function& func,
// visitDecl: 处理声明
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
DEBUG_MSG("[DEBUG] visitDecl: 开始处理声明");
DebugStream() << "[DEBUG] visitDecl: 开始处理声明" << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
// 处理 varDecl
if (auto* varDecl = ctx->varDecl()) {
DEBUG_MSG("[DEBUG] visitDecl: 处理变量声明");
DebugStream() << "[DEBUG] visitDecl: 处理变量声明" << std::endl;
for (auto* varDef : varDecl->varDef()) {
varDef->accept(this);
}
@ -115,20 +115,20 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
// 处理 constDecl
if (ctx->constDecl()) {
DEBUG_MSG("[DEBUG] visitDecl: 处理常量声明");
DebugStream() << "[DEBUG] visitDecl: 处理常量声明" << std::endl;
auto* constDecl = ctx->constDecl();
for (auto* constDef : constDecl->constDef()) {
constDef->accept(this);
}
}
DEBUG_MSG("[DEBUG] visitDecl: 声明处理完成");
DebugStream() << "[DEBUG] visitDecl: 声明处理完成" << std::endl;
return {};
}
// visitConstDecl: 处理常量声明
std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
DEBUG_MSG("[DEBUG] visitConstDecl: 开始处理常量声明");
DebugStream() << "[DEBUG] visitConstDecl: 开始处理常量声明" << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法常量声明"));
}
@ -139,13 +139,13 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
}
}
DEBUG_MSG("[DEBUG] visitConstDecl: 常量声明处理完成");
DebugStream() << "[DEBUG] visitConstDecl: 常量声明处理完成" << std::endl;
return {};
}
// visitConstDef: 处理常量定义 - 从符号表获取常量值
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
DEBUG_MSG("[DEBUG] visitConstDef: 开始处理常量定义");
DebugStream() << "[DEBUG] visitConstDef: 开始处理常量定义" << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法常量定义"));
}
@ -158,8 +158,8 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
throw std::runtime_error(FormatError("irgen", "常量符号未找到: " + const_name));
}
DEBUG_MSG("[DEBUG] visitConstDef: 从符号表获取常量 " << const_name
<< ", is_array_const: " << sym->IsArrayConstant());
DebugStream() << "[DEBUG] visitConstDef: 从符号表获取常量 " << const_name
<< ", is_array_const: " << sym->IsArrayConstant() << std::endl;
// 根据符号表中的常量值创建 IR 常量
if (sym->IsArrayConstant()) {
@ -270,12 +270,12 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
ir::ConstantValue* const_value = nullptr;
if (sym->type->IsInt32()) {
const_value = builder_.CreateConstInt(sym->GetIntConstant());
DEBUG_MSG("[DEBUG] visitConstDef: 整型常量 " << const_name
<< " = " << sym->GetIntConstant());
DebugStream() << "[DEBUG] visitConstDef: 整型常量 " << const_name
<< " = " << sym->GetIntConstant() << std::endl;
} else if (sym->type->IsFloat()) {
const_value = builder_.CreateConstFloat(sym->GetFloatConstant());
DEBUG_MSG("[DEBUG] visitConstDef: 浮点常量 " << const_name
<< " = " << sym->GetFloatConstant());
DebugStream() << "[DEBUG] visitConstDef: 浮点常量 " << const_name
<< " = " << sym->GetFloatConstant() << std::endl;
}
const_value_map_[const_name] = const_value;
@ -287,13 +287,13 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
// visitVarDef: 处理变量定义 - 从符号表获取类型信息
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
DEBUG_MSG("[DEBUG] visitVarDef: 开始处理变量定义");
DebugStream() << "[DEBUG] visitVarDef: 开始处理变量定义" << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法变量定义"));
}
std::string varName = ctx->Ident()->getText();
DEBUG_MSG("[DEBUG] visitVarDef: 变量名称: " << varName);
DebugStream() << "[DEBUG] visitVarDef: 变量名称: " << varName << std::endl;
// 防止重复分配
if (storage_map_.find(ctx) != storage_map_.end()) {
@ -306,17 +306,17 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
throw std::runtime_error(FormatError("irgen", "变量符号未找到: " + varName));
}
DEBUG_MSG("[DEBUG] visitVarDef: 变量类型: "
DebugStream() << "[DEBUG] visitVarDef: 变量类型: "
<< (sym->type->IsInt32() ? "int" :
sym->type->IsFloat() ? "float" :
sym->type->IsArray() ? "array" : "unknown"));
sym->type->IsArray() ? "array" : "unknown") << std::endl;
// 根据作用域处理
if (func_ == nullptr) {
DEBUG_MSG("[DEBUG] visitVarDef: 处理全局变量");
DebugStream() << "[DEBUG] visitVarDef: 处理全局变量" << std::endl;
return HandleGlobalVariable(ctx, varName, sym);
} else {
DEBUG_MSG("[DEBUG] visitVarDef: 处理局部变量");
DebugStream() << "[DEBUG] visitVarDef: 处理局部变量" << std::endl;
return HandleLocalVariable(ctx, varName, sym);
}
}
@ -325,7 +325,7 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx,
const std::string& varName,
const Symbol* sym) {
DEBUG_MSG("[DEBUG] HandleGlobalVariable: 开始处理全局变量 " << varName);
DebugStream() << "[DEBUG] HandleGlobalVariable: 开始处理全局变量 " << varName << std::endl;
if (!sym) {
throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName));
@ -349,9 +349,9 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx,
const auto& dimensions = array_ty->GetDimensions();
size_t total_size = array_ty->GetElementCount();
DEBUG_MSG("[DEBUG] HandleGlobalVariable: 全局数组 " << varName << " 维度: ");
for (int d : dimensions) DEBUG_MSG(d << " ");
DEBUG_MSG(", 总大小: " << total_size);
DebugStream() << "[DEBUG] HandleGlobalVariable: 全局数组 " << varName << " 维度: ";
for (int d : dimensions) std::cerr << d << " ";
std::cerr << ", 总大小: " << total_size << std::endl;
// 创建全局数组
ir::GlobalValue* global_array = module_.CreateGlobal(varName, sym->type);
@ -359,7 +359,7 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx,
// 处理初始化值(使用带维度感知的展平)
std::vector<ir::ConstantValue*> init_consts;
if (auto* initVal = ctx->initVal()) {
DEBUG_MSG("[DEBUG] HandleGlobalVariable: 处理初始化值");
DebugStream() << "[DEBUG] HandleGlobalVariable: 处理初始化值" << std::endl;
// 全局变量的初始化必须是常量表达式(语义检查已保证)
std::vector<ir::Value*> flat_vals = FlattenInitVal(
initVal, dimensions, is_float);
@ -439,7 +439,7 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx,
global_map_[varName] = global_var;
}
DEBUG_MSG("[DEBUG] HandleGlobalVariable: 全局变量处理完成");
DebugStream() << "[DEBUG] HandleGlobalVariable: 全局变量处理完成" << std::endl;
return {};
}
@ -447,7 +447,7 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx,
std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx,
const std::string& varName,
const Symbol* sym) {
DEBUG_MSG("[DEBUG] HandleLocalVariable: 开始处理局部变量 " << varName);
DebugStream() << "[DEBUG] HandleLocalVariable: 开始处理局部变量 " << varName << std::endl;
if (!sym) {
throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName));
@ -473,8 +473,8 @@ std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx,
const bool use_heap_storage =
current_function_is_recursive_ || total_bytes > kLocalArrayHeapThresholdBytes;
DEBUG_MSG("[DEBUG] HandleLocalVariable: 局部数组 " << varName
<< " 总大小: " << total_size);
DebugStream() << "[DEBUG] HandleLocalVariable: 局部数组 " << varName
<< " 总大小: " << total_size << std::endl;
ir::Value* array_slot = nullptr;
if (use_heap_storage) {
@ -520,8 +520,8 @@ std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx,
if (is_all_zero_init && !use_heap_storage) {
builder_.CreateStore(module_.GetContext().GetAggregateZero(sym->type),
array_slot);
DEBUG_MSG("[DEBUG] HandleLocalVariable: aggregate zeroinitializer store for "
<< varName);
DebugStream() << "[DEBUG] HandleLocalVariable: aggregate zeroinitializer store for "
<< varName << std::endl;
return {};
}
@ -617,35 +617,35 @@ std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx,
builder_.CreateStore(init, slot);
}
DEBUG_MSG("[DEBUG] HandleLocalVariable: 局部变量处理完成");
DebugStream() << "[DEBUG] HandleLocalVariable: 局部变量处理完成" << std::endl;
return {};
}
// visitInitVal: 处理初始化值
std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) {
DEBUG_MSG("[DEBUG] visitInitVal: 开始处理初始化值");
DebugStream() << "[DEBUG] visitInitVal: 开始处理初始化值" << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法初始化值"));
}
// 如果是单个表达式
if (ctx->exp()) {
DEBUG_MSG("[DEBUG] visitInitVal: 处理表达式初始化");
DebugStream() << "[DEBUG] visitInitVal: 处理表达式初始化" << std::endl;
return EvalExpr(*ctx->exp());
}
// 如果是聚合初始化(花括号列表)
else if (!ctx->initVal().empty()) {
DEBUG_MSG("[DEBUG] visitInitVal: 处理聚合初始化");
DebugStream() << "[DEBUG] visitInitVal: 处理聚合初始化" << std::endl;
return ProcessNestedInitVals(ctx);
}
DEBUG_MSG("[DEBUG] visitInitVal: 空初始化列表");
DebugStream() << "[DEBUG] visitInitVal: 空初始化列表" << std::endl;
return std::vector<ir::Value*>{};
}
// ProcessNestedInitVals: 处理嵌套聚合初始化
std::vector<ir::Value*> IRGenImpl::ProcessNestedInitVals(SysYParser::InitValContext* ctx) {
DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 开始处理嵌套初始化值");
DebugStream() << "[DEBUG] ProcessNestedInitVals: 开始处理嵌套初始化值" << std::endl;
std::vector<ir::Value*> all_values;
for (auto* init_val : ctx->initVal()) {
@ -655,18 +655,18 @@ std::vector<ir::Value*> IRGenImpl::ProcessNestedInitVals(SysYParser::InitValCont
// 尝试获取单个值
ir::Value* value = std::any_cast<ir::Value*>(result);
all_values.push_back(value);
DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 获取到单个值");
DebugStream() << "[DEBUG] ProcessNestedInitVals: 获取到单个值" << std::endl;
} catch (const std::bad_any_cast&) {
try {
// 尝试获取值列表(嵌套情况)
std::vector<ir::Value*> nested_values =
std::any_cast<std::vector<ir::Value*>>(result);
DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 获取到嵌套值列表, 大小: "
<< nested_values.size());
DebugStream() << "[DEBUG] ProcessNestedInitVals: 获取到嵌套值列表, 大小: "
<< nested_values.size() << std::endl;
all_values.insert(all_values.end(),
nested_values.begin(), nested_values.end());
} catch (const std::bad_any_cast&) {
DEBUG_MSG("[ERROR] ProcessNestedInitVals: 不支持的初始化值类型");
std::cerr << "[ERROR] ProcessNestedInitVals: 不支持的初始化值类型" << std::endl;
throw std::runtime_error(
FormatError("irgen", "不支持的初始化值类型"));
}
@ -674,8 +674,8 @@ std::vector<ir::Value*> IRGenImpl::ProcessNestedInitVals(SysYParser::InitValCont
}
}
DEBUG_MSG("[DEBUG] ProcessNestedInitVals: 共获取 " << all_values.size()
<< " 个初始化值");
DebugStream() << "[DEBUG] ProcessNestedInitVals: 共获取 " << all_values.size()
<< " 个初始化值" << std::endl;
return all_values;
}

@ -4,6 +4,11 @@
#include "SysYParser.h"
#include "ir/IR.h"
#include "ir/passes/ConstFold.h"
#include "ir/passes/ConstProp.h"
#include "ir/passes/DCE.h"
#include "ir/passes/Mem2Reg.h"
#include "ir/passes/PassManager.h"
#include "utils/Log.h"
// 修改 GenerateIR 函数
@ -12,5 +17,16 @@ std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
auto module = std::make_unique<ir::Module>();
IRGenImpl gen(*module, sema_result.context, sema_result.symbol_table);
tree.accept(&gen);
ir::PassManager pass_manager;
pass_manager.AddPass(std::make_unique<ir::Mem2RegPass>());
pass_manager.AddPass(std::make_unique<ir::ConstFoldPass>());
pass_manager.AddPass(std::make_unique<ir::ConstPropPass>());
pass_manager.AddPass(std::make_unique<ir::ConstFoldPass>());
pass_manager.AddPass(std::make_unique<ir::DCEPass>());
DebugStream() << "[DEBUG] IRGenDriver: before mem2reg" << std::endl;
pass_manager.Run(*module);
DebugStream() << "[DEBUG] IRGenDriver: after scalar opts" << std::endl;
return module;
}

@ -23,50 +23,50 @@
// 表达式生成
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
DEBUG_MSG("[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText());
DebugStream() << "[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText() << std::endl;
try {
auto result_any = expr.accept(this);
if (!result_any.has_value()) {
DEBUG_MSG("[ERROR] EvalExpr: result_any has no value");
std::cerr << "[ERROR] EvalExpr: result_any has no value" << std::endl;
throw std::runtime_error("表达式求值结果为空");
}
try {
ir::Value* result = std::any_cast<ir::Value*>(result_any);
DEBUG_MSG("[DEBUG] EvalExpr: success, result = " << (void*)result);
DebugStream() << "[DEBUG] EvalExpr: success, result = " << (void*)result << std::endl;
return result;
} catch (const std::bad_any_cast& e) {
DEBUG_MSG("[ERROR] EvalExpr: bad any_cast - " << e.what());
DEBUG_MSG(" Type info: " << result_any.type().name());
std::cerr << "[ERROR] EvalExpr: bad any_cast - " << e.what() << std::endl;
std::cerr << " Type info: " << result_any.type().name() << std::endl;
throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型"));
}
} catch (const std::exception& e) {
DEBUG_MSG("[ERROR] Exception in EvalExpr: " << e.what());
std::cerr << "[ERROR] Exception in EvalExpr: " << e.what() << std::endl;
throw;
}
}
ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) {
DEBUG_MSG("[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText());
DebugStream() << "[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText() << std::endl;
return std::any_cast<ir::Value*>(cond.accept(this));
}
// 基本表达式:数字、变量、括号表达式
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少基本表达式"));
}
DEBUG_MSG("[DEBUG] visitPrimaryExp");
DebugStream() << "[DEBUG] visitPrimaryExp" << std::endl;
// 处理数字字面量
if (ctx->DECIMAL_INT()) {
int value = std::stoi(ctx->DECIMAL_INT()->getText());
ir::Value* const_int = builder_.CreateConstInt(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant int " << value
<< " created as " << (void*)const_int);
DebugStream() << "[DEBUG] visitPrimaryExp: constant int " << value
<< " created as " << (void*)const_int << std::endl;
return static_cast<ir::Value*>(const_int);
}
@ -76,13 +76,13 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
try {
value = std::stof(hex_float_str);
} catch (const std::exception& e) {
DEBUG_MSG("[WARNING] 无法解析十六进制浮点数: " << hex_float_str
<< "使用0.0代替");
std::cerr << "[WARNING] 无法解析十六进制浮点数: " << hex_float_str
<< "使用0.0代替" << std::endl;
value = 0.0f;
}
ir::Value* const_float = builder_.CreateConstFloat(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant hex float " << value
<< " created as " << (void*)const_float);
DebugStream() << "[DEBUG] visitPrimaryExp: constant hex float " << value
<< " created as " << (void*)const_float << std::endl;
return static_cast<ir::Value*>(const_float);
}
@ -92,13 +92,13 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
try {
value = std::stof(dec_float_str);
} catch (const std::exception& e) {
DEBUG_MSG("[WARNING] 无法解析十进制浮点数: " << dec_float_str
<< "使用0.0代替");
std::cerr << "[WARNING] 无法解析十进制浮点数: " << dec_float_str
<< "使用0.0代替" << std::endl;
value = 0.0f;
}
ir::Value* const_float = builder_.CreateConstFloat(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant dec float " << value
<< " created as " << (void*)const_float);
DebugStream() << "[DEBUG] visitPrimaryExp: constant dec float " << value
<< " created as " << (void*)const_float << std::endl;
return static_cast<ir::Value*>(const_float);
}
@ -106,8 +106,8 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::string hex = ctx->HEX_INT()->getText();
int value = std::stoi(hex, nullptr, 16);
ir::Value* const_int = builder_.CreateConstInt(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant hex int " << value
<< " created as " << (void*)const_int);
DebugStream() << "[DEBUG] visitPrimaryExp: constant hex int " << value
<< " created as " << (void*)const_int << std::endl;
return static_cast<ir::Value*>(const_int);
}
@ -115,42 +115,42 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::string oct = ctx->OCTAL_INT()->getText();
int value = std::stoi(oct, nullptr, 8);
ir::Value* const_int = builder_.CreateConstInt(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant octal int " << value
<< " created as " << (void*)const_int);
DebugStream() << "[DEBUG] visitPrimaryExp: constant octal int " << value
<< " created as " << (void*)const_int << std::endl;
return static_cast<ir::Value*>(const_int);
}
if (ctx->ZERO()) {
ir::Value* const_int = builder_.CreateConstInt(0);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant zero int created");
DebugStream() << "[DEBUG] visitPrimaryExp: constant zero int created" << std::endl;
return static_cast<ir::Value*>(const_int);
}
// 处理变量
if (ctx->lVal()) {
DEBUG_MSG("[DEBUG] visitPrimaryExp: visiting lVal");
DebugStream() << "[DEBUG] visitPrimaryExp: visiting lVal" << std::endl;
return ctx->lVal()->accept(this);
}
// 处理括号表达式
if (ctx->L_PAREN() && ctx->exp()) {
DEBUG_MSG("[DEBUG] visitPrimaryExp: visiting parenthesized expression");
DebugStream() << "[DEBUG] visitPrimaryExp: visiting parenthesized expression" << std::endl;
return EvalExpr(*ctx->exp());
}
DEBUG_MSG("[ERROR] visitPrimaryExp: unsupported primary expression type");
std::cerr << "[ERROR] visitPrimaryExp: unsupported primary expression type" << std::endl;
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型"));
}
// 左值(变量)处理
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
std::string varName = ctx->Ident()->getText();
DEBUG_MSG("[DEBUG] visitLVal: " << varName);
DebugStream() << "[DEBUG] visitLVal: " << varName << std::endl;
// 先检查语义分析中常量绑定
const SysYParser::ConstDefContext* const_decl = sema_.ResolveConstUse(ctx);
@ -166,7 +166,7 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
// 如果是常量,直接返回常量值
if (sym && sym->kind == SymbolKind::Constant) {
DEBUG_MSG("[DEBUG] visitLVal: 找到常量 " << varName);
DebugStream() << "[DEBUG] visitLVal: 找到常量 " << varName << std::endl;
if (sym->IsScalarConstant()) {
if (sym->type->IsInt32()) {
@ -394,7 +394,7 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
@ -418,10 +418,10 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
}
ir::Value* right = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitAddExp: left=" << (void*)left
DebugStream() << "[DEBUG] visitAddExp: left=" << (void*)left
<< ", type=" << (left->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)right
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (left->GetType()->IsFloat() != right->GetType()->IsFloat()) {
@ -458,7 +458,7 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
@ -482,10 +482,10 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
}
ir::Value* right = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitMulExp: left=" << (void*)left
DebugStream() << "[DEBUG] visitMulExp: left=" << (void*)left
<< ", type=" << (left->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)right
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (left->GetType()->IsFloat() != right->GetType()->IsFloat()) {
@ -532,7 +532,7 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
// 逻辑与
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
if (!ctx->lAndExp()) {
@ -562,7 +562,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
// 逻辑或
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
if (!ctx->lOrExp()) {
@ -591,32 +591,32 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
}
std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法表达式"));
return ctx->addExp()->accept(this);
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
return ctx->lOrExp()->accept(this);
}
std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法函数调用"));
}
std::string funcName = ctx->Ident()->getText();
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 调用函数 " << funcName);
DebugStream() << "[DEBUG IRGEN] visitCallExp: 调用函数 " << funcName << std::endl;
// 查找函数对象
ir::Function* callee = module_.FindFunction(funcName);
// 如果没找到,可能是运行时函数还没声明,尝试动态声明
if (!callee) {
DEBUG_MSG("[DEBUG IRGEN] 函数 " << funcName << " 未找到,尝试动态声明");
DebugStream() << "[DEBUG IRGEN] 函数 " << funcName << " 未找到,尝试动态声明" << std::endl;
// 根据函数名动态创建运行时函数声明
callee = CreateRuntimeFunctionDecl(funcName);
@ -631,9 +631,9 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
auto argList = ctx->funcRParams()->accept(this);
try {
args = std::any_cast<std::vector<ir::Value*>>(argList);
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 收集到 " << args.size() << " 个参数");
DebugStream() << "[DEBUG IRGEN] visitCallExp: 收集到 " << args.size() << " 个参数" << std::endl;
} catch (const std::bad_any_cast& e) {
DEBUG_MSG("[ERROR] visitCallExp: 函数调用参数类型错误: " << e.what());
std::cerr << "[ERROR] visitCallExp: 函数调用参数类型错误: " << e.what() << std::endl;
}
}
@ -673,13 +673,13 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
}
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 函数调用完成,返回值 " << (void*)callResult);
DebugStream() << "[DEBUG IRGEN] visitCallExp: 函数调用完成,返回值 " << (void*)callResult << std::endl;
return static_cast<ir::Value*>(callResult);
}
// 动态创建运行时函数声明的辅助函数
ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) {
DEBUG_MSG("[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName);
DebugStream() << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName << std::endl;
// 根据常见运行时函数名创建对应的函数类型
if (funcName == "getint" || funcName == "getch") {
@ -792,7 +792,7 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName)
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
}
@ -852,7 +852,7 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 实现函数调用
std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) return std::vector<ir::Value*>{};
std::vector<ir::Value*> args;
for (auto* exp : ctx->exp()) {
@ -863,7 +863,7 @@ std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
// visitConstExp - 处理常量表达式
std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法常量表达式"));
}
@ -884,7 +884,7 @@ std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) {
// visitConstInitVal - 处理常量初始化值
std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法常量初始化值"));
}
@ -929,7 +929,7 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
@ -940,10 +940,10 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
auto* lhs = std::any_cast<ir::Value*>(left_any);
auto* rhs = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitRelExp: left=" << (void*)lhs
DebugStream() << "[DEBUG] visitRelExp: left=" << (void*)lhs
<< ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)rhs
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (lhs->GetType()->IsFloat() != rhs->GetType()->IsFloat()) {
@ -1004,7 +1004,7 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
@ -1015,10 +1015,10 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
auto* lhs = std::any_cast<ir::Value*>(left_any);
auto* rhs = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitEqExp: left=" << (void*)lhs
DebugStream() << "[DEBUG] visitEqExp: left=" << (void*)lhs
<< ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)rhs
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (lhs->GetType()->IsFloat() != rhs->GetType()->IsFloat()) {
@ -1062,8 +1062,8 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "<null>"));
DEBUG_MSG("[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "<null>") << std::endl;
DebugStream() << "[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->lVal() || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法赋值语句"));
}
@ -1127,12 +1127,12 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
} else {
// 普通标量赋值
// 调试输出指针类型
DEBUG_MSG("[DEBUG] base_ptr type: " << base_ptr->GetType());
DEBUG_MSG("[DEBUG] rhs type: " << rhs->GetType());
DebugStream() << "[DEBUG] base_ptr type: " << base_ptr->GetType() << std::endl;
DebugStream() << "[DEBUG] rhs type: " << rhs->GetType()<< std::endl;
// 如果 base_ptr 不是标量指针类型,可能需要特殊处理
if (!base_ptr->GetType()->IsPtrInt32() && !base_ptr->GetType()->IsPtrFloat()) {
DEBUG_MSG("[ERROR] base_ptr is not a pointer type!");
std::cerr << "[ERROR] base_ptr is not a pointer type!" << std::endl;
throw std::runtime_error("尝试存储到非指针类型");
}
rhs = convert_for_store(rhs, base_ptr);

@ -54,7 +54,7 @@ IRGenImpl::IRGenImpl(ir::Module& module,
}
void IRGenImpl::AddRuntimeFunctions() {
DEBUG_MSG("[DEBUG IRGEN] 添加运行时库函数声明");
DebugStream() << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl;
// 输入函数(返回 int
module_.CreateFunction("getint",
@ -155,21 +155,21 @@ void IRGenImpl::AddRuntimeFunctions() {
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()}));
DEBUG_MSG("[DEBUG IRGEN] 运行时库函数声明完成");
DebugStream() << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl;
}
// 修正:没有 mainFuncDef通过函数名找到 main
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitCompUnit");
DEBUG_MSG("[DEBUG] IRGen: 符号表地址 = " << &symbol_table_);
DEBUG_MSG("[DEBUG] IRGen: 开始生成 IR");
DebugStream() << "[DEBUG IRGEN] visitCompUnit" << std::endl;
DebugStream() << "[DEBUG] IRGen: 符号表地址 = " << &symbol_table_ << std::endl;
DebugStream() << "[DEBUG] IRGen: 开始生成 IR" << std::endl;
// 尝试查找 main 函数
const Symbol* main_sym = symbol_table_.lookup("main");
if (main_sym) {
DEBUG_MSG("[DEBUG] IRGen: 找到 main 函数符号");
DebugStream() << "[DEBUG] IRGen: 找到 main 函数符号" << std::endl;
} else {
DEBUG_MSG("[DEBUG] IRGen: 未找到 main 函数符号");
DebugStream() << "[DEBUG] IRGen: 未找到 main 函数符号" << std::endl;
}
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
@ -193,7 +193,7 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
}
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
@ -255,25 +255,25 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
auto func_type = ir::Type::GetFunctionType(ret_type, param_types);
// 调试输出
DEBUG_MSG("[DEBUG] visitFuncDef: 创建函数 " << funcName
DebugStream() << "[DEBUG] visitFuncDef: 创建函数 " << funcName
<< ",返回类型: " << (ret_type->IsVoid() ? "void" : ret_type->IsFloat() ? "float" : "int")
<< ",参数数量: " << param_types.size());
<< ",参数数量: " << param_types.size() << std::endl;
// 创建函数对象
func_ = module_.CreateFunction(funcName, func_type);
// 检查函数是否成功创建
if (!func_) {
DEBUG_MSG("[ERROR] visitFuncDef: 创建函数失败func_ 为 nullptr!");
std::cerr << "[ERROR] visitFuncDef: 创建函数失败func_ 为 nullptr!" << std::endl;
throw std::runtime_error(FormatError("irgen", "创建函数失败: " + funcName));
}
DEBUG_MSG("[DEBUG] visitFuncDef: 函数对象地址: " << (void*)func_);
DebugStream() << "[DEBUG] visitFuncDef: 函数对象地址: " << (void*)func_ << std::endl;
// 设置插入点
auto* entry_block = func_->GetEntry();
if (!entry_block) {
DEBUG_MSG("[ERROR] visitFuncDef: 函数入口基本块为空!");
std::cerr << "[ERROR] visitFuncDef: 函数入口基本块为空!" << std::endl;
throw std::runtime_error(FormatError("irgen", "函数入口基本块为空: " + funcName));
}
@ -324,15 +324,15 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
// 检查函数对象是否有效
if (!func_) {
DEBUG_MSG("[ERROR] visitFuncDef: func_ 在添加参数时变为 nullptr!");
std::cerr << "[ERROR] visitFuncDef: func_ 在添加参数时变为 nullptr!" << std::endl;
throw std::runtime_error(FormatError("irgen", "函数对象无效"));
}
DEBUG_MSG("[DEBUG] visitFuncDef: 为函数 " << funcName
DebugStream() << "[DEBUG] visitFuncDef: 为函数 " << funcName
<< " 添加参数 " << name << ",类型: "
<< (param_ty->IsInt32() ? "int32" : param_ty->IsFloat() ? "float" :
param_ty->IsPtrInt32() ? "ptr_int32" : param_ty->IsPtrFloat() ? "ptr_float" : "other")
);
<< std::endl;
// 创建参数并添加到函数
auto arg = std::make_unique<ir::Argument>(param_ty, name);
@ -344,7 +344,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
auto* added_arg = func_->AddArgument(std::move(arg));
if (!added_arg) {
DEBUG_MSG("[ERROR] visitFuncDef: AddArgument 返回 nullptr!");
std::cerr << "[ERROR] visitFuncDef: AddArgument 返回 nullptr!" << std::endl;
throw std::runtime_error(FormatError("irgen", "添加参数失败: " + name));
}
@ -371,17 +371,17 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
pointer_param_names_.erase(name);
}
DEBUG_MSG("[DEBUG] visitFuncDef: 参数 " << name << " 处理完成");
DebugStream() << "[DEBUG] visitFuncDef: 参数 " << name << " 处理完成" << std::endl;
}
}
// 生成函数体
DEBUG_MSG("[DEBUG] visitFuncDef: 开始生成函数体");
DebugStream() << "[DEBUG] visitFuncDef: 开始生成函数体" << std::endl;
ctx->block()->accept(this);
// 如果当前插入块没有终止指令,添加默认返回
if (auto* cur = builder_.GetInsertBlock(); cur && !cur->HasTerminator()) {
DEBUG_MSG("[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回");
DebugStream() << "[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回" << std::endl;
if (function_cleanup_block_) {
if (ret_type->IsFloat()) {
builder_.CreateStore(builder_.CreateConstFloat(0.0f), function_return_slot_);
@ -416,11 +416,11 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
try {
VerifyFunctionStructure(*func_);
} catch (const std::exception& e) {
DEBUG_MSG("[ERROR] visitFuncDef: 验证函数结构失败: " << e.what());
std::cerr << "[ERROR] visitFuncDef: 验证函数结构失败: " << e.what() << std::endl;
throw;
}
DEBUG_MSG("[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成");
DebugStream() << "[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成" << std::endl;
func_ = nullptr;
current_function_name_.clear();
current_function_is_recursive_ = false;
@ -467,7 +467,7 @@ ir::AllocaInst* IRGenImpl::CreateEntryAllocaFloat(const std::string& name) {
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
@ -482,8 +482,8 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
}
auto* cur = builder_.GetInsertBlock();
DEBUG_MSG("[DEBUG] current insert block: "
<< (cur ? cur->GetName() : "<null>"));
DebugStream() << "[DEBUG] current insert block: "
<< (cur ? cur->GetName() : "<null>") << std::endl;
if (cur && cur->HasTerminator()) {
break;
}
@ -500,7 +500,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
}
// 用于遍历块内项,返回是否继续访问后续项(如遇到 return/break/continue 则终止访问)
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}

@ -16,7 +16,7 @@
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
@ -65,7 +65,7 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
// 修改 HandleReturnStmt 函数
IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
}
@ -132,7 +132,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
// if语句待实现
IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
auto* cond = ctx->cond();
auto* thenStmt = ctx->stmt(0);
@ -148,11 +148,11 @@ IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) {
auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock(uniq("else")) : nullptr;
auto* mergeBlock = func_->CreateBlock(uniq("merge"));
DEBUG_MSG("[DEBUG IF] thenBlock: " << thenBlock->GetName());
if (elseBlock) DEBUG_MSG("[DEBUG IF] elseBlock: " << elseBlock->GetName());
DEBUG_MSG("[DEBUG IF] mergeBlock: " << mergeBlock->GetName());
DEBUG_MSG("[DEBUG IF] current insert block before cond: "
<< builder_.GetInsertBlock()->GetName());
DebugStream() << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl;
if (elseBlock) DebugStream() << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl;
DebugStream() << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl;
DebugStream() << "[DEBUG IF] current insert block before cond: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
// 生成条件
auto* condValue = EvalCond(*cond);
@ -168,74 +168,74 @@ IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) {
// 创建条件跳转
if (elseBlock) {
DEBUG_MSG("[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << elseBlock->GetName());
DebugStream() << "[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << elseBlock->GetName() << std::endl;
builder_.CreateCondBr(condValue, thenBlock, elseBlock);
} else {
DEBUG_MSG("[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName());
DebugStream() << "[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName() << std::endl;
builder_.CreateCondBr(condValue, thenBlock, mergeBlock);
}
// 生成 then 分支
DEBUG_MSG("[DEBUG IF] Generating then branch in block: " << thenBlock->GetName());
DebugStream() << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl;
builder_.SetInsertPoint(thenBlock);
auto thenResult = thenStmt->accept(this);
bool thenTerminated = (std::any_cast<BlockFlow>(thenResult) == BlockFlow::Terminated);
DEBUG_MSG("[DEBUG IF] then branch terminated: " << thenTerminated);
DebugStream() << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl;
if (!thenTerminated) {
DEBUG_MSG("[DEBUG IF] Adding br to merge block from then");
DebugStream() << "[DEBUG IF] Adding br to merge block from then" << std::endl;
builder_.CreateBr(mergeBlock);
}
DEBUG_MSG("[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator());
DebugStream() << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl;
// 生成 else 分支
bool elseTerminated = false;
if (elseBlock) {
DEBUG_MSG("[DEBUG IF] Generating else branch in block: " << elseBlock->GetName());
DebugStream() << "[DEBUG IF] Generating else branch in block: " << elseBlock->GetName() << std::endl;
builder_.SetInsertPoint(elseBlock);
auto elseResult = elseStmt->accept(this);
elseTerminated = (std::any_cast<BlockFlow>(elseResult) == BlockFlow::Terminated);
DEBUG_MSG("[DEBUG IF] else branch terminated: " << elseTerminated);
DebugStream() << "[DEBUG IF] else branch terminated: " << elseTerminated << std::endl;
if (!elseTerminated) {
DEBUG_MSG("[DEBUG IF] Adding br to merge block from else");
DebugStream() << "[DEBUG IF] Adding br to merge block from else" << std::endl;
builder_.CreateBr(mergeBlock);
}
DEBUG_MSG("[DEBUG IF] else block has terminator: " << elseBlock->HasTerminator());
DebugStream() << "[DEBUG IF] else block has terminator: " << elseBlock->HasTerminator() << std::endl;
}
// 决定后续插入点
DEBUG_MSG("[DEBUG IF] thenTerminated=" << thenTerminated
<< ", elseTerminated=" << elseTerminated);
DebugStream() << "[DEBUG IF] thenTerminated=" << thenTerminated
<< ", elseTerminated=" << elseTerminated << std::endl;
if (elseBlock) {
DEBUG_MSG("[DEBUG IF] Setting insert point to merge block: "
<< mergeBlock->GetName());
DebugStream() << "[DEBUG IF] Setting insert point to merge block: "
<< mergeBlock->GetName() << std::endl;
builder_.SetInsertPoint(mergeBlock);
} else {
DEBUG_MSG("[DEBUG IF] No else, setting insert point to merge block: "
<< mergeBlock->GetName());
DebugStream() << "[DEBUG IF] No else, setting insert point to merge block: "
<< mergeBlock->GetName() << std::endl;
builder_.SetInsertPoint(mergeBlock);
}
DEBUG_MSG("[DEBUG IF] Final insert block: "
<< builder_.GetInsertBlock()->GetName());
DebugStream() << "[DEBUG IF] Final insert block: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
return BlockFlow::Continue;
}
// while语句待实现IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleWhileStmt: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] HandleWhileStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->cond() || !ctx->stmt(0)) {
throw std::runtime_error(FormatError("irgen", "非法 while 语句"));
}
DEBUG_MSG("[DEBUG WHILE] Current insert block before while: "
<< builder_.GetInsertBlock()->GetName());
DebugStream() << "[DEBUG WHILE] Current insert block before while: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
auto uniq = [&](const std::string& prefix) {
std::string t = module_.GetContext().NextTemp();
@ -246,18 +246,18 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
auto* bodyBlock = func_->CreateBlock(uniq("while.body"));
auto* exitBlock = func_->CreateBlock(uniq("while.exit"));
DEBUG_MSG("[DEBUG WHILE] condBlock: " << condBlock->GetName());
DEBUG_MSG("[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName());
DEBUG_MSG("[DEBUG WHILE] exitBlock: " << exitBlock->GetName());
DebugStream() << "[DEBUG WHILE] condBlock: " << condBlock->GetName() << std::endl;
DebugStream() << "[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName() << std::endl;
DebugStream() << "[DEBUG WHILE] exitBlock: " << exitBlock->GetName() << std::endl;
DEBUG_MSG("[DEBUG WHILE] Adding br to condBlock from current block");
DebugStream() << "[DEBUG WHILE] Adding br to condBlock from current block" << std::endl;
builder_.CreateBr(condBlock);
loopStack_.push_back({condBlock, bodyBlock, exitBlock});
DEBUG_MSG("[DEBUG WHILE] loopStack size: " << loopStack_.size());
DebugStream() << "[DEBUG WHILE] loopStack size: " << loopStack_.size() << std::endl;
// 条件块
DEBUG_MSG("[DEBUG WHILE] Generating condition in block: " << condBlock->GetName());
DebugStream() << "[DEBUG WHILE] Generating condition in block: " << condBlock->GetName() << std::endl;
builder_.SetInsertPoint(condBlock);
auto* condValue = EvalCond(*ctx->cond());
if (!condValue->GetType()->IsInt1()) {
@ -270,45 +270,45 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
}
}
builder_.CreateCondBr(condValue, bodyBlock, exitBlock);
DEBUG_MSG("[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator());
DebugStream() << "[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator() << std::endl;
// 循环体
DEBUG_MSG("[DEBUG WHILE] Generating body in block: " << bodyBlock->GetName());
DebugStream() << "[DEBUG WHILE] Generating body in block: " << bodyBlock->GetName() << std::endl;
builder_.SetInsertPoint(bodyBlock);
auto bodyResult = ctx->stmt(0)->accept(this);
bool bodyTerminated = (std::any_cast<BlockFlow>(bodyResult) == BlockFlow::Terminated);
DEBUG_MSG("[DEBUG WHILE] body terminated: " << bodyTerminated);
DebugStream() << "[DEBUG WHILE] body terminated: " << bodyTerminated << std::endl;
if (!bodyTerminated) {
DEBUG_MSG("[DEBUG WHILE] Adding br to condBlock from body");
DebugStream() << "[DEBUG WHILE] Adding br to condBlock from body" << std::endl;
builder_.CreateBr(condBlock);
}
DEBUG_MSG("[DEBUG WHILE] bodyBlock has terminator: " << bodyBlock->HasTerminator());
DebugStream() << "[DEBUG WHILE] bodyBlock has terminator: " << bodyBlock->HasTerminator() << std::endl;
loopStack_.pop_back();
DEBUG_MSG("[DEBUG WHILE] loopStack size after pop: " << loopStack_.size());
DebugStream() << "[DEBUG WHILE] loopStack size after pop: " << loopStack_.size() << std::endl;
// 设置插入点为 exitBlock
DEBUG_MSG("[DEBUG WHILE] Setting insert point to exitBlock: " << exitBlock->GetName());
DebugStream() << "[DEBUG WHILE] Setting insert point to exitBlock: " << exitBlock->GetName() << std::endl;
builder_.SetInsertPoint(exitBlock);
DEBUG_MSG("[DEBUG WHILE] exitBlock has terminator before return: "
<< exitBlock->HasTerminator());
DebugStream() << "[DEBUG WHILE] exitBlock has terminator before return: "
<< exitBlock->HasTerminator() << std::endl;
return BlockFlow::Continue;
}
// break语句待实现
IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleBreakStmt: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] HandleBreakStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (loopStack_.empty()) {
throw std::runtime_error(FormatError("irgen", "break 语句不在循环中"));
}
DEBUG_MSG("[DEBUG BREAK] Current insert block before break: "
<< builder_.GetInsertBlock()->GetName());
DEBUG_MSG("[DEBUG BREAK] Breaking to exitBlock: "
<< loopStack_.back().exitBlock->GetName());
DebugStream() << "[DEBUG BREAK] Current insert block before break: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
DebugStream() << "[DEBUG BREAK] Breaking to exitBlock: "
<< loopStack_.back().exitBlock->GetName() << std::endl;
// 跳转到循环退出块
builder_.CreateBr(loopStack_.back().exitBlock);
@ -318,16 +318,16 @@ IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) {
}
IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleContinueStmt: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] HandleContinueStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (loopStack_.empty()) {
throw std::runtime_error(FormatError("irgen", "continue 语句不在循环中"));
}
DEBUG_MSG("[DEBUG CONTINUE] Current insert block before continue: "
<< builder_.GetInsertBlock()->GetName());
DEBUG_MSG("[DEBUG CONTINUE] Continuing to condBlock: "
<< loopStack_.back().condBlock->GetName());
DebugStream() << "[DEBUG CONTINUE] Current insert block before continue: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
DebugStream() << "[DEBUG CONTINUE] Continuing to condBlock: "
<< loopStack_.back().condBlock->GetName() << std::endl;
// 跳转到循环条件块
builder_.CreateBr(loopStack_.back().condBlock);
@ -340,7 +340,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx)
// 赋值语句
// 赋值语句
IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleAssignStmt: " << (ctx ? ctx->getText() : "<null>"));
DebugStream() << "[DEBUG IRGEN] HandleAssignStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->lVal() || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法赋值语句"));
@ -354,7 +354,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto* lval = ctx->lVal();
std::string varName = lval->Ident()->getText();
DEBUG_MSG("[DEBUG] HandleAssignStmt: assigning to " << varName);
DebugStream() << "[DEBUG] HandleAssignStmt: assigning to " << varName << std::endl;
// 1. 检查是否为常量(不能给常量赋值)
auto* const_decl = sema_.ResolveConstUse(lval);
@ -372,8 +372,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it = storage_map_.find(var_decl);
if (it != storage_map_.end()) {
base_ptr = it->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in storage_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
DebugStream() << "[DEBUG] HandleAssignStmt: found in storage_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -382,8 +382,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it2 = param_map_.find(varName);
if (it2 != param_map_.end()) {
base_ptr = it2->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in param_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
DebugStream() << "[DEBUG] HandleAssignStmt: found in param_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -392,8 +392,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it3 = global_map_.find(varName);
if (it3 != global_map_.end()) {
base_ptr = it3->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in global_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
DebugStream() << "[DEBUG] HandleAssignStmt: found in global_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -402,8 +402,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it4 = local_var_map_.find(varName);
if (it4 != local_var_map_.end()) {
base_ptr = it4->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in local_var_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
DebugStream() << "[DEBUG] HandleAssignStmt: found in local_var_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -497,21 +497,21 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
builder_.CreateStore(rhs, elem_ptr);
} else {
// 普通标量赋值
DEBUG_MSG("[DEBUG] HandleAssignStmt: scalar assignment to " << varName
DebugStream() << "[DEBUG] HandleAssignStmt: scalar assignment to " << varName
<< ", ptr = " << (void*)base_ptr
<< ", rhs = " << (void*)rhs);
<< ", rhs = " << (void*)rhs << std::endl;
// 在 HandleAssignStmt 中,存储前添加类型调试
if (base_ptr && base_ptr->GetType()) {
DEBUG_MSG("[DEBUG] Is int32: " << base_ptr->GetType()->IsInt32());
DEBUG_MSG("[DEBUG] Is float: " << base_ptr->GetType()->IsFloat());
DEBUG_MSG("[DEBUG] Is ptr int32: " << base_ptr->GetType()->IsPtrInt32());
DEBUG_MSG("[DEBUG] Is ptr float: " << base_ptr->GetType()->IsPtrFloat());
DEBUG_MSG("[DEBUG] Is array: " << base_ptr->GetType()->IsArray());
DebugStream() << "[DEBUG] Is int32: " << base_ptr->GetType()->IsInt32() << std::endl;
DebugStream() << "[DEBUG] Is float: " << base_ptr->GetType()->IsFloat() << std::endl;
DebugStream() << "[DEBUG] Is ptr int32: " << base_ptr->GetType()->IsPtrInt32() << std::endl;
DebugStream() << "[DEBUG] Is ptr float: " << base_ptr->GetType()->IsPtrFloat() << std::endl;
DebugStream() << "[DEBUG] Is array: " << base_ptr->GetType()->IsArray() << std::endl;
}
if (rhs && rhs->GetType()) {
DEBUG_MSG("[DEBUG] Value is int32: " << rhs->GetType()->IsInt32());
DebugStream() << "[DEBUG] Value is int32: " << rhs->GetType()->IsInt32() << std::endl;
}
if (base_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) {
rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(),

@ -35,6 +35,8 @@ int main(int argc, char** argv) {
}
auto sema = RunSema(*comp_unit);
SetDebugEnabled(opts.debug);
auto module = GenerateIR(*comp_unit, sema);
if (opts.emit_ir) {
ir::IRPrinter printer;
@ -46,17 +48,13 @@ int main(int argc, char** argv) {
}
if (opts.emit_asm) {
//auto machine_func = mir::LowerToMIR(*module);
auto machine_module = mir::LowerToMIR(*module);
//mir::RunRegAlloc(*machine_func);
mir::RunRegAlloc(*machine_module);
//mir::RunFrameLowering(*machine_func);
mir::RunFrameLowering(*machine_module);
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
if (need_blank_line) {
std::cout << "\n";
}
//mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(*machine_module, std::cout);
mir::PrintAsm(*machine_func, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {

@ -2,24 +2,12 @@
#include <ostream>
#include <stdexcept>
#include <set>
#include "utils/Log.h"
//#define DEBUG_Asm
#ifdef DEBUG_Asm
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Asm Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace mir {
namespace {
static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm);
const FrameSlot& GetFrameSlot(const MachineFunction& function,
const Operand& operand) {
if (operand.GetKind() != Operand::Kind::FrameIndex) {
@ -28,458 +16,63 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintStackAccess(std::ostream& os, const char* insn, PhysReg reg, int64_t offset) {
// offset 通常是负数,例如 -8, -24, -40 等
if (offset >= -256 && offset <= 255) {
os << " " << insn << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
return;
}
// 大偏移量:用 x16 计算 x29 + offset然后间接访问
os << " mov x16, x29\n";
int64_t abs_offset = (offset >= 0) ? offset : -offset;
if (abs_offset <= 4095) {
if (offset >= 0) {
os << " add x16, x16, #" << offset << "\n";
} else {
os << " sub x16, x16, #" << abs_offset << "\n";
}
} else {
// 分解大偏移量
PrintLoadImm64(os, PhysReg::X17, abs_offset);
if (offset >= 0) {
os << " add x16, x16, x17\n";
} else {
os << " sub x16, x16, x17\n";
}
}
os << " " << insn << " " << PhysRegName(reg) << ", [x16]\n";
}
// 打印单个操作数
void PrintOperand(std::ostream& os, const Operand& op) {
switch (op.GetKind()) {
case Operand::Kind::Reg:
os << PhysRegName(op.GetReg());
break;
case Operand::Kind::Imm:
os << "#" << op.GetImm();
break;
case Operand::Kind::FrameIndex:
os << "[sp, #" << op.GetFrameIndex() << "]";
break;
case Operand::Kind::Cond:
os << CondCodeName(op.GetCondCode());
break;
case Operand::Kind::Label:
DEBUG_MSG("label is" << op.GetLabel());
os << op.GetLabel();
break;
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
}
// 判断立即数是否可作为 AArch64 ADD/SUB 指令的 12 位立即数(可左移 0 或 12 位)
static bool IsLegalAddSubImm(int64_t imm) {
if (imm < 0) imm = -imm; // 取绝对值,因为移位规则对称
if (imm <= 4095) return true; // 0-4095 直接合法
if ((imm & 0xFFF) == 0 && imm <= 4095 * 4096) return true; // 4096 的倍数且 ≤ 16773120
return false;
}
// 在匿名命名空间添加辅助函数
static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm) {
// 输出 movz + movk 序列
uint16_t part0 = imm & 0xFFFF;
uint16_t part1 = (imm >> 16) & 0xFFFF;
uint16_t part2 = (imm >> 32) & 0xFFFF;
uint16_t part3 = (imm >> 48) & 0xFFFF;
os << " movz " << PhysRegName(reg) << ", #" << part0;
if (part1 != 0 || part2 != 0 || part3 != 0) {
os << ", lsl #0";
}
os << "\n";
} // namespace
if (part1 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part1 << ", lsl #16\n";
}
if (part2 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part2 << ", lsl #32\n";
}
if (part3 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part3 << ", lsl #48\n";
}
}
void PrintAsm(const MachineFunction& function, std::ostream& os) {
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
// 打印单条指令
void PrintInstruction(std::ostream& os, const MachineInstr& instr,
const MachineFunction& function) {
const auto& ops = instr.GetOperands();
switch (instr.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
int64_t size = function.GetFrameSize();
if (IsLegalAddSubImm(size)) {
os << " sub sp, sp, #" << size << "\n";
} else {
PrintLoadImm64(os, PhysReg::X16, size);
os << " sub sp, sp, x16\n";
for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
int64_t size = function.GetFrameSize();
if (IsLegalAddSubImm(size)) {
os << " add sp, sp, #" << size << "\n";
} else {
PrintLoadImm64(os, PhysReg::X16, size);
os << " add sp, sp, x16\n";
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
}
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::MovReg:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::StoreStack: {
// 检查第二个操作数的类型
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
// 存储到栈槽
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接存储:存储到寄存器指向的地址
// STR W9, [X8]
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
} else {
throw std::runtime_error("StoreStack: 无效的操作数类型");
}
break;
}
case Opcode::LoadStack: {
// 检查第二个操作数的类型
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
// 从栈槽加载
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接加载:从寄存器指向的地址加载
// LDR W9, [X8]
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
} else {
throw std::runtime_error("LoadStack: 无效的操作数类型");
}
break;
}
case Opcode::StoreStackPair:
// stp x29, x30, [sp, #-16]!
os << " stp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "]!\n"; // 注意添加 ! 表示 pre-index
break;
case Opcode::LoadStackPair:
// ldp x29, x30, [sp], #16
os << " ldp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp]";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "\n";
break;
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SubRI:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::UDivRR:
os << " udiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FAddRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRI:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::FCmpRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ZExt:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #1\n";
break;
case Opcode::AndRR:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::OrRR:
os << " orr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::EorRR:
os << " eor " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LslRR:
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LsrRR:
os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AsrRR:
os << " asr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::B:
os << " b ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
case Opcode::BCond:
os << " b.";
PrintOperand(os, ops.at(0));
os << " ";
PrintOperand(os, ops.at(1));
os << "\n";
break;
case Opcode::Call:
os << " bl ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
case Opcode::Nop:
os << " nop\n";
break;
case Opcode::Label:
os << ops.at(0).GetLabel() << ":\n";
break;
case Opcode::Movk:
os << " movk " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << ", lsl #" << ops.at(2).GetImm() << "\n";
break;
case Opcode::LoadStackAddr: {
const FrameSlot& slot = GetFrameSlot(function, ops.at(1));
int64_t offset = slot.offset; // 负值,如 -8
PhysReg dst = ops.at(0).GetReg();
auto tryEmitSimple = [&]() -> bool {
if (offset >= 0 && offset <= 4095) {
os << " add " << PhysRegName(dst) << ", x29, #" << offset << "\n";
return true;
} else if (offset < 0 && offset >= -4095) {
os << " sub " << PhysRegName(dst) << ", x29, #" << (-offset) << "\n";
return true;
}
return false;
};
if (tryEmitSimple()) break;
// 复杂偏移
uint64_t absOffset = (offset >= 0) ? offset : -offset;
PrintLoadImm64(os, PhysReg::X16, absOffset);
if (offset >= 0) {
os << " add " << PhysRegName(dst) << ", x29, x16\n";
} else {
os << " sub " << PhysRegName(dst) << ", x29, x16\n";
}
break;
}
case Opcode::Adrp: {
// adrp Xd, label
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetLabel() << "\n";
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::AddLabel: {
// add Xd, Xn, :lo12:label
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", :lo12:"
<< ops.at(2).GetLabel() << "\n";
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
}
case Opcode::Sxtw:
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
default:
os << " // unknown instruction\n";
case Opcode::Ret:
os << " ret\n";
break;
}
}
// 打印单个函数(单函数版本)
void PrintAsm(const MachineFunction& function, std::ostream& os) {
// 输出函数标签
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
// 计算栈帧大小
int frameSize = function.GetFrameSize();
// 输出每个基本块
const auto& blocks = function.GetBasicBlocks();
bool firstBlock = true;
for (const auto& bb : blocks) {
DEBUG_MSG("block");
// 输出基本块标签(非第一个基本块)
if (!firstBlock) {
os << bb->GetName() << ":\n";
}
firstBlock = false;
// 输出基本块中的指令
for (const auto& inst : bb->GetInstructions()) {
DEBUG_MSG("inst");
PrintInstruction(os, inst, function);
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 打印模块(模块版本)
void PrintAsm(const MachineModule& module, std::ostream& os) {
// 输出文件头
os << ".arch armv8-a\n";
// 输出数据段:全局变量
const auto& globals = module.GetGlobals();
if (!globals.empty()) {
os << "\n.data\n";
for (const auto& g : globals) {
os << ".global " << g.name << "\n";
os << ".type " << g.name << ", %object\n";
os << ".align " << g.alignment << "\n";
os << g.name << ":\n";
if (g.is_zero_init) {
os << " .zero " << g.size << "\n";
} else if (g.has_init_data) {
if (g.size == 4) {
os << " .word " << static_cast<uint32_t>(g.init_data) << "\n";
} else if (g.size == 8) {
os << " .quad " << g.init_data << "\n";
} else {
// 暂不支持的标量大小,回退为零初始化
os << " .zero " << g.size << " // unhandled init size\n";
}
} else {
// 有初始值但无法提取(例如数组、结构体)
os << " .zero " << g.size << " // unhandled initializer\n";
}
os << ".size " << g.name << ", " << g.size << "\n\n";
}
}
static const std::set<std::string> externalFuncs = {
"getint", "getch", "getarray", "putint", "putch", "putarray", "puts",
"_sysy_starttime", "_sysy_stoptime", "starttime", "stoptime",
"getfloat", "putfloat", "getfarray", "putfarray", "memset",
"sysy_alloc_i32", "sysy_alloc_f32", "sysy_free_i32", "sysy_free_f32",
"sysy_zero_i32", "sysy_zero_f32"
};
DEBUG_MSG("module");
// 遍历所有函数,输出汇编
for (const auto& func : module.GetFunctions()) {
if (externalFuncs.count(func->GetName())) {
continue; // 跳过库函数桩
}
DEBUG_MSG("func");
PrintAsm(*func, os);
os << "\n";
}
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n";
}
} // namespace mir
} // namespace mir

@ -5,15 +5,6 @@
#include "utils/Log.h"
//#define DEBUG_Frame
#ifdef DEBUG_Frame
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Frame Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace mir {
namespace {
@ -24,49 +15,31 @@ int AlignTo(int value, int align) {
} // namespace
void RunFrameLowering(MachineFunction& function) {
DEBUG_MSG("function RunFrameLowering");
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
}
cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
}
function.SetFrameSize(AlignTo(cursor, 16));
// 基本块
const auto& blocks = function.GetBasicBlocks();
bool firstBlock = true;
for (const auto& bb : blocks) {
DEBUG_MSG("block");
auto& insts = bb->GetInstructions();
std::vector<MachineInstr> lowered;
// 输出基本块标签(非第一个基本块)
if (firstBlock) {
DEBUG_MSG("empalace Prologue");
lowered.emplace_back(Opcode::Prologue);
}
firstBlock = false;
// 输出基本块中的指令
for (const auto& inst : insts) {
DEBUG_MSG("inst");
if (inst.GetOpcode() == Opcode::Ret) {
DEBUG_MSG("empalace Epilogue");
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);
auto& insts = function.GetEntry().GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
}
insts = std::move(lowered);
}
}
// 模块版本的栈帧布局
void RunFrameLowering(MachineModule& module) {
// 对模块中的每个函数执行栈帧布局
DEBUG_MSG("module RunFrameLowering");
for (auto& func : module.GetFunctions()) {
RunFrameLowering(*func);
lowered.push_back(inst);
}
insts = std::move(lowered);
}
} // namespace mir
} // namespace mir

File diff suppressed because it is too large Load Diff

@ -9,8 +9,7 @@ MachineBasicBlock::MachineBasicBlock(std::string name)
MachineInstr& MachineBasicBlock::Append(Opcode opcode,
std::initializer_list<Operand> operands) {
//instructions_.emplace_back(opcode, std::vector<Operand>(operands));
instructions_.emplace_back(opcode, std::move(operands));
instructions_.emplace_back(opcode, std::vector<Operand>(operands));
return instructions_.back();
}

@ -4,25 +4,17 @@
namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label)
: kind_(kind), reg_(reg), imm_(imm), cc_(cc), label_(label) {}
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0, CondCode::EQ, ""); }
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value, CondCode::EQ, "");
return Operand(Kind::Imm, PhysReg::W0, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index, CondCode::EQ, "");
}
Operand Operand::Cond(CondCode cc) {
return Operand(Kind::Cond, PhysReg::W0, 0, cc, "");
}
Operand Operand::Label(const std::string& label) {
return Operand(Kind::Label, PhysReg::W0, 0, CondCode::EQ, label);
return Operand(Kind::FrameIndex, PhysReg::W0, index);
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)

@ -12,8 +12,8 @@ bool IsAllowedReg(PhysReg reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29: //FP = X29 帧指针
case PhysReg::X30: //LR = X30 链接寄存器
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
return true;
}
@ -22,61 +22,15 @@ bool IsAllowedReg(PhysReg reg) {
} // namespace
//void RunRegAlloc(MachineFunction& function) {
// for (const auto& inst : function.GetEntry().GetInstructions()) {
// for (const auto& operand : inst.GetOperands()) {
// if (operand.GetKind() == Operand::Kind::Reg &&
// !IsAllowedReg(operand.GetReg())) {
// throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
// }
// }
// }
//}
// 单函数版本的寄存器分配(原有逻辑)
void RunRegAlloc(MachineFunction& function) {
// 当前仅执行最小一致性检查,不实现真实寄存器分配
// Lab3 阶段保持栈槽模型,不需要真实寄存器分配
// 检查每个基本块中的指令
for (auto& bb : function.GetBasicBlocks()) {
for (auto& instr : bb->GetInstructions()) {
// 检查指令的操作数是否有效
for (const auto& operand : instr.GetOperands()) {
switch (operand.GetKind()) {
case Operand::Kind::Reg:
// 寄存器操作数:检查是否在允许的范围内
// 当前使用固定寄存器 w0, w8, w9, s0, s1 等
break;
case Operand::Kind::FrameIndex:
// 栈槽索引:检查是否有效
if (operand.GetFrameIndex() < 0 ||
operand.GetFrameIndex() >= static_cast<int>(function.GetFrameSlots().size())) {
throw std::runtime_error(
FormatError("regalloc", "无效的栈槽索引: " +
std::to_string(operand.GetFrameIndex())));
}
break;
case Operand::Kind::Imm:
case Operand::Kind::Cond:
case Operand::Kind::Label:
// 立即数、条件码、标签不需要检查
break;
}
for (const auto& inst : function.GetEntry().GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
}
}
}
// 注意Lab3 阶段不实现真实寄存器分配
// 所有值仍然使用栈槽模型,寄存器仅作为临时计算使用
}
// 模块版本的寄存器分配
void RunRegAlloc(MachineModule& module) {
// 对模块中的每个函数执行寄存器分配
for (auto& func : module.GetFunctions()) {
RunRegAlloc(*func);
}
}
} // namespace mir

@ -8,111 +8,20 @@ namespace mir {
const char* PhysRegName(PhysReg reg) {
switch (reg) {
// 32位寄存器
case PhysReg::W0: return "w0";
case PhysReg::W1: return "w1";
case PhysReg::W2: return "w2";
case PhysReg::W3: return "w3";
case PhysReg::W4: return "w4";
case PhysReg::W5: return "w5";
case PhysReg::W6: return "w6";
case PhysReg::W7: return "w7";
case PhysReg::W8: return "w8";
case PhysReg::W9: return "w9";
case PhysReg::W10: return "w10";
case PhysReg::W11: return "w11";
case PhysReg::W12: return "w12";
case PhysReg::W13: return "w13";
case PhysReg::W14: return "w14";
case PhysReg::W15: return "w15";
case PhysReg::W16: return "w16"; // 添加
case PhysReg::W17: return "w17"; // 添加
case PhysReg::W18: return "w18"; // 添加
case PhysReg::W19: return "w19"; // 添加
case PhysReg::W20: return "w20"; // 添加
case PhysReg::W21: return "w21"; // 添加
case PhysReg::W22: return "w22"; // 添加
case PhysReg::W23: return "w23"; // 添加
case PhysReg::W24: return "w24"; // 添加
case PhysReg::W25: return "w25"; // 添加
case PhysReg::W26: return "w26"; // 添加
case PhysReg::W27: return "w27"; // 添加
case PhysReg::W28: return "w28"; // 添加
case PhysReg::W29: return "w29";
case PhysReg::W30: return "w30";
// 64位寄存器
case PhysReg::X0: return "x0";
case PhysReg::X1: return "x1";
case PhysReg::X2: return "x2";
case PhysReg::X3: return "x3";
case PhysReg::X4: return "x4";
case PhysReg::X5: return "x5";
case PhysReg::X6: return "x6";
case PhysReg::X7: return "x7";
case PhysReg::X8: return "x8";
case PhysReg::X9: return "x9";
case PhysReg::X10: return "x10"; // 添加
case PhysReg::X11: return "x11"; // 添加
case PhysReg::X12: return "x12"; // 添加
case PhysReg::X13: return "x13"; // 添加
case PhysReg::X14: return "x14"; // 添加
case PhysReg::X15: return "x15"; // 添加
case PhysReg::X16: return "x16"; // 添加
case PhysReg::X17: return "x17"; // 添加
case PhysReg::X18: return "x18"; // 添加
case PhysReg::X19: return "x19"; // 添加
case PhysReg::X20: return "x20"; // 添加
case PhysReg::X21: return "x21"; // 添加
case PhysReg::X22: return "x22"; // 添加
case PhysReg::X23: return "x23"; // 添加
case PhysReg::X24: return "x24"; // 添加
case PhysReg::X25: return "x25"; // 添加
case PhysReg::X26: return "x26"; // 添加
case PhysReg::X27: return "x27"; // 添加
case PhysReg::X28: return "x28"; // 添加
case PhysReg::X29: return "x29";
case PhysReg::X30: return "x30";
// 浮点寄存器
case PhysReg::S0: return "s0";
case PhysReg::S1: return "s1";
case PhysReg::S2: return "s2";
case PhysReg::S3: return "s3";
case PhysReg::S4: return "s4";
case PhysReg::S5: return "s5";
case PhysReg::S6: return "s6";
case PhysReg::S7: return "s7";
// 特殊寄存器
case PhysReg::SP: return "sp";
case PhysReg::ZR: return "xzr";
default: return "unknown";
case PhysReg::W0:
return "w0";
case PhysReg::W8:
return "w8";
case PhysReg::W9:
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
case PhysReg::SP:
return "sp";
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
}
const char* CondCodeName(CondCode cc) {
switch (cc) {
case CondCode::EQ: return "eq";
case CondCode::NE: return "ne";
case CondCode::CS: return "cs";
case CondCode::CC: return "cc";
case CondCode::MI: return "mi";
case CondCode::PL: return "pl";
case CondCode::VS: return "vs";
case CondCode::VC: return "vc";
case CondCode::HI: return "hi";
case CondCode::LS: return "ls";
case CondCode::GE: return "ge";
case CondCode::LT: return "lt";
case CondCode::GT: return "gt";
case CondCode::LE: return "le";
case CondCode::AL: return "al";
default: return "unknown";
}
throw std::runtime_error(FormatError("mir", "未知条件码"));
}
} // namespace mir

@ -4,21 +4,11 @@
#include <stdexcept>
#include <string>
#include <sstream>
#include <iostream>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
//#define DEBUG_SEMA
#ifdef DEBUG_SEMA
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Sema Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace {
// 获取左值名称的辅助函数
@ -77,9 +67,10 @@ public:
} else {
return_type = ir::Type::GetInt32Type();
}
DEBUG_MSG("[DEBUG] 进入函数: " << name
DebugStream() << "[DEBUG] 进入函数: " << name
<< " 返回类型: " << (return_type->IsInt32() ? "int" :
return_type->IsFloat() ? "float" : "void"));
return_type->IsFloat() ? "float" : "void")
<< std::endl;
// 记录当前函数返回类型(用于 return 检查)
current_func_return_type_ = return_type;
@ -92,9 +83,10 @@ public:
if (ctx->block()) { // 处理函数体
ctx->block()->accept(this);
}
DEBUG_MSG("[DEBUG] 函数 " << name
DebugStream() << "[DEBUG] 函数 " << name
<< " has_return: " << current_func_has_return_
<< " return_type_is_void: " << return_type->IsVoid());
<< " return_type_is_void: " << return_type->IsVoid()
<< std::endl;
if (!return_type->IsVoid() && !current_func_has_return_) { // 检查非 void 函数是否有 return
throw std::runtime_error(FormatError("sema", "非 void 函数 " + name + " 缺少 return 语句"));
}
@ -178,10 +170,10 @@ public:
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
// 调试输出
DEBUG_MSG("[DEBUG] CheckVarDef: " << name
DebugStream() << "[DEBUG] CheckVarDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size());
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
// 处理数组维度
for (auto* dim_exp : ctx->constExp()) {
@ -193,24 +185,26 @@ public:
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
DEBUG_MSG("[DEBUG] dim[" << dims.size() - 1 << "] = " << dim);
DebugStream() << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
// 创建数组类型
type = ir::Type::GetArrayType(base_type, dims);
DEBUG_MSG("[DEBUG] 创建数组类型完成");
DEBUG_MSG("[DEBUG] type->IsArray(): " << type->IsArray());
DEBUG_MSG("[DEBUG] type->GetKind(): " << (int)type->GetKind());
DebugStream() << "[DEBUG] 创建数组类型完成" << std::endl;
DebugStream() << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl;
DebugStream() << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl;
// 验证数组类型
if (type->IsArray()) {
auto* arr_type = dynamic_cast<ir::ArrayType*>(type.get());
if (arr_type) {
DEBUG_MSG("[DEBUG] ArrayType dimensions: ");
DebugStream() << "[DEBUG] ArrayType dimensions: ";
for (int d : arr_type->GetDimensions()) {
DEBUG_MSG(d << " ");
DebugStream() << d << " ";
}
DEBUG_MSG("[DEBUG] Element type: "
DebugStream() << std::endl;
DebugStream() << "[DEBUG] Element type: "
<< (arr_type->GetElementType()->IsInt32() ? "int" :
arr_type->GetElementType()->IsFloat() ? "float" : "unknown"));
arr_type->GetElementType()->IsFloat() ? "float" : "unknown")
<< std::endl;
}
}
}
@ -236,10 +230,10 @@ public:
sym.param_types.clear(); // 确保不混淆
}
table_.addSymbol(sym); // 添加到符号表
DEBUG_MSG("[DEBUG] 符号添加完成: " << name
DebugStream() << "[DEBUG] 符号添加完成: " << name
<< " type_kind: " << (int)sym.type->GetKind()
<< " is_array: " << sym.type->IsArray()
);
<< std::endl;
}
void CheckConstDef(SysYParser::ConstDefContext* ctx,
@ -256,10 +250,10 @@ public:
std::shared_ptr<ir::Type> type = base_type;
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
DEBUG_MSG("[DEBUG] CheckConstDef: " << name
DebugStream() << "[DEBUG] CheckConstDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size());
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
for (auto* dim_exp : ctx->constExp()) {
@ -268,10 +262,10 @@ public:
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
DEBUG_MSG("[DEBUG] dim[" << dims.size() - 1 << "] = " << dim);
DebugStream() << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
type = ir::Type::GetArrayType(base_type, dims);
DEBUG_MSG("[DEBUG] 创建数组类型完成IsArray: " << type->IsArray());
DebugStream() << "[DEBUG] 创建数组类型完成IsArray: " << type->IsArray() << std::endl;
}
// ========== 绑定维度表达式 ==========
@ -286,7 +280,7 @@ public:
BindConstInitVal(ctx->constInitVal());
init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type);
DEBUG_MSG("[DEBUG] 初始化值数量: " << init_values.size());
DebugStream() << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl;
}
// 计算期望的元素数量
@ -294,12 +288,12 @@ public:
if (is_array) {
expected_count = 1;
for (int d : dims) expected_count *= d;
DEBUG_MSG("[DEBUG] 期望元素数量: " << expected_count);
DebugStream() << "[DEBUG] 期望元素数量: " << expected_count << std::endl;
}
// 如果初始化值不足,补零
if (is_array && init_values.size() < expected_count) {
DEBUG_MSG("[DEBUG] 初始化值不足,补零");
DebugStream() << "[DEBUG] 初始化值不足,补零" << std::endl;
SymbolTable::ConstValue zero;
if (base_type->IsInt32()) {
zero.kind = SymbolTable::ConstValue::INT;
@ -320,13 +314,13 @@ public:
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Constant;
DEBUG_MSG("CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind);
DebugStream() << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl;
sym.type = type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = true;
sym.var_def_ctx = nullptr;
sym.const_def_ctx = ctx;
DEBUG_MSG("保存常量定义上下文: " << name << ", ctx: " << ctx);
DebugStream() << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl;
// ========== 存储常量值 ==========
if (is_array) {
@ -344,19 +338,19 @@ public:
sym.array_const_values.push_back(cv);
}
DEBUG_MSG("[DEBUG] 存储数组常量,共 " << sym.array_const_values.size()
<< " 个元素");
DebugStream() << "[DEBUG] 存储数组常量,共 " << sym.array_const_values.size()
<< " 个元素" << std::endl;
} else if (!init_values.empty()) {
// 存储标量常量
if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) {
sym.is_int_const = true;
sym.const_value.i32 = init_values[0].int_val;
DEBUG_MSG("[DEBUG] 存储整型常量: " << init_values[0].int_val);
DebugStream() << "[DEBUG] 存储整型常量: " << init_values[0].int_val << std::endl;
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
sym.is_int_const = false;
sym.const_value.f32 = init_values[0].float_val;
DEBUG_MSG("[DEBUG] 存储浮点常量: " << init_values[0].float_val);
DebugStream() << "[DEBUG] 存储浮点常量: " << init_values[0].float_val << std::endl;
} else if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
// 整型常量用浮点数初始化(需要检查是否为整数)
float f = init_values[0].float_val;
@ -367,30 +361,30 @@ public:
}
sym.is_int_const = true;
sym.const_value.i32 = i;
DEBUG_MSG("[DEBUG] 浮点转整型常量: " << f << " -> " << i);
DebugStream() << "[DEBUG] 浮点转整型常量: " << f << " -> " << i << std::endl;
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::INT) {
// 浮点常量用整型初始化,隐式转换
sym.is_int_const = false;
sym.const_value.f32 = static_cast<float>(init_values[0].int_val);
DEBUG_MSG("[DEBUG] 整型转浮点常量: " << init_values[0].int_val
<< " -> " << static_cast<float>(init_values[0].int_val));
DebugStream() << "[DEBUG] 整型转浮点常量: " << init_values[0].int_val
<< " -> " << static_cast<float>(init_values[0].int_val) << std::endl;
}
} else {
// 没有初始化值,对于标量常量这是错误的
if (!is_array) {
throw std::runtime_error(FormatError("sema", "常量必须有初始化值: " + name));
}
DEBUG_MSG("[DEBUG] 数组常量无初始化器,将全部补零");
DebugStream() << "[DEBUG] 数组常量无初始化器,将全部补零" << std::endl;
}
table_.addSymbol(sym);
DEBUG_MSG("CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind);
DebugStream() << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl;
auto* stored = table_.lookup(name);
DEBUG_MSG("CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx);
DebugStream() << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl;
DEBUG_MSG("[DEBUG] 常量符号添加完成: " << name
DebugStream() << "[DEBUG] 常量符号添加完成: " << name
<< " is_array_const: " << sym.is_array_const
<< " element_count: " << sym.array_const_values.size());
<< " element_count: " << sym.array_const_values.size() << std::endl;
}
// ==================== 常量声明 ====================
@ -413,19 +407,20 @@ public:
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) return {};
// 调试输出
DEBUG_MSG("[DEBUG] visitStmt: ");
if (ctx->Return()) DEBUG_MSG("Return ");
if (ctx->If()) DEBUG_MSG("If ");
if (ctx->While()) DEBUG_MSG("While ");
if (ctx->Break()) DEBUG_MSG("Break ");
if (ctx->Continue()) DEBUG_MSG("Continue ");
if (ctx->lVal() && ctx->Assign()) DEBUG_MSG("Assign ");
if (ctx->exp() && ctx->Semi()) DEBUG_MSG("ExpStmt ");
if (ctx->block()) DEBUG_MSG("Block ");
DebugStream() << "[DEBUG] visitStmt: ";
if (ctx->Return()) DebugStream() << "Return ";
if (ctx->If()) DebugStream() << "If ";
if (ctx->While()) DebugStream() << "While ";
if (ctx->Break()) DebugStream() << "Break ";
if (ctx->Continue()) DebugStream() << "Continue ";
if (ctx->lVal() && ctx->Assign()) DebugStream() << "Assign ";
if (ctx->exp() && ctx->Semi()) DebugStream() << "ExpStmt ";
if (ctx->block()) DebugStream() << "Block ";
DebugStream() << std::endl;
// 判断语句类型 - 注意Return() 返回的是 TerminalNode*
if (ctx->Return() != nullptr) {
// return 语句
DEBUG_MSG("[DEBUG] 检测到 return 语句");
DebugStream() << "[DEBUG] 检测到 return 语句" << std::endl;
return visitReturnStmtInternal(ctx);
} else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) {
// 赋值语句
@ -454,14 +449,14 @@ public:
// return 语句内部实现
std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG] visitReturnStmtInternal 被调用");
DebugStream() << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl;
std::shared_ptr<ir::Type> expected = current_func_return_type_;
if (!expected) {
throw std::runtime_error(FormatError("sema", "return 语句不在函数体内"));
}
if (ctx->exp() != nullptr) {
// 有返回值的 return
DEBUG_MSG("[DEBUG] 有返回值的 return");
DebugStream() << "[DEBUG] 有返回值的 return" << std::endl;
ExprInfo ret_val = CheckExp(ctx->exp());
if (expected->IsVoid()) {
throw std::runtime_error(FormatError("sema", "void 函数不能返回值"));
@ -474,23 +469,23 @@ public:
}
// 设置 has_return 标志
current_func_has_return_ = true;
DEBUG_MSG("[DEBUG] 设置 current_func_has_return_ = true");
DebugStream() << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl;
} else {
// 无返回值的 return
DEBUG_MSG("[DEBUG] 无返回值的 return");
DebugStream() << "[DEBUG] 无返回值的 return" << std::endl;
if (!expected->IsVoid()) {
throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值"));
}
// 设置 has_return 标志
current_func_has_return_ = true;
DEBUG_MSG("[DEBUG] 设置 current_func_has_return_ = true");
DebugStream() << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl;
}
return {};
}
// 左值表达式(变量引用)
std::any visitLVal(SysYParser::LValContext* ctx) override {
DEBUG_MSG("[DEBUG] visitLVal: " << ctx->getText());
DebugStream() << "[DEBUG] visitLVal: " << ctx->getText() << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
@ -501,17 +496,17 @@ public:
}
// 检查数组访问
bool is_array_access = !ctx->exp().empty();
DEBUG_MSG("[DEBUG] name: " << name
DebugStream() << "[DEBUG] name: " << name
<< ", is_array_access: " << is_array_access
<< ", subscript_count: " << ctx->exp().size());
<< ", subscript_count: " << ctx->exp().size() << std::endl;
ExprInfo result;
// 判断是否为数组类型或指针类型(数组参数)
bool is_array_or_ptr = false;
if (sym->type) {
is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat();
DEBUG_MSG("[DEBUG] type_kind: " << (int)sym->type->GetKind()
DebugStream() << "[DEBUG] type_kind: " << (int)sym->type->GetKind()
<< ", is_array: " << sym->type->IsArray()
<< ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()));
<< ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) << std::endl;
}
if (is_array_or_ptr) {
@ -522,7 +517,7 @@ public:
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
dim_count = arr_type->GetDimensions().size();
elem_type = arr_type->GetElementType();
DEBUG_MSG("[DEBUG] 数组维度: " << dim_count);
DebugStream() << "[DEBUG] 数组维度: " << dim_count << std::endl;
}
} else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) {
dim_count = 1;
@ -531,12 +526,12 @@ public:
} else if (sym->type->IsPtrFloat()) {
elem_type = ir::Type::GetFloatType();
}
DEBUG_MSG("[DEBUG] 指针类型, dim_count: 1");
DebugStream() << "[DEBUG] 指针类型, dim_count: 1" << std::endl;
}
if (is_array_access) {
DEBUG_MSG("[DEBUG] 有下标访问,期望维度: " << dim_count
<< ", 实际下标数: " << ctx->exp().size());
DebugStream() << "[DEBUG] 有下标访问,期望维度: " << dim_count
<< ", 实际下标数: " << ctx->exp().size() << std::endl;
if (ctx->exp().size() != dim_count) {
throw std::runtime_error(FormatError("sema", "数组下标个数不匹配"));
}
@ -550,9 +545,9 @@ public:
result.is_lvalue = true;
result.is_const = false;
} else {
DEBUG_MSG("[DEBUG] 无下标访问");
DebugStream() << "[DEBUG] 无下标访问" << std::endl;
if (sym->type->IsArray()) {
DEBUG_MSG("[DEBUG] 数组名作为地址,转换为指针");
DebugStream() << "[DEBUG] 数组名作为地址,转换为指针" << std::endl;
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
if (arr_type->GetElementType()->IsInt32()) {
result.type = ir::Type::GetPtrInt32Type();
@ -674,7 +669,7 @@ public:
// 主表达式
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
DEBUG_MSG("[DEBUG] visitPrimaryExp: " << ctx->getText());
DebugStream() << "[DEBUG] visitPrimaryExp: " << ctx->getText() << std::endl;
ExprInfo result;
if (ctx->lVal()) { // 左值表达式
result = CheckLValue(ctx->lVal());
@ -706,14 +701,14 @@ public:
// 一元表达式
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
DEBUG_MSG("[DEBUG] visitUnaryExp: " << ctx->getText());
DebugStream() << "[DEBUG] visitUnaryExp: " << ctx->getText() << std::endl;
ExprInfo result;
if (ctx->primaryExp()) {
ctx->primaryExp()->accept(this);
auto* info = sema_.GetExprType(ctx->primaryExp());
if (info) result = *info;
} else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用
DEBUG_MSG("[DEBUG] 函数调用: " << ctx->Ident()->getText());
DebugStream() << "[DEBUG] 函数调用: " << ctx->Ident()->getText() << std::endl;
result = CheckFuncCall(ctx);
} else if (ctx->unaryOp()) { // 一元运算
ctx->unaryExp()->accept(this);
@ -1079,8 +1074,8 @@ public:
// 新增:同时返回两者
SemaResult TakeResult() {
DEBUG_MSG("[DEBUG] TakeResult 前: 符号表作用域数量 = "
<< table_.getScopeCount());
DebugStream() << "[DEBUG] TakeResult 前: 符号表作用域数量 = "
<< table_.getScopeCount() << std::endl;
// 可选:打印符号表内容
// table_.dump();
@ -1089,8 +1084,8 @@ public:
result.context = std::move(sema_);
result.symbol_table = std::move(table_);
DEBUG_MSG("[DEBUG] TakeResult 后: 符号表作用域数量 = "
<< result.symbol_table.getScopeCount());
DebugStream() << "[DEBUG] TakeResult 后: 符号表作用域数量 = "
<< result.symbol_table.getScopeCount() << std::endl;
return result;
}
@ -1111,7 +1106,7 @@ private:
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "无效表达式"));
}
DEBUG_MSG("[DEBUG] CheckExp: " << ctx->getText());
DebugStream() << "[DEBUG] CheckExp: " << ctx->getText() << std::endl;
ctx->addExp()->accept(this);
auto* info = sema_.GetExprType(ctx->addExp());
if (!info) {
@ -1162,21 +1157,21 @@ private:
if (!sym) {
throw std::runtime_error(FormatError("sema", "未定义的变量: " + name));
}
DEBUG_MSG("CheckLValue: found sym->name = " << sym->name
<< ", sym->kind = " << (int)sym->kind);
DebugStream() << "CheckLValue: found sym->name = " << sym->name
<< ", sym->kind = " << (int)sym->kind << std::endl;
if (sym->kind == SymbolKind::Variable && sym->var_def_ctx) {
sema_.BindVarUse(ctx, sym->var_def_ctx);
DEBUG_MSG("绑定变量: " << name << " -> VarDefContext");
DebugStream() << "绑定变量: " << name << " -> VarDefContext" << std::endl;
}
else if (sym->kind == SymbolKind::Constant && sym->const_def_ctx) {
sema_.BindConstUse(ctx, sym->const_def_ctx);
DEBUG_MSG("绑定常量: " << name << " -> ConstDefContext");
DebugStream() << "绑定常量: " << name << " -> ConstDefContext" << std::endl;
}
DEBUG_MSG("CheckLValue 绑定变量: " << name
DebugStream() << "CheckLValue 绑定变量: " << name
<< ", sym->kind: " << (int)sym->kind
<< ", sym->var_def_ctx: " << sym->var_def_ctx
<< ", sym->const_def_ctx: " << sym->const_def_ctx);
<< ", sym->const_def_ctx: " << sym->const_def_ctx << std::endl;
bool is_array_access = !ctx->exp().empty();
bool is_const = (sym->kind == SymbolKind::Constant);
@ -1208,8 +1203,9 @@ private:
} else if (sym->type->IsPtrFloat()) {
elem_type = ir::Type::GetFloatType();
}
DEBUG_MSG("数组参数维度: " << dim_count << " 维, dims: ");
for (int d : dims) DEBUG_MSG(d << " ");
DebugStream() << "数组参数维度: " << dim_count << " 维, dims: ";
for (int d : dims) DebugStream() << d << " ";
DebugStream() << std::endl;
} else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) {
// 普通指针,只能有一个下标
dim_count = 1;
@ -1222,7 +1218,7 @@ private:
size_t subscript_count = ctx->exp().size();
DEBUG_MSG("dim_count: " << dim_count << ", subscript_count: " << subscript_count);
DebugStream() << "dim_count: " << dim_count << ", subscript_count: " << subscript_count << std::endl;
if (dim_count > 0 || sym->is_array_param || sym->type->IsArray() ||
sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) {
@ -1243,11 +1239,11 @@ private:
if (subscript_count == dim_count) {
// 完全索引,返回元素类型
DEBUG_MSG("完全索引,返回元素类型");
DebugStream() << "完全索引,返回元素类型" << std::endl;
return {elem_type, true, false};
} else {
// 部分索引,返回子数组的指针类型
DEBUG_MSG("部分索引,返回指针类型");
DebugStream() << "部分索引,返回指针类型" << std::endl;
// 计算剩余维度的指针类型
if (elem_type->IsInt32()) {
return {ir::Type::GetPtrInt32Type(), false, false};
@ -1261,7 +1257,7 @@ private:
// 没有下标访问
if (sym->type && sym->type->IsArray()) {
// 数组名作为地址
DEBUG_MSG("数组名作为地址");
DebugStream() << "数组名作为地址" << std::endl;
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
if (arr_type->GetElementType()->IsInt32()) {
return {ir::Type::GetPtrInt32Type(), false, true};
@ -1272,7 +1268,7 @@ private:
return {ir::Type::GetPtrInt32Type(), false, true};
} else if (sym->is_array_param) {
// 数组参数名作为地址
DEBUG_MSG("数组参数名作为地址");
DebugStream() << "数组参数名作为地址" << std::endl;
if (sym->type->IsPtrInt32()) {
return {ir::Type::GetPtrInt32Type(), false, true};
} else {
@ -1296,14 +1292,14 @@ private:
throw std::runtime_error(FormatError("sema", "非法函数调用"));
}
std::string func_name = ctx->Ident()->getText();
DEBUG_MSG("[DEBUG] CheckFuncCall: " << func_name);
DebugStream() << "[DEBUG] CheckFuncCall: " << func_name << std::endl;
auto* func_sym = table_.lookup(func_name);
if (!func_sym || func_sym->kind != SymbolKind::Function) {
throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name));
}
std::vector<ExprInfo> args;
if (ctx->funcRParams()) {
DEBUG_MSG("[DEBUG] 处理函数调用参数:");
DebugStream() << "[DEBUG] 处理函数调用参数:" << std::endl;
for (auto* exp : ctx->funcRParams()->exp()) {
if (exp) {
args.push_back(CheckExp(exp));
@ -1314,8 +1310,8 @@ private:
throw std::runtime_error(FormatError("sema", "参数个数不匹配"));
}
for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) {
DEBUG_MSG("[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind()
<< " 形参类型 " << (int)func_sym->param_types[i]->GetKind());
DebugStream() << "[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind()
<< " 形参类型 " << (int)func_sym->param_types[i]->GetKind() << std::endl;
if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) {
throw std::runtime_error(FormatError("sema", "参数类型不匹配"));
}
@ -1515,8 +1511,10 @@ private:
sym.array_dims = dims;
table_.addSymbol(sym);
DEBUG_MSG("[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind());
for (int d : dims) DEBUG_MSG(d << " ");
DebugStream() << "[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind()
<< " is_array: " << is_array << " dims: ";
for (int d : dims) DebugStream() << d << " ";
DebugStream() << std::endl;
}
}

@ -1,4 +1,5 @@
#include "sem/SymbolTable.h"
#include "utils/Log.h"
#include <antlr4-runtime.h> // 用于访问父节点
#include <cctype>
#include <stdexcept>
@ -6,11 +7,11 @@
#include <cmath>
#include <functional>
//#define DEBUG_SYMBOL_TABLE
#define DEBUG_SYMBOL_TABLE
#ifdef DEBUG_SYMBOL_TABLE
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[SymbolTable Debug] " << msg << std::endl
#define DEBUG_MSG(msg) DebugStream() << "[SymbolTable Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
@ -48,9 +49,10 @@ bool SymbolTable::addSymbol(const Symbol& sym) {
// 立即验证存储的符号
const auto& stored = current_scope[sym.name];
DEBUG_MSG("SymbolTable::addSymbol: stored " << sym.name
DebugStream() << "SymbolTable::addSymbol: stored " << sym.name
<< " with kind=" << (int)stored.kind
<< ", const_def_ctx=" << stored.const_def_ctx);
<< ", const_def_ctx=" << stored.const_def_ctx
<< std::endl;
return true;
}
@ -68,10 +70,11 @@ const Symbol* SymbolTable::lookup(const std::string& name) const {
const auto& scope = scopes_[*it];
auto found = scope.find(name);
if (found != scope.end()) {
DEBUG_MSG("SymbolTable::lookup: found " << name
DebugStream() << "SymbolTable::lookup: found " << name
<< " in active scope index " << *it
<< ", kind=" << (int)found->second.kind
<< ", const_def_ctx=" << found->second.const_def_ctx);
<< ", const_def_ctx=" << found->second.const_def_ctx
<< std::endl;
return &found->second;
}
}

@ -58,6 +58,11 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue;
}
if (std::strcmp(arg, "--debug") == 0) {
opt.debug = true;
continue;
}
if (arg[0] == '-') {
throw std::runtime_error(
FormatError("cli", std::string("未知参数: ") + arg +

@ -2,17 +2,44 @@
#include "utils/Log.h"
#include <iostream>
#include <ostream>
#include <streambuf>
#include <string>
bool g_debug_enabled = false;
namespace {
class NullBuffer : public std::streambuf {
protected:
int overflow(int c) override { return c; }
};
std::ostream& NullStream() {
static NullBuffer null_buffer;
static std::ostream null_stream(&null_buffer);
return null_stream;
}
bool IsCLIError(const std::string_view msg) {
return HasErrorPrefix(msg, "cli");
}
} // namespace
void SetDebugEnabled(bool enabled) {
g_debug_enabled = enabled;
}
bool IsDebugEnabled() {
return g_debug_enabled;
}
std::ostream& DebugStream() {
return g_debug_enabled ? std::cerr : NullStream();
}
void LogInfo(const std::string_view msg, std::ostream& os) {
os << "[info] " << msg << "\n";
}
@ -57,6 +84,7 @@ void PrintHelp(std::ostream& os) {
<< " --emit-parse-tree 仅在显式模式下启用语法树输出\n"
<< " --emit-ir 仅在显式模式下启用 IR 输出\n"
<< " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n"
<< " --debug 启用调试日志输出\n"
<< "\n"
<< "说明:\n"
<< " - 默认输出 IR\n"

@ -1,5 +0,0 @@
编译libsysy.a
```
aarch64-linux-gnu-gcc -c sylib.c -o sylib.o
aarch64-linux-gnu-ar rcs libsysy.a sylib.o
```
Loading…
Cancel
Save