Compare commits

...

17 Commits

Author SHA1 Message Date
ptabmhn4l 23c274eab6 Merge pull request '完成lab4' (#9) from ptabmhn4l/nudt-compiler-cpp:develop into develop
5 days ago
Junhe Wu 99826566e6 Merge branch 'feat/ir-opt' into develop
5 days ago
Junhe Wu 19928c4945 feat(ir-opt): 完成了lab4
5 days ago
Junhe Wu 827558938b fix(sylib): 使用官方提供的库文件
5 days ago
Junhe Wu c7e8b28d29 fix(testdata): 添加了2026年的测试用例
5 days ago
Junhe Wu e3de2c59af fix(ir): 修复了最后通不过的测试样例。
5 days ago
Junhe Wu 3c6ffe8e3e fix(ir):修复了一些ir的错误
3 weeks ago
ppxf25tqu de126b93d6 Merge pull request 'feat(mir):修正并完善功能' (#7) from pt9wfaocb/nudt-compiler-cpp:tansiping into develop
3 weeks ago
tansiping 310c7c3697 feat(mir):修正并完善功能
3 weeks ago
ptabmhn4l 248db05cf4 Merge pull request 'feat(mir):实现MIR后端' (#6) from pfwvrotsf/nudt-compiler-cpp:feature/mir into develop
3 weeks ago
cy feaba9abd4 fix(mir):修正测试用例
3 weeks ago
cy 1ff1b543d1 feat(mir): MIR 后端(RISC-V架构)
3 weeks ago
ptabmhn4l 80c46cee7e Merge pull request '初步通过verify测试' (#5) from ptabmhn4l/nudt-compiler-cpp:fix/irgen into develop
1 month ago
Junhe Wu 19ef82738f fix(irgen):通过了除了性能测试外的测试用例。
1 month ago
Junhe Wu 4693253459 Merge branch 'develop' of https://bdgit.educoder.net/ppxf25tqu/nudt-compiler-cpp into develop
1 month ago
ptabmhn4l fd45b74e2e Merge pull request '基本完成了ir生成' (#4) from ptabmhn4l/nudt-compiler-cpp:feature/ir into develop
1 month ago
ptabmhn4l 74bcb45776 Merge pull request '把比赛的测试用例放进来' (#2) from ptabmhn4l/nudt-compiler-cpp:fix/testdata into develop
1 month ago

1
.gitignore vendored

@ -54,6 +54,7 @@ compile_commands.json
.fleet/
.vs/
*.code-workspace
CLAUDE.md
# CLion
cmake-build-debug/

@ -0,0 +1,3 @@
bash scripts/run_ir_test.sh --run # 优化模式,计时
bash scripts/run_ir_test.sh --run --O0 # 无优化,计时
bash scripts/bench_ir.sh # 同时对比 O0 vs O1

@ -15,6 +15,7 @@
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -60,12 +61,14 @@ class Context {
~Context();
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
std::string NextTemp();
std::string NextTemp(); // 用于指令名(数字,连续)
std::string NextLabel(); // 用于块名(字母前缀,独立计数)
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1;
int label_index_ = -1;
};
// ─── Type ─────────────────────────────────────────────────────────────────────
@ -160,6 +163,8 @@ enum class Opcode {
Gep,
// 控制流
Ret, Br, CondBr,
// PHI 节点
Phi,
// 函数调用
Call,
// 类型转换
@ -198,16 +203,30 @@ class GlobalValue : public User {
class GlobalVariable : public Value {
public:
GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements = 1);
int num_elements = 1, bool is_array_decl = false,
bool is_float = false);
bool IsConst() const { return is_const_; }
bool IsFloat() const { return is_float_; }
int GetInitVal() const { return init_val_; }
float GetInitValF() const { return init_val_f_; }
int GetNumElements() const { return num_elements_; }
bool IsArray() const { return num_elements_ > 1; }
// GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store
bool IsArray() const { return is_array_decl_ || num_elements_ > 1; }
void SetInitVals(std::vector<int> v) { init_vals_ = std::move(v); }
void SetInitValsF(std::vector<float> v) { init_vals_f_ = std::move(v); }
void SetInitValF(float v){ init_val_f_ = v; }
void SetInitVal(int v){ init_val_ = v;}
const std::vector<int>& GetInitVals() const { return init_vals_; }
const std::vector<float>& GetInitValsF() const { return init_vals_f_; }
bool HasInitVals() const { return !init_vals_.empty() || !init_vals_f_.empty(); }
private:
bool is_const_;
bool is_float_;
int init_val_;
float init_val_f_;
int num_elements_;
bool is_array_decl_;
std::vector<int> init_vals_;
std::vector<float> init_vals_f_;
};
// ─── Instruction ──────────────────────────────────────────────────────────────
@ -218,6 +237,7 @@ class Instruction : public User {
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
void RemoveFromParent();
private:
Opcode opcode_;
@ -358,6 +378,18 @@ class StoreInst : public Instruction {
Value* GetPtr() const;
};
// PHI 节点:在控制流汇合处选择值
// 操作数布局:[val0, bb0, val1, bb1, ...](偶数下标为值,奇数下标为基本块)
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* val, BasicBlock* bb);
size_t GetNumIncoming() const { return GetNumOperands() / 2; }
Value* GetIncomingValue(size_t i) const { return GetOperand(i * 2); }
BasicBlock* GetIncomingBlock(size_t i) const;
void SetIncomingValue(size_t i, Value* val) { SetOperand(i * 2, val); }
};
// ─── BasicBlock ───────────────────────────────────────────────────────────────
class BasicBlock : public Value {
public:
@ -368,6 +400,16 @@ class BasicBlock : public Value {
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* bb);
void RemovePredecessor(BasicBlock* bb);
void ClearPredecessors();
void AddSuccessor(BasicBlock* bb);
void ClearSuccessors();
// 指令管理
void RemoveInstruction(Instruction* inst);
// 在 before 之前插入指令before 为 nullptr 时追加到末尾
void InsertBefore(Instruction* inst, Instruction* before);
template <typename T, typename... Args>
T* Append(Args&&... args) {
@ -381,6 +423,29 @@ class BasicBlock : public Value {
return ptr;
}
template <typename T, typename... Args>
T* Prepend(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr;
}
// Insert before terminator (or append if no terminator)
template <typename T, typename... Args>
T* InsertBeforeTerminator(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
instructions_.insert(instructions_.end() - 1, std::move(inst));
} else {
instructions_.push_back(std::move(inst));
}
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
@ -409,6 +474,12 @@ class Function : public Value {
Argument* GetArgument(size_t i) const;
size_t GetNumArgs() const { return args_.size(); }
bool IsVoidReturn() const { return type_->IsVoid(); }
// 将某个块移动到 blocks_ 列表末尾(用于确保块顺序正确)
void MoveBlockToEnd(BasicBlock* bb);
// 重建 CFG根据终结指令计算所有块的前驱/后继
void RebuildCFG();
// 从函数中移除一个基本块
void RemoveBlock(BasicBlock* bb);
private:
BasicBlock* entry_ = nullptr;
@ -437,7 +508,9 @@ class Module {
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const,
int init_val, int num_elements = 1);
int init_val, int num_elements = 1,
bool is_array_decl = false,
bool is_float = false);
GlobalVariable* GetGlobalVariable(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const;
@ -494,9 +567,12 @@ class IRBuilder {
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaF32(const std::string& name);
AllocaInst* CreateAllocaArray(int num_elements, const std::string& name);
AllocaInst* CreateAllocaArrayF32(int num_elements, const std::string& name);
GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
// 零初始化数组emit memset call
void CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod);
// 控制流
ReturnInst* CreateRet(Value* v);
@ -504,6 +580,8 @@ class IRBuilder {
BrInst* CreateBr(BasicBlock* target);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_bb,
BasicBlock* false_bb);
// PHI 节点(添加到当前块开头)
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
// 调用
CallInst* CreateCall(Function* callee, std::vector<Value*> args,
@ -518,9 +596,31 @@ class IRBuilder {
SIToFPInst* CreateSIToFP(Value* val, const std::string& name);
FPToSIInst* CreateFPToSI(Value* val, const std::string& name);
void SetAllocaBlock(BasicBlock* bb) { alloca_block_ = bb; }
private:
Context& ctx_;
BasicBlock* insert_block_;
BasicBlock* alloca_block_ = nullptr;
};
// ─── DominatorTree ────────────────────────────────────────────────────────────
class DominatorTree {
public:
void Compute(Function& func);
BasicBlock* GetIDom(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetChildren(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetDominanceFrontier(BasicBlock* bb) const;
bool Dominates(BasicBlock* a, BasicBlock* b) const;
const std::vector<BasicBlock*>& GetDFOrder() const { return df_order_; }
private:
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
std::unordered_map<BasicBlock*, size_t> dom_level_;
std::vector<BasicBlock*> df_order_;
std::unordered_set<BasicBlock*> visited_;
};
// ─── IRPrinter ────────────────────────────────────────────────────────────────
@ -529,4 +629,7 @@ class IRPrinter {
void Print(const Module& module, std::ostream& os);
};
// ─── Pass Manager ────────────────────────────────────────────────────────────
void RunPasses(Module& module);
} // namespace ir

@ -113,5 +113,4 @@ class IRGenImpl final : public SysYBaseVisitor {
std::vector<LoopCtx> loop_stack_;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema);
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, const SemanticContext& sema);

@ -19,39 +19,161 @@ class MIRContext {
MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP };
// RISC-V 64位寄存器定义
enum class PhysReg {
// 通用寄存器
ZERO, // x0, 恒为0
RA, // x1, 返回地址
SP, // x2, 栈指针
GP, // x3, 全局指针
TP, // x4, 线程指针
T0, // x5, 临时寄存器
T1, // x6, 临时寄存器
T2, // x7, 临时寄存器
S0, // x8, 帧指针/保存寄存器
S1, // x9, 保存寄存器
A0, // x10, 参数/返回值
A1, // x11, 参数
A2, // x12, 参数
A3, // x13, 参数
A4, // x14, 参数
A5, // x15, 参数
A6, // x16, 参数
A7, // x17, 参数
S2, // x18, 保存寄存器
S3, // x19, 保存寄存器
S4, // x20, 保存寄存器
S5, // x21, 保存寄存器
S6, // x22, 保存寄存器
S7, // x23, 保存寄存器
S8, // x24, 保存寄存器
S9, // x25, 保存寄存器
S10, // x26, 保存寄存器
S11, // x27, 保存寄存器
T3, // x28, 临时寄存器
T4, // x29, 临时寄存器
T5, // x30, 临时寄存器
T6, // x31, 临时寄存器
FT0, FT1, FT2, FT3, FT4, FT5, FT6, FT7,
FS0, FS1,
FA0, FA1, FA2, FA3, FA4, FA5, FA6, FA7,
FT8, FT9, FT10, FT11,
};
const char* PhysRegName(PhysReg reg);
// 在 MIR.h 中添加(在 Opcode 枚举之前)
struct GlobalVarInfo {
std::string name;
int value;
float valueF;
bool isConst;
bool isArray;
bool isFloat;
std::vector<int> arrayValues;
std::vector<float> arrayValuesF;
int arraySize;
};
enum class Opcode {
Prologue,
Epilogue,
MovImm,
LoadStack,
StoreStack,
AddRR,
Load,
Store,
Add,
Addi,
Sub,
Mul,
Div,
Rem,
Slt,
Slti,
Slli,
Sltu, // 无符号小于
Sltiu,
Xori,
LoadGlobalAddr,
LoadGlobal,
StoreGlobal,
LoadIndirect, // lw rd, 0(rs1) 从寄存器地址加载
StoreIndirect, // sw rs2, 0(rs1)
Call,
GEP,
LoadAddr,
Ret,
// 浮点指令
FMov, // 浮点移动
FMovWX, // fmv.w.x fs, x 整数寄存器移动到浮点寄存器
FMovXW, // fmv.x.w x, fs 浮点寄存器移动到整数寄存器
FAdd,
FSub,
FMul,
FDiv,
FEq, // 浮点相等比较
FLt, // 浮点小于比较
FLe, // 浮点小于等于比较
FNeg, // 浮点取反
FAbs, // 浮点绝对值
SIToFP, // int 转 float
FPToSI, // float 转 int
LoadFloat, // 浮点加载 (flw)
StoreFloat, // 浮点存储 (fsw)
Br,
CondBr,
Label,
};
enum class GlobalKind {
Data, // .data 段(已初始化)
BSS, // .bss 段未初始化初始为0
RoData // .rodata 段(只读常量)
};
// 全局变量信息
struct GlobalInfo {
std::string name;
GlobalKind kind;
int size; // 大小(字节)
int value; // 初始值(对于简单变量)
bool isArray;
int arraySize;
std::vector<int> dimensions; // 数组维度
};
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex };
enum class Kind { Reg, Imm, FrameIndex, Global, Func };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand Imm64(int64_t value); // 新增:存储 64 位值
static Operand FrameIndex(int index);
static Operand Global(const std::string& name);
static Operand Func(const std::string& name);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int64_t GetImm64() const { return imm64_; } // 新增
int GetFrameIndex() const { return imm_; }
const std::string& GetGlobalName() const { return global_name_; }
const std::string& GetFuncName() const { return func_name_; }
private:
Operand(Kind kind, PhysReg reg, int imm);
Operand(Kind kind, PhysReg reg, int64_t imm64); // 新增构造函数
Operand(Kind kind, PhysReg reg, int imm, const std::string& name);
Kind kind_;
PhysReg reg_;
int imm_;
int64_t imm64_; // 新增
std::string global_name_;
std::string func_name_;
};
class MachineInstr {
@ -71,7 +193,6 @@ struct FrameSlot {
int size = 4;
int offset = 0;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
@ -93,9 +214,14 @@ class MachineFunction {
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
// 基本块管理
MachineBasicBlock* CreateBlock(const std::string& name);
MachineBasicBlock* GetEntry() { return entry_; }
const MachineBasicBlock* GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const { return blocks_; }
// 栈帧管理
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
@ -106,14 +232,15 @@ class MachineFunction {
private:
std::string name_;
MachineBasicBlock entry_;
MachineBasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
//std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os);
} // namespace mir
//void PrintAsm(const MachineFunction& function, std::ostream& os);
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module);
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os);
} // namespace mir

@ -1,14 +1,16 @@
// 简易命令行解析:支持帮助、输入文件与输出阶段选择。
// 命令行解析compiler <input.sy> -S|-IR -o <output> [-O1]
// 同时兼容 --emit-ir / --emit-asm
#pragma once
#include <string>
struct CLIOptions {
std::string input;
bool emit_parse_tree = false;
bool emit_ir = true;
bool emit_asm = false;
std::string output; // -o <file>,为空则输出到 stdout
bool emit_ir = false; // -IR / --emit-ir
bool emit_asm = false; // -S / --emit-asm
bool show_help = false;
bool opt = false; // -O1
};
CLIOptions ParseCLI(int argc, char** argv);

@ -0,0 +1,309 @@
.text
.global main
.type main, @function
main:
addi sp, sp, -272
sw ra, 264(sp)
sw s0, 256(sp)
addi a0, sp, -4
li a1, 0
li a2, 32
call
addi a0, sp, -8
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -8
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -8
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -8
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -8
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -8
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -8
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -8
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -8
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -44
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -44
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -44
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -44
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -44
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -44
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -44
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -44
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -44
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -80
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -80
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -80
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -80
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -80
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -80
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -80
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -112(sp)
li t0, 1
lw t1, -112(sp)
add t0, t0, t1
sw t0, -116(sp)
addi t0, sp, -80
lw t1, -116(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -124(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -128(sp)
li t0, 1
lw t1, -128(sp)
add t0, t0, t1
sw t0, -132(sp)
addi t0, sp, -44
lw t1, -132(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -140(sp)
addi a0, sp, -108
li a1, 0
li a2, 32
call
lw t2, -124(sp)
addi t0, sp, -108
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
lw t2, -140(sp)
addi t0, sp, -108
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -108
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -108
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -108
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -108
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -108
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -108
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 3
li t1, 2
mul t0, t0, t1
sw t0, -176(sp)
li t0, 1
lw t1, -176(sp)
add t0, t0, t1
sw t0, -180(sp)
addi t0, sp, -108
lw t1, -180(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -188(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -192(sp)
li t0, 0
lw t1, -192(sp)
add t0, t0, t1
sw t0, -196(sp)
addi t0, sp, -108
lw t1, -196(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -204(sp)
lw t0, -188(sp)
lw t1, -204(sp)
add t0, t0, t1
sw t0, -208(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -212(sp)
li t0, 1
lw t1, -212(sp)
add t0, t0, t1
sw t0, -216(sp)
addi t0, sp, -108
lw t1, -216(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -224(sp)
lw t0, -208(sp)
lw t1, -224(sp)
add t0, t0, t1
sw t0, -228(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -232(sp)
li t0, 0
lw t1, -232(sp)
add t0, t0, t1
sw t0, -236(sp)
addi t0, sp, -4
lw t1, -236(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -244(sp)
lw t0, -228(sp)
lw t1, -244(sp)
add t0, t0, t1
sw t0, -248(sp)
lw a0, -248(sp)
lw ra, 264(sp)
lw s0, 256(sp)
addi sp, sp, 272
ret
.size main, .-main

@ -0,0 +1,147 @@
#!/usr/bin/env bash
# 优化效果对比:测量 O0 vs O1 的编译时间和运行时间
# 用法: bash scripts/bench_ir.sh [--test-dir=<dir>] [--result-dir=<dir>]
set -uo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
TEST_CASE_DIR="${PROJECT_ROOT}/test/test_case"
RESULT_DIR="${PROJECT_ROOT}/test/test_result/bench"
while [[ $# -gt 0 ]]; do
case "$1" in
--test-dir=*) TEST_CASE_DIR="${1#*=}" ;;
--result-dir=*) RESULT_DIR="${1#*=}" ;;
*) echo "未知参数: $1" >&2; exit 1 ;;
esac
shift
done
compiler="${PROJECT_ROOT}/build/bin/compiler"
[[ -x "$compiler" ]] || { echo "错误:未找到编译器 $compiler" >&2; exit 1; }
command -v llc >/dev/null 2>&1 || { echo "错误:未找到 llc" >&2; exit 1; }
command -v clang >/dev/null 2>&1 || { echo "错误:未找到 clang" >&2; exit 1; }
mkdir -p "$RESULT_DIR"
# 时间测量:使用 date +%s.%N
now() { date +%s.%N; }
elapsed() { python3 -c "print(f'{float($2)-float($1):.4f}')" 2>/dev/null || awk "BEGIN{printf \"%.4f\\n\",$2-$1}"; }
summary_file="${RESULT_DIR}/summary.csv"
echo "test,opt,compile_s,exec_s,compile+exec_s" > "$summary_file"
total=0
o0_ct_total=0; o1_ct_total=0
o0_et_total=0; o1_et_total=0
echo "=== 优化效果对比 O0 vs O1 ==="
echo ""
while read -r test_file; do
full_path=$(readlink -f "$test_file")
tcdir=$(readlink -f "$TEST_CASE_DIR")
rel="${full_path#$tcdir}"
[[ "${rel:0:1}" != "/" ]] && rel="/$rel"
base=$(basename "$test_file")
stem="${base%.sy}"
idir=$(dirname "$test_file")
stdin="${idir}/${stem}.in"
expected="${idir}/${stem}.out"
total=$((total+1))
printf "[%4d] %s" "$total" "$rel"
o0_ll="${RESULT_DIR}/O0/${rel%.sy}.ll"
o1_ll="${RESULT_DIR}/O1/${rel%.sy}.ll"
mkdir -p "$(dirname "$o0_ll")" "$(dirname "$o1_ll")"
# --- 编译 O0 ---
t1=$(now)
"$compiler" "$test_file" -IR -o "$o0_ll" 2>/dev/null; rc0=$?
t2=$(now)
if [[ $rc0 -ne 0 ]]; then
echo " | O0编译失败"
echo "$stem,O0,-,-,-" >> "$summary_file"
echo "$stem,O1,-,-,-" >> "$summary_file"
continue
fi
o0_ct=$(elapsed "$t1" "$t2")
# --- 编译 O1 ---
t1=$(now)
"$compiler" "$test_file" -IR -o "$o1_ll" -O1 2>/dev/null; rc1=$?
t2=$(now)
if [[ $rc1 -ne 0 ]]; then
echo " | O1编译失败"
echo "$stem,O0,$o0_ct,-,-" >> "$summary_file"
echo "$stem,O1,-,-,-" >> "$summary_file"
continue
fi
o1_ct=$(elapsed "$t1" "$t2")
# --- llc + clang O0 ---
o0_obj="${RESULT_DIR}/O0/${stem}.o"
o1_obj="${RESULT_DIR}/O1/${stem}.o"
o0_exe="${RESULT_DIR}/O0/${stem}.exe"
o1_exe="${RESULT_DIR}/O1/${stem}.exe"
llc -filetype=obj "$o0_ll" -o "$o0_obj" 2>/dev/null
llc -filetype=obj "$o1_ll" -o "$o1_obj" 2>/dev/null
clang "$o0_obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$o0_exe" -lm 2>/dev/null
clang "$o1_obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$o1_exe" -lm 2>/dev/null
# --- 运行 O0 ---
t1=$(now)
sr0=0
if [[ -f "$stdin" ]]; then
(ulimit -s unlimited; "$o0_exe" < "$stdin") > /dev/null 2>&1 || sr0=$?
else
(ulimit -s unlimited; "$o0_exe") > /dev/null 2>&1 || sr0=$?
fi
t2=$(now)
o0_et=$(elapsed "$t1" "$t2")
# --- 运行 O1 ---
t1=$(now)
sr1=0
if [[ -f "$stdin" ]]; then
(ulimit -s unlimited; "$o1_exe" < "$stdin") > /dev/null 2>&1 || sr1=$?
else
(ulimit -s unlimited; "$o1_exe") > /dev/null 2>&1 || sr1=$?
fi
t2=$(now)
o1_et=$(elapsed "$t1" "$t2")
# 验证一致性
flag=""
if [[ $sr0 -ne $sr1 ]]; then
flag=" EXIT:O0=$sr0 O1=$sr1"
fi
# 累计 & 比率
o0_ct_total=$(awk "BEGIN{printf \"%.4f\",$o0_ct_total+$o0_ct}")
o1_ct_total=$(awk "BEGIN{printf \"%.4f\",$o1_ct_total+$o1_ct}")
o0_et_total=$(awk "BEGIN{printf \"%.4f\",$o0_et_total+$o0_et}")
o1_et_total=$(awk "BEGIN{printf \"%.4f\",$o1_et_total+$o1_et}")
cspd=$(awk "BEGIN{if($o1_ct>0)printf \"%.1fx\",$o0_ct/$o1_ct; else print \"-\"}")
espd=$(awk "BEGIN{if($o1_et>0)printf \"%.1fx\",$o0_et/$o1_et; else print \"-\"}")
printf " | 编译 O0:%.4fs O1:%.4fs(%s) 运行 O0:%.4fs O1:%.4fs(%s)%s\n" \
"$o0_ct" "$o1_ct" "$cspd" "$o0_et" "$o1_et" "$espd" "$flag"
echo "$stem,O0,$o0_ct,$o0_et,$(awk "BEGIN{printf \"%.4f\",$o0_ct+$o0_et}")" >> "$summary_file"
echo "$stem,O1,$o1_ct,$o1_et,$(awk "BEGIN{printf \"%.4f\",$o1_ct+$o1_et}")" >> "$summary_file"
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
echo ""
echo "============================================"
echo "总用例: $total"
echo "O0 编译总耗时: ${o0_ct_total}s"
echo "O1 编译总耗时: ${o1_ct_total}s"
echo "O0 运行总耗时: ${o0_et_total}s"
echo "O1 运行总耗时: ${o1_et_total}s"
echo "CSV: $summary_file"

@ -0,0 +1,100 @@
#!/usr/bin/env python3
import sys
def read_file(filepath):
"""读取文件内容,返回行列表"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
return f.readlines()
except FileNotFoundError:
print(f"错误:文件 '{filepath}' 不存在")
sys.exit(1)
def get_context(line, pos, context_len=10):
"""获取字符上下文"""
start = max(0, pos - context_len)
end = min(len(line), pos + context_len + 1)
prefix = "..." if start > 0 else ""
suffix = "..." if end < len(line) else ""
return prefix + line[start:end] + suffix
def compare_files(file1, file2):
"""比较两个文件,输出详细差异"""
lines1 = read_file(file1)
lines2 = read_file(file2)
print(f"比较文件: {file1} vs {file2}")
print("=" * 80)
max_lines = max(len(lines1), len(lines2))
differences = 0
for line_num in range(max_lines):
line1 = lines1[line_num] if line_num < len(lines1) else None
line2 = lines2[line_num] if line_num < len(lines2) else None
if line1 is None:
print(f"\n[新增行] 第 {line_num + 1}")
print(f" + {repr(line2)}")
differences += 1
continue
if line2 is None:
print(f"\n[删除行] 第 {line_num + 1}")
print(f" - {repr(line1)}")
differences += 1
continue
if line1 == line2:
continue
# 行内容不同,逐字符比较
print(f"\n[差异行] 第 {line_num + 1}")
max_chars = max(len(line1), len(line2))
for char_pos in range(max_chars):
char1 = line1[char_pos] if char_pos < len(line1) else None
char2 = line2[char_pos] if char_pos < len(line2) else None
if char1 == char2:
continue
# 找到差异字符
context1 = get_context(line1, char_pos) if line1 else ""
context2 = get_context(line2, char_pos) if line2 else ""
print(f" 字符位置 {char_pos + 1}:")
if char1 is not None:
print(f" - {repr(char1)} | 上下文: {repr(context1)}")
else:
print(f" - (缺失)")
if char2 is not None:
print(f" + {repr(char2)} | 上下文: {repr(context2)}")
else:
print(f" + (缺失)")
differences += 1
# 跳过一些连续差异,避免输出过多
while char_pos + 1 < max_chars:
next_char1 = line1[char_pos + 1] if char_pos + 1 < len(line1) else None
next_char2 = line2[char_pos + 1] if char_pos + 1 < len(line2) else None
if next_char1 != next_char2:
char_pos += 1
else:
break
print("\n" + "=" * 80)
print(f"比较完成,共发现 {differences} 处差异")
def main():
if len(sys.argv) != 3:
print("用法: python diff.py <文件1> <文件2>")
sys.exit(1)
file1 = sys.argv[1]
file2 = sys.argv[2]
compare_files(file1, file2)
if __name__ == "__main__":
main()

@ -0,0 +1,137 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/mir"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
exit 1
fi
mkdir -p "$TEST_RESULT_DIR"
echo "=========================================="
echo "RISC-V 后端测试"
echo "=========================================="
echo ""
# 收集测试用例
mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort)
total=${#test_files[@]}
pass_gen=0
fail_gen=0
pass_run=0
fail_run=0
timeout_cnt=0
echo "=== 阶段1汇编生成 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s"
mkdir -p "$(dirname "$output_file")"
"$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file"
if [ $? -eq 0 ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}${NC} $relative_path"
((pass_gen++))
else
echo -e " ${RED}${NC} $relative_path"
((fail_gen++))
fi
done
echo ""
echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---"
echo ""
echo "=== 阶段2运行验证 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
stem="${relative_path%.sy}"
asm_file="$TEST_RESULT_DIR/${stem}.s"
exe_file="$TEST_RESULT_DIR/${stem}"
expected_file="${test_file%.sy}.out"
if [ ! -s "$asm_file" ]; then
echo -e " ${YELLOW}${NC} $relative_path (跳过)"
continue
fi
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e " ${RED}${NC} $relative_path (链接失败)"
((fail_run++))
continue
fi
# 运行程序,设置超时 5 秒
timeout 5 qemu-riscv64 "$exe_file" 2>/dev/null
exit_code=$?
# 检查是否超时
if [ $exit_code -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
continue
fi
# 获取程序输出(需要单独捕获,因为 timeout 会改变输出)
program_output=$(timeout 5 qemu-riscv64 "$exe_file" 2>/dev/null)
if [ $? -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
continue
fi
if [ -f "$expected_file" ]; then
expected=$(cat "$expected_file" | tr -d '\n')
# 判断期望文件是输出内容还是退出码
if [ -z "$expected" ] || [[ "$expected" =~ ^[0-9]+$ ]]; then
# 期望退出码
if [ $exit_code -eq "$expected" ] 2>/dev/null; then
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)"
((fail_run++))
fi
else
# 期望输出内容
if [ "$program_output" = "$expected" ]; then
echo -e " ${GREEN}${NC} $relative_path (输出匹配)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (输出不匹配)"
((fail_run++))
fi
fi
else
# 没有期望文件,默认通过
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
fi
done
echo ""
echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---"
echo ""
echo "=========================================="
echo "测试完成"
echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen"
echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt"
echo "=========================================="

@ -0,0 +1,153 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/mir"
SYLIB_C="$PROJECT_ROOT/sylib/sylib.c"
SYLIB_O="/tmp/sylib.o"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
exit 1
fi
# 编译 sylib 运行时库
echo "编译运行时库..."
if [ ! -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -c "$SYLIB_C" -o "$SYLIB_O" 2>/dev/null
if [ $? -ne 0 ]; then
echo "警告:无法编译 sylib.c部分测试可能链接失败"
fi
fi
echo ""
mkdir -p "$TEST_RESULT_DIR"
echo "=========================================="
echo "RISC-V 后端测试"
echo "=========================================="
echo ""
# 收集测试用例
mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort)
total=${#test_files[@]}
pass_gen=0
fail_gen=0
pass_run=0
fail_run=0
timeout_cnt=0
echo "=== 阶段1汇编生成 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s"
mkdir -p "$(dirname "$output_file")"
"$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file"
if [ $? -eq 0 ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}${NC} $relative_path"
((pass_gen++))
else
echo -e " ${RED}${NC} $relative_path"
((fail_gen++))
fi
done
echo ""
echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---"
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
stem="${relative_path%.sy}"
asm_file="$TEST_RESULT_DIR/${stem}.s"
exe_file="$TEST_RESULT_DIR/${stem}"
expected_file="${test_file%.sy}.out"
if [ ! -s "$asm_file" ]; then
echo -e " ${YELLOW}${NC} $relative_path (跳过)"
continue
fi
# 链接
if [ -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -static "$asm_file" "$SYLIB_O" -o "$exe_file" -no-pie 2>/dev/null
else
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null
fi
if [ $? -ne 0 ]; then
echo -e " ${RED}${NC} $relative_path (链接失败)"
((fail_run++))
continue
fi
# 运行程序
input_file="${test_file%.sy}.in"
tmp_out=$(mktemp)
if [ -f "$input_file" ]; then
timeout 10 qemu-riscv64 "$exe_file" < "$input_file" > "$tmp_out" 2>&1
else
timeout 10 qemu-riscv64 "$exe_file" > "$tmp_out" 2>&1
fi
exit_code=$?
if [ $exit_code -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
rm -f "$tmp_out"
continue
fi
program_output=$(cat "$tmp_out" | tr -d '\n' | sed 's/[[:space:]]*$//')
rm -f "$tmp_out"
if [ -f "$expected_file" ]; then
expected=$(cat "$expected_file" | tr -d '\n' | sed 's/[[:space:]]*$//')
if [[ "$expected" =~ ^[0-9]+$ ]] && [ "$expected" -ge 0 ] && [ "$expected" -le 255 ] && [ -z "$program_output" ]; then
# 期望退出码(且没有输出)
if [ $exit_code -eq "$expected" ] 2>/dev/null; then
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)"
((fail_run++))
fi
else
# 期望输出内容
if [ "$program_output" = "$expected" ]; then
echo -e " ${GREEN}${NC} $relative_path (输出匹配)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (输出不匹配: 期望 '$expected', 实际 '$program_output')"
((fail_run++))
fi
fi
else
# 没有期望文件
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code, 输出: '$program_output')"
((pass_run++))
fi
done
echo ""
echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---"
echo ""
echo "=========================================="
echo "测试完成"
echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen"
echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt"
echo "=========================================="

@ -0,0 +1,65 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_DIR="$PROJECT_ROOT/test/test_case/basic"
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器不存在: $COMPILER"
exit 1
fi
echo "=========================================="
echo "RISC-V 浮点转换测试"
echo "=========================================="
TESTS="
float_conv:3
float_add:13
float_mul:30
"
PASS=0
FAIL=0
for test in $TESTS; do
name=$(echo $test | cut -d: -f1)
expected=$(echo $test | cut -d: -f2)
echo -n "测试 $name (期望 $expected) ... "
"$COMPILER" "$TEST_DIR/$name.sy" --emit-asm > /tmp/test_$name.s 2>&1
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (汇编错误)${NC}"
cat /tmp/test_$name.s | head -3
FAIL=$((FAIL + 1))
continue
fi
riscv64-linux-gnu-gcc -static /tmp/test_$name.s -o /tmp/test_$name -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (链接错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
qemu-riscv64 /tmp/test_$name > /dev/null 2>&1
exit_code=$?
if [ $exit_code -eq $expected ]; then
echo -e "${GREEN}通过${NC}"
PASS=$((PASS + 1))
else
echo -e "${RED}失败 (实际 $exit_code)${NC}"
FAIL=$((FAIL + 1))
fi
done
echo "=========================================="
echo -e "测试结果: ${GREEN}通过 $PASS${NC} / ${RED}失败 $FAIL${NC}"
echo "=========================================="

@ -1,58 +1,308 @@
#!/bin/bash
#!/usr/bin/env bash
# 串行执行IR测试脚本实时输出结果
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir"
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
echo "请先构建项目cmake --build build -j\$(nproc)"
exit 1
# 默认参数
TEST_CASE_DIR="${PROJECT_ROOT}/test/test_case"
TEST_RESULT_DIR="${PROJECT_ROOT}/test/test_result/ir"
RUN_EXEC=false
VERBOSE=false
OPT_FLAG="-O1" # 默认开启优化
# 解析命令行参数
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
RUN_EXEC=true
;;
--verbose|-v)
VERBOSE=true
;;
--O0)
OPT_FLAG=""
;;
--O1)
OPT_FLAG="-O1"
;;
--test-dir=*)
TEST_CASE_DIR="${1#*=}"
;;
--result-dir=*)
TEST_RESULT_DIR="${1#*=}"
;;
*)
echo "未知参数: $1" >&2
echo "用法: $0 [--run] [--O0|--O1] [--verbose] [--test-dir=<dir>] [--result-dir=<dir>]" >&2
exit 1
;;
esac
shift
done
# 检查编译器是否存在
compiler="${PROJECT_ROOT}/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "错误:未找到编译器 $compiler" >&2
echo "请先构建项目: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1
fi
# 创建输出目录
mkdir -p "$TEST_RESULT_DIR"
pass_count=0
fail_count=0
failed_cases=()
echo "=== 开始测试 IR 生成 ==="
echo ""
# 统计变量
total_tests=0
passed_tests=0
failed_tests=0
while IFS= read -r test_file; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll"
# 汇总日志文件
summary_log="${TEST_RESULT_DIR}/summary.log"
> "$summary_log"
mkdir -p "$(dirname "$output_file")"
# 失败测试列表
failed_list=""
echo -n "测试: $relative_path ... "
echo "=== 开始IR测试 ==="
echo "测试目录: $TEST_CASE_DIR"
echo "结果目录: $TEST_RESULT_DIR"
echo "优化级别: ${OPT_FLAG:--O0}"
echo "运行可执行文件: $RUN_EXEC"
echo ""
"$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1
exit_code=$?
# 串行遍历所有测试用例
while read -r test_file; do
total_tests=$((total_tests + 1))
# 计算相对路径
full_path=$(readlink -f "$test_file")
test_case_path=$(readlink -f "$TEST_CASE_DIR")
relative_path="${full_path#$test_case_path}"
# 确保路径以 / 开头
if [[ "${relative_path:0:1}" != "/" ]]; then
relative_path="/$relative_path"
fi
# 计算输出文件路径
base=$(basename "$test_file")
stem="${base%.sy}"
output_file="${TEST_RESULT_DIR}/${relative_path%.sy}.ll"
output_dir=$(dirname "$output_file")
# 创建输出目录
mkdir -p "$output_dir"
# 获取输入和预期输出文件路径
input_dir=$(dirname "$test_file")
stdin_file="${input_dir}/${stem}.in"
expected_file="${input_dir}/${stem}.out"
# 每个测试用例的详细日志文件
test_log="${output_dir}/${stem}.log"
> "$test_log"
echo "[$total_tests] 处理: $relative_path"
echo "[$(date '+%Y-%m-%d %H:%M:%S')] 开始处理: $relative_path" >> "$test_log"
echo "输入文件: $test_file" >> "$test_log"
echo "输出目录: $output_dir" >> "$test_log"
# 生成IR
if $VERBOSE; then
echo " 生成IR..."
fi
echo "步骤1: 生成IR" >> "$test_log"
if [ $exit_code -eq 0 ] && [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then
echo "通过"
pass_count=$((pass_count + 1))
# 计时编译
compile_start=$(date +%s.%N)
set +e
"$compiler" "$test_file" -IR -o "$output_file" $OPT_FLAG 2>&1
ir_status=$?
set -e
compile_end=$(date +%s.%N)
compile_time=$(python3 -c "print(f'{float($compile_end)-float($compile_start):.3f}s')" 2>/dev/null || echo "-")
if [[ $ir_status -ne 0 ]]; then
echo " ✗ IR生成失败"
echo "结果: FAILED (IR生成失败)" >> "$test_log"
echo "错误信息:" >> "$test_log"
cat "$output_file" >> "$test_log"
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - IR生成失败"
if $VERBOSE; then
cat "$output_file"
fi
echo ""
continue
fi
echo "IR文件: $output_file" >> "$test_log"
if $VERBOSE; then
echo " ✓ IR已生成: $output_file"
fi
# 如果需要运行可执行文件
if [[ "$RUN_EXEC" == true ]]; then
echo "步骤2: 编译和运行" >> "$test_log"
if ! command -v llc >/dev/null 2>&1; then
echo " 警告: 未找到 llc跳过执行测试" >&2
echo "警告: 未找到 llc跳过执行测试" >> "$test_log"
echo "结果: SKIPPED (缺少llc)" >> "$test_log"
passed_tests=$((passed_tests + 1))
echo ""
continue
fi
if ! command -v clang >/dev/null 2>&1; then
echo " 警告: 未找到 clang跳过执行测试" >&2
echo "警告: 未找到 clang跳过执行测试" >> "$test_log"
echo "结果: SKIPPED (缺少clang)" >> "$test_log"
passed_tests=$((passed_tests + 1))
echo ""
continue
fi
obj="${output_dir}/${stem}.o"
exe="${output_dir}/${stem}"
stdout_file="${output_dir}/${stem}.stdout"
actual_file="${output_dir}/${stem}.actual.out"
# 编译IR为目标文件
if $VERBOSE; then
echo " 编译IR..."
fi
echo "编译IR: llc -filetype=obj $output_file -o $obj" >> "$test_log"
llc -filetype=obj "$output_file" -o "$obj" 2>/dev/null
if [[ $? -ne 0 ]]; then
echo " ✗ IR编译失败"
echo "结果: FAILED (IR编译失败)" >> "$test_log"
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - IR编译失败"
echo ""
continue
fi
echo "目标文件: $obj" >> "$test_log"
# 链接为可执行文件
echo "链接: clang $obj ${PROJECT_ROOT}/sylib/sylib.c -o $exe -lm" >> "$test_log"
clang "$obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$exe" -lm 2>/dev/null
if [[ $? -ne 0 ]]; then
echo " ✗ 链接失败"
echo "结果: FAILED (链接失败)" >> "$test_log"
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - 链接失败"
echo ""
continue
fi
echo "可执行文件: $exe" >> "$test_log"
# 运行可执行文件
if $VERBOSE; then
echo " 运行..."
fi
echo "运行命令: $exe" >> "$test_log"
if [[ -f "$stdin_file" ]]; then
echo "标准输入: $stdin_file" >> "$test_log"
fi
run_start=$(date +%s.%N)
set +e
if [[ -f "$stdin_file" ]]; then
(ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file"
else
(ulimit -s unlimited; "$exe") > "$stdout_file"
fi
status=$?
set -e
run_end=$(date +%s.%N)
run_time=$(python3 -c "print(f'{float($run_end)-float($run_start):.3f}s')" 2>/dev/null || echo "-")
echo "退出码: $status" >> "$test_log"
echo "标准输出:" >> "$test_log"
cat "$stdout_file" >> "$test_log"
# 保存实际输出(包含退出码)
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
# 比对输出
echo "步骤3: 比对输出" >> "$test_log"
if [[ -f "$expected_file" ]]; then
echo "预期输出: $expected_file" >> "$test_log"
echo "实际输出: $actual_file" >> "$test_log"
if diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \
<(tr -d '\r' < "$actual_file" | sed -e '$a\') > /dev/null 2>&1; then
echo " ✓ 编译:${compile_time} 运行:${run_time}"
echo "结果: PASSED (输出匹配)" >> "$test_log"
passed_tests=$((passed_tests + 1))
else
echo " ✗ 输出不匹配"
echo "结果: FAILED (输出不匹配)" >> "$test_log"
echo "差异:" >> "$test_log"
diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \
<(tr -d '\r' < "$actual_file" | sed -e '$a\') >> "$test_log" 2>&1 || true
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - 输出不匹配"
if $VERBOSE; then
echo " 预期:"
cat "$expected_file"
echo " 实际:"
cat "$actual_file"
fi
fi
else
echo "失败"
fail_count=$((fail_count + 1))
failed_cases+=("$relative_path")
echo " 错误信息已保存到: $output_file"
echo " ? 无预期输出文件,跳过比对 (编译:${compile_time})"
echo "警告: 无预期输出文件,跳过比对" >> "$test_log"
echo "结果: SKIPPED (无预期输出)" >> "$test_log"
passed_tests=$((passed_tests + 1))
fi
else
echo "步骤2: 跳过执行测试 (--run未启用)" >> "$test_log"
echo "结果: PASSED (仅IR生成)" >> "$test_log"
passed_tests=$((passed_tests + 1))
echo " ✓ 编译:${compile_time}"
fi
echo ""
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
echo ""
# 输出统计结果到终端
echo "=== 测试完成 ==="
echo "通过: $pass_count"
echo "失败: $fail_count"
echo "结果保存在: $TEST_RESULT_DIR"
echo "总测试数: $total_tests"
echo "通过: $passed_tests"
echo "失败: $failed_tests"
if [ ${#failed_cases[@]} -gt 0 ]; then
echo ""
echo "=== 失败的用例 ==="
for f in "${failed_cases[@]}"; do
echo " - $f"
done
exit 1
# 写入汇总日志
echo "=== IR测试汇总报告 ===" > "$summary_log"
echo "测试时间: $(date '+%Y-%m-%d %H:%M:%S')" >> "$summary_log"
echo "测试目录: $TEST_CASE_DIR" >> "$summary_log"
echo "结果目录: $TEST_RESULT_DIR" >> "$summary_log"
echo "运行可执行文件: $RUN_EXEC" >> "$summary_log"
echo "" >> "$summary_log"
echo "=== 统计结果 ===" >> "$summary_log"
echo "总测试数: $total_tests" >> "$summary_log"
echo "通过: $passed_tests" >> "$summary_log"
echo "失败: $failed_tests" >> "$summary_log"
echo "成功率: $((passed_tests * 100 / total_tests))%" >> "$summary_log"
echo "" >> "$summary_log"
if [[ $failed_tests -gt 0 ]]; then
echo "=== 失败测试列表 ===" >> "$summary_log"
echo -e "$failed_list" >> "$summary_log"
fi
echo "详细日志已保存到各测试用例目录"
echo "汇总日志: $summary_log"
if [[ $failed_tests -eq 0 ]]; then
echo "所有测试通过!"
exit 0
else
echo "$failed_tests 个测试失败"
exit 1
fi

@ -0,0 +1,85 @@
#!/bin/bash
# 获取项目根目录
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_DIR="$PROJECT_ROOT/test/test_case/basic"
# 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
# 检查编译器
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器不存在: $COMPILER"
exit 1
fi
# 检查工具链
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "错误: 未找到 riscv64-linux-gnu-gcc"
exit 1
fi
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "错误: 未找到 qemu-riscv64"
exit 1
fi
echo "=========================================="
echo "RISC-V 基础功能测试"
echo "=========================================="
# 定义测试用例
TESTS="arith:50 add:30 sub:7 mul:50 div:25 mod:2 var:43"
PASS=0
FAIL=0
for test in $TESTS; do
name=$(echo $test | cut -d: -f1)
expected=$(echo $test | cut -d: -f2)
echo -n "测试 $name (期望 $expected) ... "
# 生成汇编
"$COMPILER" "$TEST_DIR/$name.sy" --emit-asm > /tmp/test_$name.s 2>&1
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (汇编错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
# 链接
riscv64-linux-gnu-gcc -static /tmp/test_$name.s -o /tmp/test_$name -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (链接错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
# 运行
qemu-riscv64 /tmp/test_$name > /dev/null 2>&1
exit_code=$?
if [ $exit_code -eq $expected ]; then
echo -e "${GREEN}通过${NC}"
PASS=$((PASS + 1))
else
echo -e "${RED}失败 (实际 $exit_code)${NC}"
FAIL=$((FAIL + 1))
fi
done
echo "=========================================="
echo -e "测试结果: ${GREEN}通过 $PASS${NC} / ${RED}失败 $FAIL${NC}"
echo "=========================================="
if [ $FAIL -eq 0 ]; then
echo -e "${GREEN}✓ 所有基础测试通过!${NC}"
exit 0
else
echo -e "${RED}✗ 有 $FAIL 个测试失败${NC}"
exit 1
fi

@ -3,6 +3,8 @@
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
@ -31,7 +33,7 @@ if [[ ! -f "$input" ]]; then
exit 1
fi
compiler="./build/bin/compiler"
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1
@ -43,7 +45,7 @@ stem=${base%.sy}
out_file="$out_dir/$stem.ll"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-ir "$input" > "$out_file"
"$compiler" "$input" -IR -o "$out_file" -O1
echo "IR 已生成: $out_file"
if [[ "$run_exec" == true ]]; then
@ -60,13 +62,13 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe"
clang "$obj" "$PROJECT_ROOT/sylib/sylib.c" -o "$exe" -lm
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
(ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file"
else
"$exe" > "$stdout_file"
(ulimit -s unlimited; "$exe") > "$stdout_file"
fi
status=$?
set -e
@ -81,9 +83,11 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
if diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \
<(tr -d '\r' < "$actual_file" | sed -e '$a\') > /dev/null 2>&1; then
echo "输出匹配: $expected_file"
else
diff -u "$expected_file" "$actual_file" || true
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
exit 1

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/mir"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
mir_file="$out_dir/$stem.mir"
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
# 生成 MIR
"$compiler" --emit-mir "$input" > "$mir_file"
echo "MIR 已生成: $mir_file"
# 生成汇编
"$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file"
if [[ "$run_exec" == true ]]; then
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 riscv64-linux-gnu-gcc" >&2
exit 1
fi
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "未找到 qemu-riscv64" >&2
exit 1
fi
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe" -no-pie
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
qemu-riscv64 "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-riscv64 "$exe" > "$stdout_file"
fi
status=$?
set -e
cat "$stdout_file"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/riscv_asm"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 riscv64-linux-gnu-gcc无法汇编/链接。" >&2
echo "请安装: sudo apt install gcc-riscv64-linux-gnu" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" 2>/dev/null > "$asm_file"
echo "汇编已生成: $asm_file"
# 使用静态链接避免动态链接器问题
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe" -no-pie
echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "未找到 qemu-riscv64无法运行生成的可执行文件。" >&2
echo "请安装: sudo apt install qemu-user" >&2
exit 1
fi
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
qemu-riscv64 "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-riscv64 "$exe" > "$stdout_file"
fi
status=$?
set -e
cat "$stdout_file"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi

@ -21,6 +21,7 @@ if(NOT COMPILER_PARSE_ONLY)
target_link_libraries(compiler PRIVATE
sem
irgen
ir
mir
)
target_compile_definitions(compiler PRIVATE COMPILER_PARSE_ONLY=0)

@ -9,6 +9,7 @@
#include "ir/IR.h"
#include <algorithm>
#include <utility>
namespace ir {
@ -42,4 +43,53 @@ const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_;
}
void BasicBlock::AddPredecessor(BasicBlock* bb) {
if (bb) predecessors_.push_back(bb);
}
void BasicBlock::RemovePredecessor(BasicBlock* bb) {
predecessors_.erase(
std::remove(predecessors_.begin(), predecessors_.end(), bb),
predecessors_.end());
}
void BasicBlock::ClearPredecessors() { predecessors_.clear(); }
void BasicBlock::AddSuccessor(BasicBlock* bb) {
if (bb) successors_.push_back(bb);
}
void BasicBlock::ClearSuccessors() { successors_.clear(); }
void BasicBlock::RemoveInstruction(Instruction* inst) {
for (auto it = instructions_.begin(); it != instructions_.end(); ++it) {
if (it->get() == inst) {
instructions_.erase(it);
return;
}
}
}
void BasicBlock::InsertBefore(Instruction* inst, Instruction* before) {
if (!before) {
// append (respecting terminator)
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
instructions_.insert(instructions_.end() - 1,
std::unique_ptr<Instruction>(inst));
} else {
instructions_.push_back(std::unique_ptr<Instruction>(inst));
}
} else {
for (auto it = instructions_.begin(); it != instructions_.end(); ++it) {
if (it->get() == before) {
instructions_.insert(it, std::unique_ptr<Instruction>(inst));
return;
}
}
// before not found, append instead
instructions_.push_back(std::unique_ptr<Instruction>(inst));
}
inst->SetParent(this);
}
} // namespace ir

@ -29,4 +29,10 @@ std::string Context::NextTemp() {
return oss.str();
}
std::string Context::NextLabel() {
std::ostringstream oss;
oss << "L" << ++label_index_;
return oss.str();
}
} // namespace ir

@ -1,6 +1,8 @@
// IR Function
#include "ir/IR.h"
#include <algorithm>
namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type)
@ -17,6 +19,17 @@ BasicBlock* Function::CreateBlock(const std::string& name) {
return ptr;
}
void Function::MoveBlockToEnd(BasicBlock* bb) {
for (size_t i = 0; i < blocks_.size(); ++i) {
if (blocks_[i].get() == bb) {
auto tmp = std::move(blocks_[i]);
blocks_.erase(blocks_.begin() + i);
blocks_.push_back(std::move(tmp));
return;
}
}
}
BasicBlock* Function::GetEntry() { return entry_; }
const BasicBlock* Function::GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
@ -34,4 +47,84 @@ Argument* Function::GetArgument(size_t i) const {
return args_[i].get();
}
void Function::RebuildCFG() {
// 清除所有块的前驱/后继
for (auto& bb : blocks_) {
bb->ClearPredecessors();
bb->ClearSuccessors();
}
// 根据终结指令重新计算
for (auto& bb : blocks_) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (!term->IsTerminator()) continue;
switch (term->GetOpcode()) {
case Opcode::Br: {
auto* target = static_cast<BrInst*>(term)->GetTarget();
bb->AddSuccessor(target);
target->AddPredecessor(bb.get());
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<CondBrInst*>(term);
auto* t = cbr->GetTrueBB();
auto* f = cbr->GetFalseBB();
bb->AddSuccessor(t);
bb->AddSuccessor(f);
t->AddPredecessor(bb.get());
f->AddPredecessor(bb.get());
break;
}
case Opcode::Ret:
// 无后继
break;
default:
break;
}
}
}
void Function::RemoveBlock(BasicBlock* bb) {
if (entry_ == bb) return;
// 步骤1清除所有 PHI 节点中对本块的引用
for (auto& other_bb : blocks_) {
if (other_bb.get() == bb) continue;
for (auto& inst : other_bb->GetInstructions()) {
if (auto* phi = dynamic_cast<PhiInst*>(inst.get())) {
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == bb) {
phi->SetOperand(i * 2, nullptr); // value
phi->SetOperand(i * 2 + 1, nullptr); // block
}
}
}
}
}
// 步骤2将块内所有指令"未定义化",用 undef (nullptr) 替换所有使用
// 这样其他引用这些指令的 place 会变成 null后续 SanitizePhis 会修复
for (auto& inst : bb->GetInstructions()) {
inst->ReplaceAllUsesWith(nullptr);
}
// 步骤3断开本块指令对操作数的引用
for (auto& inst : bb->GetInstructions()) {
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
inst->SetOperand(i, nullptr);
}
}
// 步骤4从 blocks_ 中移除
blocks_.erase(
std::remove_if(blocks_.begin(), blocks_.end(),
[bb](const std::unique_ptr<BasicBlock>& b) {
return b.get() == bb;
}),
blocks_.end());
}
} // namespace ir

@ -136,6 +136,15 @@ AllocaInst* IRBuilder::CreateAllocaArray(int num_elements,
num_elements, name);
}
AllocaInst* IRBuilder::CreateAllocaArrayF32(int num_elements,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrFloat32Type(),
num_elements, name);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index,
const std::string& name) {
if (!insert_block_) {
@ -237,4 +246,28 @@ FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) {
return insert_block_->Append<FPToSIInst>(val, name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Prepend<PhiInst>(std::move(ty), name);
}
void IRBuilder::CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
// declare memset if not already declared
if (!mod.HasExternalDecl("memset")) {
mod.DeclareExternalFunc("memset", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
}
int byte_count = num_elements * 4;
insert_block_->Append<CallInst>(
std::string("memset"), Type::GetVoidType(),
std::vector<Value*>{ptr, ctx.GetConstInt(0), ctx.GetConstInt(byte_count)},
std::string(""));
}
} // namespace ir

@ -1,8 +1,13 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <algorithm>
#include <unordered_map>
#include "utils/Log.h"
@ -44,87 +49,274 @@ static const char* FPredToStr(FCmpPredicate pred) {
return "?";
}
static std::string ValStr(const Value* v) {
using RenameMap = std::unordered_map<const Value*, int>;
static std::string ValStr(const Value* v, const RenameMap& rm) {
if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
if (dynamic_cast<const ConstantInt*>(v))
return std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::to_string(cf->GetValue());
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
return oss.str();
}
// BasicBlock: 打印为 label %name
if (dynamic_cast<const BasicBlock*>(v)) {
if (dynamic_cast<const BasicBlock*>(v))
return "%" + v->GetName();
}
// GlobalVariable: 打印为 @name
if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) {
if (gv->IsArray()) {
// 数组全局变量的指针getelementptr [N x i32], [N x i32]* @name, i32 0, i32 0
return "getelementptr ([" + std::to_string(gv->GetNumElements()) +
" x i32], [" + std::to_string(gv->GetNumElements()) +
" x i32]* @" + gv->GetName() + ", i32 0, i32 0)";
const char* et = gv->IsFloat() ? "float" : "i32";
return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) +
" x " + et + "], [" + std::to_string(gv->GetNumElements()) +
" x " + et + "]* @" + gv->GetName() + ", i32 0, i32 0)";
}
return "@" + v->GetName();
}
auto it = rm.find(v);
if (it != rm.end()) return "%" + std::to_string(it->second);
return "%" + v->GetName();
}
static std::string TypeVal(const Value* v) {
static std::string TypeVal(const Value* v, const RenameMap& rm) {
if (!v) return "void";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::string(TypeToStr(*ci->GetType())) + " " +
std::to_string(ci->GetValue());
}
if (dynamic_cast<const ConstantInt*>(v))
return std::string(TypeToStr(*v->GetType())) + " " +
std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::string(TypeToStr(*cf->GetType())) + " " +
std::to_string(cf->GetValue());
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
return oss.str();
}
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v, rm);
}
// Print one instruction (non-alloca) using rename map
static void PrintInst(const Instruction* inst, std::ostream& os,
const RenameMap& rm) {
auto N = [&](const Value* v) -> std::string {
auto it = rm.find(v);
if (it != rm.end()) return std::to_string(it->second);
return v->GetName();
};
auto VS = [&](const Value* v) { return ValStr(v, rm); };
auto TV = [&](const Value* v) { return TypeVal(v, rm); };
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op = nullptr;
switch (bin->GetOpcode()) {
case Opcode::Add: op = "add"; break;
case Opcode::Sub: op = "sub"; break;
case Opcode::Mul: op = "mul"; break;
case Opcode::Div: op = "sdiv"; break;
case Opcode::Mod: op = "srem"; break;
default: op = "?"; break;
}
os << " %" << N(bin) << " = " << op << " i32 "
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
break;
}
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op = nullptr;
switch (bin->GetOpcode()) {
case Opcode::FAdd: op = "fadd"; break;
case Opcode::FSub: op = "fsub"; break;
case Opcode::FMul: op = "fmul"; break;
case Opcode::FDiv: op = "fdiv"; break;
default: op = "?"; break;
}
os << " %" << N(bin) << " = " << op << " float "
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " %" << N(cmp) << " = icmp " << PredToStr(cmp->GetPredicate())
<< " i32 " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " %" << N(cmp) << " = fcmp " << FPredToStr(cmp->GetPredicate())
<< " float " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* al = static_cast<const AllocaInst*>(inst);
const char* et = al->GetType()->IsPtrFloat32() ? "float" : "i32";
if (al->IsArray())
os << " %" << N(al) << " = alloca " << et << ", i32 " << al->GetNumElements() << "\n";
else
os << " %" << N(al) << " = alloca " << et << "\n";
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
bool fp = gep->GetBasePtr()->GetType()->IsPtrFloat32();
os << " %" << N(gep) << " = getelementptr " << (fp ? "float" : "i32")
<< ", " << (fp ? "float*" : "i32*") << " "
<< VS(gep->GetBasePtr()) << ", i32 " << VS(gep->GetIndex()) << "\n";
break;
}
case Opcode::Load: {
auto* ld = static_cast<const LoadInst*>(inst);
bool fp = ld->GetPtr()->GetType()->IsPtrFloat32();
os << " %" << N(ld) << " = load " << (fp ? "float" : "i32")
<< ", " << (fp ? "float*" : "i32*") << " " << VS(ld->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* st = static_cast<const StoreInst*>(inst);
os << " store " << TV(st->GetValue()) << ", "
<< TypeToStr(*st->GetPtr()->GetType()) << " " << VS(st->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->HasValue()) os << " ret void\n";
else os << " ret " << TV(ret->GetValue()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BrInst*>(inst);
os << " br label %" << br->GetTarget()->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << VS(cbr->GetCond()) << ", label %"
<< cbr->GetTrueBB()->GetName() << ", label %"
<< cbr->GetFalseBB()->GetName() << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
if (!call->IsVoid() && !call->GetName().empty())
os << " %" << N(call) << " = ";
else
os << " ";
os << "call " << (call->IsVoid() ? "void" : TypeToStr(*call->GetType()))
<< " @" << call->GetCalleeName() << "(";
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
os << TV(call->GetArg(i));
}
os << ")\n";
break;
}
case Opcode::ZExt: {
auto* ze = static_cast<const ZExtInst*>(inst);
os << " %" << N(ze) << " = zext i1 " << VS(ze->GetSrc()) << " to i32\n";
break;
}
case Opcode::SIToFP: {
auto* si = static_cast<const SIToFPInst*>(inst);
os << " %" << N(si) << " = sitofp i32 " << VS(si->GetSrc()) << " to float\n";
break;
}
case Opcode::FPToSI: {
auto* fp = static_cast<const FPToSIInst*>(inst);
os << " %" << N(fp) << " = fptosi float " << VS(fp->GetSrc()) << " to i32\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " %" << N(phi) << " = phi " << TypeToStr(*phi->GetType()) << " ";
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (i > 0) os << ", ";
os << "[ " << VS(phi->GetIncomingValue(i)) << ", %"
<< phi->GetIncomingBlock(i)->GetName() << " ]";
}
os << "\n";
break;
}
}
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v);
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
// 1. 全局变量/常量
for (const auto& g : module.GetGlobalVariables()) {
if (g->IsArray()) {
// 全局数组zeroinitializer
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
for (const auto& gv : module.GetGlobalVariables()) {
if (!gv) continue;
if (gv->IsConstant()) {
os << "@" << gv->GetName() << " = constant i32 " << gv->GetInitVal() << "\n";
} else if (gv->IsArray()) {
const char* et = gv->IsFloat() ? "float" : "i32";
os << "@" << gv->GetName() << " = global [" << gv->GetNumElements()
<< " x " << et << "] ";
if (!gv->HasInitVals()) {
os << "zeroinitializer\n";
} else if (gv->IsFloat()) {
const auto& vals = gv->GetInitValsF();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](float f){ return f == 0.0f; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
float fv = (i < (int)vals.size()) ? vals[i] : 0.0f;
double d = static_cast<double>(fv);
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
os << oss.str();
}
os << "]\n";
}
} else {
os << "@" << g->GetName() << " = global [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
const auto& vals = gv->GetInitVals();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](int v){ return v == 0; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
os << "i32 " << (i < (int)vals.size() ? vals[i] : 0);
}
os << "]\n";
}
}
} else {
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant i32 " << g->GetInitVal()
<< "\n";
} else {
os << "@" << g->GetName() << " = global i32 " << g->GetInitVal()
<< "\n";
if(gv->IsFloat()) {
double d = static_cast<double>(gv->GetInitValF());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
os << "@" << gv->GetName() << " = global float " << oss.str() << "\n";
} else
{
os << "@" << gv->GetName() << " = global i32 " << gv->GetInitVal() << "\n";
}
}
}
if (!module.GetGlobalVariables().empty()) os << "\n";
// 2. 外部函数声明
// 2. 外部声明
for (const auto& decl : module.GetExternalDecls()) {
os << "declare " << TypeToStr(*decl.ret_type) << " @" << decl.name << "(";
for (size_t i = 0; i < decl.param_types.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToStr(*decl.param_types[i]);
}
if (decl.is_variadic) {
if (!decl.param_types.empty()) os << ", ";
os << "...";
}
os << ")\n";
}
if (!module.GetExternalDecls().empty()) os << "\n";
// 3. 函数定义
for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName()
<< "(";
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName() << "(";
for (size_t i = 0; i < func->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
auto* arg = func->GetArgument(i);
@ -132,172 +324,63 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
os << ") {\n";
// Build rename map: alloca instructions first (in block order), then rest
RenameMap rm;
int next_id = 0;
auto assign = [&](const Value* v) {
if (!v) return;
if (dynamic_cast<const ConstantInt*>(v)) return;
if (dynamic_cast<const ConstantFloat*>(v)) return;
if (dynamic_cast<const BasicBlock*>(v)) return;
if (dynamic_cast<const GlobalVariable*>(v)) return;
if (dynamic_cast<const Argument*>(v)) return;
if (rm.count(v) == 0) rm[v] = next_id++;
};
// Pass 1: all allocas across all blocks
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca) assign(ip.get());
}
// Pass 2: all non-alloca instructions in block order
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca) assign(ip.get());
}
// Print: entry block first with all allocas hoisted, then rest
bool first_bb = true;
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
os << bb->GetName() << ":\n";
for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get();
switch (inst->GetOpcode()) {
// ── 算术 ──────────────────────────────────────────────────────────
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op_str = nullptr;
switch (bin->GetOpcode()) {
case Opcode::Add: op_str = "add"; break;
case Opcode::Sub: op_str = "sub"; break;
case Opcode::Mul: op_str = "mul"; break;
case Opcode::Div: op_str = "sdiv"; break;
case Opcode::Mod: op_str = "srem"; break;
default: op_str = "?"; break;
}
os << " %" << bin->GetName() << " = " << op_str << " i32 "
<< ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs())
<< "\n";
break;
}
// ── 浮点算术 ──────────────────────────────────────────────────────
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op_str = nullptr;
switch (bin->GetOpcode()) {
case Opcode::FAdd: op_str = "fadd"; break;
case Opcode::FSub: op_str = "fsub"; break;
case Opcode::FMul: op_str = "fmul"; break;
case Opcode::FDiv: op_str = "fdiv"; break;
default: op_str = "?"; break;
}
os << " %" << bin->GetName() << " = " << op_str << " float "
<< ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs())
<< "\n";
break;
}
// ── 比较 ──────────────────────────────────────────────────────────
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " %" << cmp->GetName() << " = icmp "
<< PredToStr(cmp->GetPredicate()) << " i32 "
<< ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs())
<< "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " %" << cmp->GetName() << " = fcmp "
<< FPredToStr(cmp->GetPredicate()) << " float "
<< ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs())
<< "\n";
break;
}
// ── 内存 ──────────────────────────────────────────────────────────
case Opcode::Alloca: {
auto* al = static_cast<const AllocaInst*>(inst);
const char* elem_type = al->GetType()->IsPtrFloat32() ? "float" : "i32";
if (al->IsArray()) {
os << " %" << al->GetName() << " = alloca " << elem_type << ", i32 "
<< al->GetNumElements() << "\n";
} else {
os << " %" << al->GetName() << " = alloca " << elem_type << "\n";
}
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
os << " %" << gep->GetName()
<< " = getelementptr i32, i32* "
<< ValStr(gep->GetBasePtr()) << ", i32 "
<< ValStr(gep->GetIndex()) << "\n";
break;
}
case Opcode::Load: {
auto* ld = static_cast<const LoadInst*>(inst);
const char* val_type = ld->GetType()->IsFloat32() ? "float" : "i32";
const char* ptr_type = ld->GetPtr()->GetType()->IsPtrFloat32() ? "float*" : "i32*";
os << " %" << ld->GetName() << " = load " << val_type << ", " << ptr_type << " "
<< ValStr(ld->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* st = static_cast<const StoreInst*>(inst);
os << " store " << TypeVal(st->GetValue()) << ", "
<< TypeToStr(*st->GetPtr()->GetType()) << " "
<< ValStr(st->GetPtr()) << "\n";
break;
}
// ── 控制流 ────────────────────────────────────────────────────────
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->HasValue()) {
os << " ret void\n";
} else {
auto* v = ret->GetValue();
os << " ret " << TypeVal(v) << "\n";
}
break;
}
case Opcode::Br: {
auto* br = static_cast<const BrInst*>(inst);
os << " br label %" << br->GetTarget()->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << ValStr(cbr->GetCond()) << ", label %"
<< cbr->GetTrueBB()->GetName() << ", label %"
<< cbr->GetFalseBB()->GetName() << "\n";
break;
}
// ── 调用 ──────────────────────────────────────────────────────────
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
std::string ret_type_str;
if (call->IsVoid()) {
ret_type_str = "void";
} else {
ret_type_str = TypeToStr(*call->GetType());
}
// 打印赋值部分(仅当有返回值时)
if (!call->IsVoid() && !call->GetName().empty()) {
os << " %" << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << ret_type_str << " @" << call->GetCalleeName()
<< "(";
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
auto* arg = call->GetArg(i);
os << TypeVal(arg);
}
os << ")\n";
break;
}
// ── 类型转换 ──────────────────────────────────────────────────────
case Opcode::ZExt: {
auto* ze = static_cast<const ZExtInst*>(inst);
os << " %" << ze->GetName() << " = zext i1 "
<< ValStr(ze->GetSrc()) << " to i32\n";
break;
}
case Opcode::SIToFP: {
auto* si = static_cast<const SIToFPInst*>(inst);
os << " %" << si->GetName() << " = sitofp i32 "
<< ValStr(si->GetSrc()) << " to float\n";
break;
}
case Opcode::FPToSI: {
auto* fp = static_cast<const FPToSIInst*>(inst);
os << " %" << fp->GetName() << " = fptosi float "
<< ValStr(fp->GetSrc()) << " to i32\n";
break;
}
if (first_bb) {
first_bb = false;
// Print all allocas from all blocks (only for entry block)
for (const auto& bb2 : func->GetBlocks()) {
if (!bb2) continue;
for (const auto& ip : bb2->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca)
PrintInst(ip.get(), os, rm);
}
// Print PHI nodes of entry block
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Phi)
PrintInst(ip.get(), os, rm);
// Print non-alloca non-phi instructions of entry block
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
PrintInst(ip.get(), os, rm);
} else {
// Non-entry blocks: skip allocas (already printed)
// Print PHI nodes first
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Phi)
PrintInst(ip.get(), os, rm);
// Print non-alloca non-phi instructions
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
PrintInst(ip.get(), os, rm);
}
}
os << "}\n\n";

@ -48,6 +48,12 @@ bool Instruction::IsTerminator() const {
opcode_ == Opcode::CondBr;
}
void Instruction::RemoveFromParent() {
if (parent_) {
parent_->RemoveInstruction(this);
}
}
BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
@ -224,7 +230,11 @@ AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, int num_elements,
// ─── GepInst ──────────────────────────────────────────────────────────────────
GepInst::GepInst(Value* base_ptr, Value* index, std::string name)
: Instruction(Opcode::Gep, Type::GetPtrInt32Type(), std::move(name)) {
: Instruction(Opcode::Gep,
(base_ptr && base_ptr->GetType()->IsPtrFloat32())
? Type::GetPtrFloat32Type()
: Type::GetPtrInt32Type(),
std::move(name)) {
if (!base_ptr || !index) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数"));
}
@ -263,12 +273,30 @@ Value* StoreInst::GetPtr() const { return GetOperand(1); }
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
// ─── PhiInst ──────────────────────────────────────────────────────────────────
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* val, BasicBlock* bb) {
AddOperand(val);
AddOperand(bb);
}
BasicBlock* PhiInst::GetIncomingBlock(size_t i) const {
return static_cast<BasicBlock*>(GetOperand(i * 2 + 1));
}
// ─── GlobalVariable ────────────────────────────────────────────────────────────
GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements)
: Value(Type::GetPtrInt32Type(), std::move(name)),
int num_elements, bool is_array_decl,
bool is_float)
: Value(is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type(),
std::move(name)),
is_const_(is_const),
is_float_(is_float),
init_val_(init_val),
num_elements_(num_elements) {}
init_val_f_(0.0f),
num_elements_(num_elements),
is_array_decl_(is_array_decl) {}
} // namespace ir

@ -28,9 +28,10 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
// ─── 全局变量管理 ─────────────────────────────────────────────────────────────
GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
bool is_const, int init_val,
int num_elements) {
int num_elements, bool is_array_decl,
bool is_float) {
globals_.push_back(
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements));
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements, is_array_decl, is_float));
GlobalVariable* g = globals_.back().get();
global_map_[name] = g;
return g;

@ -1,4 +1,205 @@
// 支配树分析:
// - 构建/查询 Dominator Tree 及相关关系
// - 为 mem2reg、CFG 优化与循环分析提供基础能力
// - 使用 Cooper-Harvey-Kennedy 算法,近线性时间复杂度
#include "ir/IR.h"
#include <algorithm>
#include <functional>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
void DominatorTree::Compute(Function& func) {
func.RebuildCFG();
// Build block list and reverse postorder (RPO)
std::vector<BasicBlock*> blocks;
std::unordered_map<BasicBlock*, int> rpo;
{
std::vector<BasicBlock*> rpo_vec;
std::unordered_set<BasicBlock*> visited;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* bb) {
if (!bb || visited.count(bb)) return;
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
dfs(succ);
}
rpo_vec.push_back(bb);
};
dfs(func.GetEntry());
// Reverse to get RPO (postorder reversed)
std::reverse(rpo_vec.begin(), rpo_vec.end());
blocks = rpo_vec;
for (int i = 0; i < (int)blocks.size(); ++i) {
rpo[blocks[i]] = i;
}
}
if (blocks.empty()) return;
int n = (int)blocks.size();
auto* entry = func.GetEntry();
if (!entry) return;
// ─── 1. CHK algorithm for immediate dominators ─────────────────────────
idom_.clear();
idom_[entry] = entry; // entry is its own dominator
// Intersect: find common ancestor of b1 and b2 walking up the dom tree
// Uses RPO number: a dominator always has lower RPO number
auto intersect = [&](BasicBlock* b1, BasicBlock* b2) -> BasicBlock* {
auto i1 = rpo.find(b1), i2 = rpo.find(b2);
if (i1 == rpo.end() || i2 == rpo.end()) return entry;
int r1 = i1->second, r2 = i2->second;
while (b1 != b2) {
while (r1 > r2) {
auto it = idom_.find(b1);
if (it == idom_.end() || it->second == b1) return b1;
b1 = it->second;
r1 = rpo[b1];
}
while (r2 > r1) {
auto it = idom_.find(b2);
if (it == idom_.end() || it->second == b2) return b2;
b2 = it->second;
r2 = rpo[b2];
}
}
return b1;
};
bool changed = true;
while (changed) {
changed = false;
// Process in RPO (skip entry which is first in RPO)
for (int i = 1; i < n; ++i) {
auto* bb = blocks[i];
// Find first predecessor with defined IDOM
BasicBlock* new_idom = nullptr;
for (auto* pred : bb->GetPredecessors()) {
if (idom_.count(pred) && pred != bb) {
new_idom = pred;
break;
}
}
if (!new_idom) continue;
// Intersect with remaining predecessors
for (auto* pred : bb->GetPredecessors()) {
if (pred == new_idom || pred == bb) continue;
if (idom_.count(pred)) {
new_idom = intersect(pred, new_idom);
}
}
auto old = idom_.find(bb);
if (old == idom_.end() || old->second != new_idom) {
idom_[bb] = new_idom;
changed = true;
}
}
}
// Entry is its own IDOM, set to nullptr for external queries
idom_[entry] = nullptr;
// Unreached blocks get entry as IDOM
for (auto* bb : blocks) {
if (!idom_.count(bb)) idom_[bb] = entry;
}
// ─── 2. Build children map and dom levels ──────────────────────────────
children_.clear();
dom_level_.clear();
for (auto& [child, parent] : idom_) {
if (parent) children_[parent].push_back(child);
}
// BFS to compute dom levels
std::queue<BasicBlock*> q;
dom_level_[entry] = 0;
q.push(entry);
while (!q.empty()) {
auto* cur = q.front();
q.pop();
size_t cur_level = dom_level_[cur];
auto it = children_.find(cur);
if (it != children_.end()) {
for (auto* child : it->second) {
dom_level_[child] = cur_level + 1;
q.push(child);
}
}
}
// ─── 3. Compute dominance frontier ─────────────────────────────────────
df_.clear();
for (int i = 0; i < n; ++i) {
auto* b = blocks[i];
if (b->GetPredecessors().size() < 2) continue;
for (auto* p : b->GetPredecessors()) {
auto* runner = p;
auto* b_idom = GetIDom(b);
while (runner != b_idom) {
if (!runner) break;
df_[runner].push_back(b);
runner = GetIDom(runner);
}
}
}
// Deduplicate DF entries
for (auto& [bb, vec] : df_) {
std::sort(vec.begin(), vec.end());
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
}
// ─── 4. Compute DFS order of dominator tree ────────────────────────────
df_order_.clear();
visited_.clear();
std::function<void(BasicBlock*)> dfs_tree = [&](BasicBlock* bb) {
if (!bb || visited_.count(bb)) return;
visited_.insert(bb);
df_order_.push_back(bb);
auto it = children_.find(bb);
if (it != children_.end()) {
for (auto* child : it->second) {
dfs_tree(child);
}
}
};
dfs_tree(entry);
}
BasicBlock* DominatorTree::GetIDom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return (it != idom_.end()) ? it->second : nullptr;
}
const std::vector<BasicBlock*>& DominatorTree::GetChildren(
BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = children_.find(bb);
return (it != children_.end()) ? it->second : empty;
}
const std::vector<BasicBlock*>& DominatorTree::GetDominanceFrontier(
BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return (it != df_.end()) ? it->second : empty;
}
bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const {
if (a == b) return true;
BasicBlock* runner = b;
while (runner) {
runner = GetIDom(runner);
if (runner == a) return true;
}
return false;
}
} // namespace ir

@ -2,3 +2,143 @@
// - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成
#include "ir/IR.h"
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
// 简化常量条件跳转
bool SimplifyConstantBranches(Function& func, Context& ctx) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (auto* cbr = dynamic_cast<CondBrInst*>(term)) {
if (auto* ci = dynamic_cast<ConstantInt*>(cbr->GetCond())) {
BasicBlock* target =
(ci->GetValue() != 0) ? cbr->GetTrueBB() : cbr->GetFalseBB();
// 替换条件跳转为无条件跳转
cbr->RemoveFromParent();
bb->Append<BrInst>(target);
changed = true;
}
}
}
return changed;
}
// 合并空基本块:如果一个块只有一个 br 指令,可以绕过它
bool MergeEmptyBlocks(Function& func) {
bool changed = false;
bool local_changed = true;
// 迭代处理,因为合并一个空块可能产生新的空块或改变前驱关系
while (local_changed) {
local_changed = false;
func.RebuildCFG();
BasicBlock* block_to_remove = nullptr;
for (auto& bb : func.GetBlocks()) {
auto* block = bb.get();
if (block == func.GetEntry()) continue;
const auto& insts = block->GetInstructions();
if (insts.size() != 1) continue;
auto* br = dynamic_cast<BrInst*>(insts[0].get());
if (!br) continue;
auto* target = br->GetTarget();
if (target == block) continue;
// 不能合并目标也是空块的块(将在下一轮处理)
if (target->GetInstructions().size() == 1 &&
dynamic_cast<BrInst*>(target->GetInstructions()[0].get()) &&
target != func.GetEntry()) {
continue;
}
// 只重定向仍然引用此块的前驱
for (auto* pred : block->GetPredecessors()) {
if (pred == block) continue;
auto& p_insts = pred->GetInstructions();
if (p_insts.empty()) continue;
auto* p_term = p_insts.back().get();
if (auto* cbr = dynamic_cast<CondBrInst*>(p_term)) {
if (cbr->GetTrueBB() == block) cbr->SetOperand(1, target);
if (cbr->GetFalseBB() == block) cbr->SetOperand(2, target);
} else if (auto* p_br = dynamic_cast<BrInst*>(p_term)) {
if (p_br->GetTarget() == block) p_br->SetOperand(0, target);
}
// 更新 target 中引用 block 的 PHI 节点,改为引用 pred
for (auto& t_inst : target->GetInstructions()) {
if (auto* phi = dynamic_cast<PhiInst*>(t_inst.get())) {
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
phi->SetOperand(i * 2 + 1, pred);
}
}
}
}
}
block_to_remove = block;
local_changed = true;
changed = true;
break; // 只合并一个,下一轮迭代重建 CFG
}
if (block_to_remove) {
func.RemoveBlock(block_to_remove);
}
}
if (changed) func.RebuildCFG();
return changed;
}
// 删除不可达基本块
bool RemoveUnreachableBlocks(Function& func) {
func.RebuildCFG();
// BFS from entry
std::unordered_set<BasicBlock*> reachable;
std::queue<BasicBlock*> q;
auto* entry = func.GetEntry();
if (!entry) return false;
q.push(entry);
reachable.insert(entry);
while (!q.empty()) {
auto* bb = q.front();
q.pop();
for (auto* succ : bb->GetSuccessors()) {
if (!succ) continue;
if (reachable.insert(succ).second) q.push(succ);
}
}
std::vector<BasicBlock*> unreachable;
for (auto& bb : func.GetBlocks()) {
if (!reachable.count(bb.get())) unreachable.push_back(bb.get());
}
for (auto* bb : unreachable) {
func.RemoveBlock(bb);
}
if (!unreachable.empty()) func.RebuildCFG();
return !unreachable.empty();
}
} // namespace
bool RunCFGSimplify(Function& func, Context& ctx) {
bool changed = false;
changed |= SimplifyConstantBranches(func, ctx);
changed |= MergeEmptyBlocks(func);
changed |= RemoveUnreachableBlocks(func);
return changed;
}
} // namespace ir

@ -1,4 +1,123 @@
// 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式
// - 典型放置在 ConstFold 之后、DCE 之前
// - 当前为 Lab4 的框架占位,具体算法由实验实现
// - 局部值编号:在单个基本块内消除重复计算
#include "ir/IR.h"
#include <cstdint>
#include <sstream>
#include <string>
#include <unordered_map>
namespace ir {
namespace {
// 为操作数生成唯一标识:常量使用值,否则使用指针
std::string ValKey(Value* v) {
if (auto* ci = dynamic_cast<ConstantInt*>(v))
return "ci" + std::to_string(ci->GetValue());
if (auto* cf = dynamic_cast<ConstantFloat*>(v)) {
// 使用 IEEE 754 位表示
union { float f; uint32_t i; } u;
u.f = cf->GetValue();
return "cf" + std::to_string(u.i);
}
// 非常量使用指针地址作为唯一标识
std::ostringstream oss;
oss << "p" << reinterpret_cast<uintptr_t>(v);
return oss.str();
}
// 为可消除的指令生成 hash key
std::string MakeKey(Instruction* inst) {
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod:
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv: {
auto* bin = static_cast<BinaryInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(bin->GetLhs()) + "|" + ValKey(bin->GetRhs());
}
case Opcode::ICmp: {
auto* cmp = static_cast<ICmpInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
std::to_string(static_cast<int>(cmp->GetPredicate())) + "|" +
ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs());
}
case Opcode::FCmp: {
auto* cmp = static_cast<FCmpInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
std::to_string(static_cast<int>(cmp->GetPredicate())) + "|" +
ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs());
}
case Opcode::Gep: {
auto* gep = static_cast<GepInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(gep->GetBasePtr()) + "|" + ValKey(gep->GetIndex());
}
case Opcode::Load: {
auto* ld = static_cast<LoadInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(ld->GetPtr());
}
case Opcode::ZExt: {
auto* ze = static_cast<ZExtInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(ze->GetSrc());
}
case Opcode::SIToFP: {
auto* si = static_cast<SIToFPInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(si->GetSrc());
}
case Opcode::FPToSI: {
auto* fs = static_cast<FPToSIInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(fs->GetSrc());
}
default: return "";
}
}
} // namespace
bool RunCSE(Function& func) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
std::unordered_map<std::string, Value*> available;
std::vector<Instruction*> to_remove;
for (auto& inst : bb->GetInstructions()) {
auto* ip = inst.get();
std::string key = MakeKey(ip);
if (key.empty()) {
// 不可消除的指令:如果它有结果,可以考虑将其加入可用集
// 但为了简单,这里不处理
continue;
}
auto it = available.find(key);
if (it != available.end()) {
// 找到已有的等价指令,替换使用
ip->ReplaceAllUsesWith(it->second);
to_remove.push_back(ip);
changed = true;
} else {
available[key] = ip;
}
}
for (auto* ip : to_remove) {
for (size_t i = 0; i < ip->GetNumOperands(); ++i)
ip->SetOperand(i, nullptr);
ip->RemoveFromParent();
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,187 @@
// IR 常量折叠:
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪)
// - 简化常量控制流分支
#include "ir/IR.h"
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
// 在释放指令前断开其 use-def 链,避免其他值的 uses_ 中有悬空指针
static void DetachAndRemove(Instruction* inst) {
// 先断开自身对操作数的引用,清除 use-def 链
for (size_t i = 0; i < inst->GetNumOperands(); ++i)
inst->SetOperand(i, nullptr);
// 然后从父块中移除(清除父指针后不会再次尝试访问)
if (auto* parent = inst->GetParent()) {
parent->RemoveInstruction(inst);
}
}
bool FoldICmp(ICmpInst* cmp, Context& ctx) {
auto* lhs = dynamic_cast<ConstantInt*>(cmp->GetLhs());
auto* rhs = dynamic_cast<ConstantInt*>(cmp->GetRhs());
if (!lhs || !rhs) return false;
int lv = lhs->GetValue(), rv = rhs->GetValue();
bool result = false;
switch (cmp->GetPredicate()) {
case ICmpPredicate::EQ: result = lv == rv; break;
case ICmpPredicate::NE: result = lv != rv; break;
case ICmpPredicate::SLT: result = lv < rv; break;
case ICmpPredicate::SLE: result = lv <= rv; break;
case ICmpPredicate::SGT: result = lv > rv; break;
case ICmpPredicate::SGE: result = lv >= rv; break;
}
cmp->ReplaceAllUsesWith(ctx.GetConstInt(result ? 1 : 0));
DetachAndRemove(cmp);
return true;
}
bool FoldFCmp(FCmpInst* cmp, Context& ctx) {
auto* lhs = dynamic_cast<ConstantFloat*>(cmp->GetLhs());
auto* rhs = dynamic_cast<ConstantFloat*>(cmp->GetRhs());
if (!lhs || !rhs) return false;
float lv = lhs->GetValue(), rv = rhs->GetValue();
bool result = false;
switch (cmp->GetPredicate()) {
case FCmpPredicate::OEQ: result = lv == rv; break;
case FCmpPredicate::ONE: result = lv != rv; break;
case FCmpPredicate::OLT: result = lv < rv; break;
case FCmpPredicate::OLE: result = lv <= rv; break;
case FCmpPredicate::OGT: result = lv > rv; break;
case FCmpPredicate::OGE: result = lv >= rv; break;
}
cmp->ReplaceAllUsesWith(ctx.GetConstInt(result ? 1 : 0));
DetachAndRemove(cmp);
return true;
}
bool FoldZExt(ZExtInst* zext, Context& ctx) {
auto* src = dynamic_cast<ConstantInt*>(zext->GetSrc());
if (!src) return false;
zext->ReplaceAllUsesWith(ctx.GetConstInt(src->GetValue() != 0 ? 1 : 0));
DetachAndRemove(zext);
return true;
}
bool FoldSIToFP(SIToFPInst* inst, Context& ctx) {
auto* src = dynamic_cast<ConstantInt*>(inst->GetSrc());
if (!src) return false;
inst->ReplaceAllUsesWith(ctx.GetConstFloat(static_cast<float>(src->GetValue())));
DetachAndRemove(inst);
return true;
}
bool FoldFPToSI(FPToSIInst* inst, Context& ctx) {
auto* src = dynamic_cast<ConstantFloat*>(inst->GetSrc());
if (!src) return false;
inst->ReplaceAllUsesWith(ctx.GetConstInt(static_cast<int>(src->GetValue())));
DetachAndRemove(inst);
return true;
}
// Fold constant binary operations (int and float)
bool FoldBinaryWithCtx(BinaryInst* bin, Context& ctx) {
auto* lhs_c = dynamic_cast<ConstantInt*>(bin->GetLhs());
auto* rhs_c = dynamic_cast<ConstantInt*>(bin->GetRhs());
auto* lhs_f = dynamic_cast<ConstantFloat*>(bin->GetLhs());
auto* rhs_f = dynamic_cast<ConstantFloat*>(bin->GetRhs());
if (lhs_c && rhs_c) {
int lv = lhs_c->GetValue(), rv = rhs_c->GetValue();
int result = 0;
bool valid = true;
switch (bin->GetOpcode()) {
case Opcode::Add: result = lv + rv; break;
case Opcode::Sub: result = lv - rv; break;
case Opcode::Mul: result = lv * rv; break;
case Opcode::Div: if (rv != 0) result = lv / rv; else valid = false; break;
case Opcode::Mod: if (rv != 0) result = lv % rv; else valid = false; break;
default: valid = false; break;
}
if (valid) {
bin->ReplaceAllUsesWith(ctx.GetConstInt(result));
DetachAndRemove(bin);
return true;
}
}
if (lhs_f && rhs_f) {
float lv = lhs_f->GetValue(), rv = rhs_f->GetValue();
float result = 0.0f;
bool valid = true;
switch (bin->GetOpcode()) {
case Opcode::FAdd: result = lv + rv; break;
case Opcode::FSub: result = lv - rv; break;
case Opcode::FMul: result = lv * rv; break;
case Opcode::FDiv: if (rv != 0.0f) result = lv / rv; else valid = false; break;
default: valid = false; break;
}
if (valid) {
bin->ReplaceAllUsesWith(ctx.GetConstFloat(result));
DetachAndRemove(bin);
return true;
}
}
return false;
}
} // namespace
bool RunConstFold(Function& func, Context& ctx) {
bool changed = false;
std::unordered_set<void*> removed;
bool any_changed = true;
while (any_changed) {
any_changed = false;
for (auto& bb : func.GetBlocks()) {
// 每轮重新收集(因为指令列表在变化)
std::vector<Instruction*> insts;
for (auto& inst : bb->GetInstructions())
insts.push_back(inst.get());
for (auto* inst : insts) {
if (removed.count(inst)) continue;
bool folded = false;
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod:
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv:
folded = FoldBinaryWithCtx(static_cast<BinaryInst*>(inst), ctx);
break;
case Opcode::ICmp:
folded = FoldICmp(static_cast<ICmpInst*>(inst), ctx);
break;
case Opcode::FCmp:
folded = FoldFCmp(static_cast<FCmpInst*>(inst), ctx);
break;
case Opcode::ZExt:
folded = FoldZExt(static_cast<ZExtInst*>(inst), ctx);
break;
case Opcode::SIToFP:
folded = FoldSIToFP(static_cast<SIToFPInst*>(inst), ctx);
break;
case Opcode::FPToSI:
folded = FoldFPToSI(static_cast<FPToSIInst*>(inst), ctx);
break;
default: break;
}
if (folded) {
removed.insert(inst);
any_changed = true;
changed = true;
}
}
}
}
return changed;
}
} // namespace ir

@ -1,5 +1,63 @@
// 常量传播Constant Propagation
// - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
#include "ir/IR.h"
#include <queue>
#include <unordered_map>
#include <vector>
namespace ir {
bool RunConstProp(Function& func, Context& ctx) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
std::vector<Instruction*> insts;
for (auto& inst : bb->GetInstructions())
insts.push_back(inst.get());
for (auto* inst : insts) {
if (inst->GetParent() == nullptr) continue;
// 检查是否为"复制"类指令:直接将一个操作数作为结果传播
// 实际上常量传播由 ConstFold 配合 use-def 链完成
// 这里处理简单的常量替换
switch (inst->GetOpcode()) {
case Opcode::ZExt: {
auto* ze = static_cast<ZExtInst*>(inst);
if (auto* ci = dynamic_cast<ConstantInt*>(ze->GetSrc())) {
ze->ReplaceAllUsesWith(ctx.GetConstInt(ci->GetValue() != 0 ? 1 : 0));
ze->RemoveFromParent();
changed = true;
}
break;
}
case Opcode::SIToFP: {
auto* si = static_cast<SIToFPInst*>(inst);
if (auto* ci = dynamic_cast<ConstantInt*>(si->GetSrc())) {
si->ReplaceAllUsesWith(
ctx.GetConstFloat(static_cast<float>(ci->GetValue())));
si->RemoveFromParent();
changed = true;
}
break;
}
case Opcode::FPToSI: {
auto* fp = static_cast<FPToSIInst*>(inst);
if (auto* cf = dynamic_cast<ConstantFloat*>(fp->GetSrc())) {
fp->ReplaceAllUsesWith(
ctx.GetConstInt(static_cast<int>(cf->GetValue())));
fp->RemoveFromParent();
changed = true;
}
break;
}
default: break;
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,68 @@
// 死代码删除DCE
// - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用
// - 标记-清扫算法:先标记有用指令,再清除未标记的
#include "ir/IR.h"
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
bool HasSideEffect(Instruction* inst) {
switch (inst->GetOpcode()) {
case Opcode::Store:
case Opcode::Call:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Ret:
return true;
default:
return false;
}
}
} // namespace
bool RunDCE(Function& func) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
// 收集要删除的指令
std::vector<Instruction*> to_remove;
for (auto& inst : bb->GetInstructions()) {
auto* ip = inst.get();
if (ip->IsTerminator()) continue;
if (HasSideEffect(ip)) continue;
// 检查该指令是否有使用者
bool has_use = false;
for (const auto& use : ip->GetUses()) {
if (use.GetUser()) {
has_use = true;
break;
}
}
if (!has_use) {
to_remove.push_back(ip);
}
}
for (auto* ip : to_remove) {
// 断开该指令对操作数的引用
for (size_t i = 0; i < ip->GetNumOperands(); ++i) {
ip->SetOperand(i, nullptr);
}
ip->RemoveFromParent();
changed = true;
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,190 @@
// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
// - 插入 PHI 节点并重写使用,所有可提升的 alloca 在一次重命名遍中处理
#include "ir/IR.h"
#include <functional>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
bool IsPromotable(AllocaInst* alloca) {
if (alloca->IsArray()) return false;
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
if (auto* load = dynamic_cast<LoadInst*>(user)) {
if (load->GetPtr() != alloca) return false;
} else if (auto* store = dynamic_cast<StoreInst*>(user)) {
if (store->GetPtr() != alloca) return false;
} else {
return false;
}
}
return true;
}
void CollectStoresAndLoads(AllocaInst* alloca,
std::vector<StoreInst*>& stores,
std::vector<LoadInst*>& loads) {
for (const auto& use : alloca->GetUses()) {
if (auto* store = dynamic_cast<StoreInst*>(use.GetUser())) {
if (use.GetOperandIndex() == 1) stores.push_back(store);
} else if (auto* load = dynamic_cast<LoadInst*>(use.GetUser())) {
loads.push_back(load);
}
}
}
std::set<BasicBlock*> ComputeIDF(const std::set<BasicBlock*>& def_blocks,
DominatorTree& dt) {
std::set<BasicBlock*> df_plus;
std::vector<BasicBlock*> worklist(def_blocks.begin(), def_blocks.end());
std::set<BasicBlock*> visited(def_blocks.begin(), def_blocks.end());
while (!worklist.empty()) {
auto* bb = worklist.back();
worklist.pop_back();
for (auto* df_bb : dt.GetDominanceFrontier(bb)) {
if (df_plus.insert(df_bb).second) {
if (visited.insert(df_bb).second) worklist.push_back(df_bb);
}
}
}
return df_plus;
}
struct AllocaInfo {
AllocaInst* alloca = nullptr;
std::vector<StoreInst*> stores;
std::vector<LoadInst*> loads;
std::unordered_map<BasicBlock*, PhiInst*> phis;
std::vector<Value*> value_stack;
Value* undef_val = nullptr;
};
} // namespace
bool RunMem2Reg(Function& func, Context& ctx) {
DominatorTree dt;
dt.Compute(func);
// 收集所有可提升的 alloca
std::vector<AllocaInst*> promotable;
for (auto& bb : func.GetBlocks()) {
for (auto& inst : bb->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotable(alloca)) promotable.push_back(alloca);
}
}
}
if (promotable.empty()) return false;
// 为每个可提升的 alloca 构建信息
std::vector<AllocaInfo> infos(promotable.size());
std::unordered_map<StoreInst*, int> store_to_info;
std::unordered_map<LoadInst*, int> load_to_info;
for (size_t i = 0; i < promotable.size(); ++i) {
auto* alloca = promotable[i];
auto& info = infos[i];
info.alloca = alloca;
CollectStoresAndLoads(alloca, info.stores, info.loads);
std::set<BasicBlock*> def_blocks;
for (auto* s : info.stores) def_blocks.insert(s->GetParent());
auto val_type = alloca->GetType()->IsPtrFloat32() ? Type::GetFloat32Type()
: Type::GetInt32Type();
info.undef_val = alloca->GetType()->IsPtrFloat32()
? static_cast<Value*>(ctx.GetConstFloat(0.0f))
: static_cast<Value*>(ctx.GetConstInt(0));
info.value_stack.push_back(info.undef_val);
// 插入 PHI 节点到迭代支配边界
auto df_plus = ComputeIDF(def_blocks, dt);
for (auto* bb : df_plus) {
auto* phi = bb->Prepend<PhiInst>(val_type, "");
info.phis[bb] = phi;
}
// 建立快速查找映射
for (auto* s : info.stores) store_to_info[s] = (int)i;
for (auto* l : info.loads) load_to_info[l] = (int)i;
}
// ─── 单次重命名遍DFS 遍历支配树,同时处理所有 alloca ──────────────
std::function<void(BasicBlock*)> rename = [&](BasicBlock* bb) {
// 保存所有栈大小
std::vector<size_t> saved_sizes(infos.size());
for (size_t i = 0; i < infos.size(); ++i) {
saved_sizes[i] = infos[i].value_stack.size();
auto phi_it = infos[i].phis.find(bb);
if (phi_it != infos[i].phis.end()) {
infos[i].value_stack.push_back(phi_it->second);
}
}
// 处理块内指令
for (auto& inst_up : bb->GetInstructions()) {
auto* inst = inst_up.get();
// Skip PHI nodes (they've already been pushed onto stacks)
if (dynamic_cast<PhiInst*>(inst)) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
auto it = store_to_info.find(store);
if (it != store_to_info.end()) {
infos[it->second].value_stack.push_back(store->GetValue());
}
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
auto it = load_to_info.find(load);
if (it != load_to_info.end()) {
load->ReplaceAllUsesWith(infos[it->second].value_stack.back());
}
}
}
// 设置后继块中 PHI 节点的 incoming values
for (auto* succ : bb->GetSuccessors()) {
for (size_t i = 0; i < infos.size(); ++i) {
auto phi_it = infos[i].phis.find(succ);
if (phi_it != infos[i].phis.end()) {
phi_it->second->AddIncoming(infos[i].value_stack.back(), bb);
}
}
}
// 递归遍历支配树子节点
for (auto* child : dt.GetChildren(bb)) rename(child);
// 恢复栈
for (size_t i = 0; i < infos.size(); ++i) {
infos[i].value_stack.resize(saved_sizes[i]);
}
};
rename(func.GetEntry());
// 删除已提升的 load、store 和 alloca
// 必须先断开 use-def 链再删除,否则其他值的使用列表中会有悬空指针
for (auto& info : infos) {
for (auto* ld : info.loads) {
ld->SetOperand(0, nullptr); // 断开对 alloca 的引用
ld->RemoveFromParent();
}
for (auto* st : info.stores) {
st->SetOperand(0, nullptr); // 断开对 value 的引用
st->SetOperand(1, nullptr); // 断开对 alloca 的引用
st->RemoveFromParent();
}
info.alloca->RemoveFromParent();
}
return true;
}
} // namespace ir

@ -1 +1,89 @@
// IR Pass 管理骨架。
// IR Pass 管理:按顺序执行优化遍,支持迭代至不动点
#include "ir/IR.h"
#include <queue>
#include <unordered_set>
namespace ir {
// 前向声明(定义在各 pass 文件中)
extern bool RunMem2Reg(Function& func, Context& ctx);
extern bool RunConstFold(Function& func, Context& ctx);
extern bool RunConstProp(Function& func, Context& ctx);
extern bool RunDCE(Function& func);
extern bool RunCFGSimplify(Function& func, Context& ctx);
extern bool RunCSE(Function& func);
// 清理 PHI 节点:修复无效的 incoming 值/块引用,补齐缺失的前驱条目
static void SanitizePhis(Function& func, Context& ctx) {
func.RebuildCFG();
auto* entry = func.GetEntry();
if (!entry) return;
for (auto& bb : func.GetBlocks()) {
// 收集 PHI 节点
std::vector<PhiInst*> phis;
for (auto& inst : bb->GetInstructions()) {
if (auto* phi = dynamic_cast<PhiInst*>(inst.get()))
phis.push_back(phi);
}
if (phis.empty()) continue;
auto& preds = bb->GetPredecessors();
Value* undef_val = ctx.GetConstInt(0);
for (auto* phi : phis) {
// 收集已有的 incoming 块
std::unordered_set<BasicBlock*> existing;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
existing.insert(phi->GetIncomingBlock(i));
}
// 为每个前驱补齐缺失的 incomingLLVM 要求 PHI 覆盖所有前驱)
for (auto* pred : preds) {
if (!existing.count(pred)) {
phi->AddIncoming(undef_val, pred);
}
}
// 修复无效的 incoming
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
auto* inc_val = phi->GetIncomingValue(i);
auto* inc_bb = phi->GetIncomingBlock(i);
if (!inc_val || !inc_bb) {
phi->SetOperand(i * 2, undef_val);
phi->SetOperand(i * 2 + 1, entry);
}
}
}
}
}
void RunPasses(Module& module) {
Context& ctx = module.GetContext();
bool changed = true;
int iteration = 0;
const int kMaxIterations = 10;
for (auto& func : module.GetFunctions()) {
RunMem2Reg(*func, ctx);
}
while (changed && iteration < kMaxIterations) {
changed = false;
++iteration;
for (auto& func : module.GetFunctions()) {
changed |= RunConstFold(*func, ctx);
changed |= RunConstProp(*func, ctx);
changed |= RunCSE(*func);
// CFGSimplify 在处理 PHI 节点较多的 CFG 时会导致悬空指针
// 其功能(空块合并 + 不可达块删除)由后续 DCE + SanitizePhis 部分承担
changed |= RunDCE(*func);
SanitizePhis(*func, ctx);
}
}
}
} // namespace ir

@ -117,33 +117,101 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
for (int d : dims) total *= (d > 0 ? d : 1);
if (in_global_scope_) {
auto* gv = module_.CreateGlobalVariable(name, true, 0, total);
auto* gv = module_.CreateGlobalVariable(name, true, 0, total, true);
global_storage_map_[constDef] = gv;
global_array_dims_[constDef] = dims;
// 计算初始值并存入全局变量
if (constDef->constInitVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::vector<int> flat(total, 0);
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) { flat[pos] = EvalConstExprInt(iv->constExp()); return; }
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; fill(sub,a,sub_stride); cur=a+sub_stride; }
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur+top_stride-1)/top_stride)*top_stride; fill(sub,a,top_stride); cur=a+top_stride; }
}
}
gv->SetInitVals(flat);
}
} else {
auto* slot = builder_.CreateAllocaArray(total, name);
storage_map_[constDef] = slot;
array_dims_[constDef] = dims;
// 扁平化初始化
// 按 C 语义扁平化初始化(子列表对齐到维度边界)
if (constDef->constInitVal()) {
std::vector<int> flat;
flat.reserve(total);
std::function<void(SysYParser::ConstInitValContext*)> flatten =
[&](SysYParser::ConstInitValContext* iv) {
if (!iv) return;
if (iv->constExp()) {
flat.push_back(EvalConstExprInt(iv->constExp()));
} else {
for (auto* sub : iv->constInitVal()) flatten(sub);
}
};
flatten(constDef->constInitVal());
std::vector<int> flat(total, 0);
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) {
flat[pos] = EvalConstExprInt(iv->constExp());
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
for (int i = 0; i < total; ++i) {
int v = (i < (int)flat.size()) ? flat[i] : 0;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(v), ptr);
builder_.CreateStore(builder_.CreateConstInt(flat[i]), ptr);
}
}
}
@ -179,24 +247,119 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (in_global_scope_) {
if (!is_array) {
// 全局标量初始化器必须是常量简化处理为0
int init_val = 0;
if (ctx->initVal() && ctx->initVal()->exp()) {
try {
auto cv = sem::EvaluateExp(*ctx->initVal()->exp()->addExp());
init_val = static_cast<int>(cv.int_val);
} catch (...) {
init_val = 0;
auto* gv = module_.CreateGlobalVariable(name, false, 0, 1, false, is_float);
// 全局标量:初始化器必须是常量
if(is_float){
float init_valf = 0;
try { init_valf = static_cast<float>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {}
gv->SetInitValF(init_valf);
} else {
int init_val = 0;
if (ctx->initVal() && ctx->initVal()->exp()) {
try {
auto cv = sem::EvaluateExp(*ctx->initVal()->exp()->addExp());
init_val = static_cast<int>(cv.int_val);
} catch (...) {
init_val = 0;
}
}
gv->SetInitVal(init_val);
}
auto* gv = module_.CreateGlobalVariable(name, false, init_val);
global_storage_map_[ctx] = gv;
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total, true, is_float);
global_storage_map_[ctx] = gv;
global_array_dims_[ctx] = dims;
// 计算初始值
if (ctx->initVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
if (is_float) {
std::vector<float> flat(total, 0.0f);
std::function<void(SysYParser::InitValContext*, int, int)> fill_f;
fill_f = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<float>(sem::EvaluateExp(*iv->exp()->addExp()).float_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill_f(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<float>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill_f(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitValsF(flat);
} else {
std::vector<int> flat(total, 0);
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<int>(sem::EvaluateExp(*iv->exp()->addExp()).int_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<int>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).int_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitVals(flat);
}
}
}
} else {
if (storage_map_.count(ctx)) {
@ -211,6 +374,14 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
ir::Value* init;
if (ctx->initVal() && ctx->initVal()->exp()) {
init = EvalExpr(*ctx->initVal()->exp());
// Coerce init value to slot type
if (!is_float && init->IsFloat32()) {
init = ToInt(init);
} else if (is_float && init->IsInt32()) {
init = ToFloat(init);
} else if (!is_float && init->IsInt1()) {
init = ToI32(init);
}
} else {
init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
@ -219,40 +390,95 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* slot = builder_.CreateAllocaArray(total, name);
auto* slot = is_float ? builder_.CreateAllocaArrayF32(total, module_.GetContext().NextTemp())
: builder_.CreateAllocaArray(total, module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
array_dims_[ctx] = dims;
ir::Value* zero_init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
if (ctx->initVal()) {
// 收集扁平化初始值
std::vector<ir::Value*> flat;
flat.reserve(total);
std::function<void(SysYParser::InitValContext*)> flatten =
[&](SysYParser::InitValContext* iv) {
if (!iv) return;
if (iv->exp()) {
flat.push_back(EvalExpr(*iv->exp()));
} else {
for (auto* sub : iv->initVal()) flatten(sub);
}
};
flatten(ctx->initVal());
for (int i = 0; i < total; ++i) {
ir::Value* v = (i < (int)flat.size()) ? flat[i]
: builder_.CreateConstInt(0);
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(v, ptr);
// 按 C 语义扁平化初始值:子列表对齐到对应维度边界
std::vector<ir::Value*> flat(total, zero_init);
// 计算各维度的 stridestride[i] = dims[i]*dims[i+1]*...*dims[n-1]
// 但我们需要「子列表对应第几维的 stride」
// 顶层stride = total / dims[0](即每行的元素数)
// 递归时 stride 继续除以当前维度大小
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0]; // 每个顶层子列表占用的元素数
// fill(iv, pos, stride):将 iv 的内容填入 flat[pos..pos+stride)
// stride 表示当前层子列表对应的元素个数
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
flat[pos] = EvalExpr(*iv->exp());
return;
}
// 子列表内的 stride = stride / (当前层首维大小)
// 找到对应的 strides 层stride == strides[k] → 子stride = strides[k+1]
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k) {
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
}
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 sub_stride 边界
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
// 顶层扫描
int cur = 0;
if (ctx->initVal()->exp()) {
flat[0] = EvalExpr(*ctx->initVal()->exp());
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 top_stride 边界
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
} else {
// 零初始化
// 先 memset 归零,再只写入非零元素
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
for (int i = 0; i < total; ++i) {
bool is_zero = false;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(flat[i])) {
is_zero = (ci->GetValue() == 0);
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(flat[i])) {
is_zero = (cf->GetValue() == 0.0f);
}
if (is_zero) continue;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), ptr);
builder_.CreateStore(flat[i], ptr);
}
} else {
// 零初始化:用 memset 归零
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
(void)zero_init;
}
}
}

@ -10,10 +10,15 @@
// ─── 辅助 ─────────────────────────────────────────────────────────────────────
// 把 i32 值转成 i1icmp ne i32 v, 0
// 把 i32/float 值转成 i1
ir::Value* IRGenImpl::ToI1(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value"));
if (v->IsInt1()) return v;
if (v->IsFloat32()) {
return builder_.CreateFCmp(ir::FCmpPredicate::ONE, v,
builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmp(ir::ICmpPredicate::NE, v,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
@ -87,7 +92,13 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
} else if (name == "getch") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
} else if (name == "getfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似
module_.DeclareExternalFunc(name, ir::Type::GetFloat32Type(), {});
} else if (name == "getarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type()});
} else if (name == "getfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrFloat32Type()});
} else if (name == "putint") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -95,10 +106,16 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else if (name == "putfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {});
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetFloat32Type()});
} else if (name == "putarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
{ir::Type::GetInt32Type(),
ir::Type::GetPtrInt32Type()});
} else if (name == "putfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(),
ir::Type::GetPtrFloat32Type()});
} else if (name == "starttime" || name == "stoptime") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -227,13 +244,113 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* exp : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*exp));
// 检查是否是数组变量(无索引的 lVar若是则传指针而非 load
ir::Value* arg = nullptr;
auto* add = exp->addExp();
if (add && add->mulExp().size() == 1) {
auto* mul = add->mulExp(0);
if (mul && mul->unaryExp().size() == 1) {
auto* unary = mul->unaryExp(0);
if (unary && !unary->unaryOp() && unary->primaryExp()) {
auto* primary = unary->primaryExp();
if (primary && primary->lVar() && primary->lVar()->exp().empty()) {
auto* lvar = primary->lVar();
auto* decl = sema_.ResolveVarUse(lvar->Ident());
if (decl) {
// 检查是否是数组参数storage_map_ 里存的是指针)
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
auto* val = it->second;
if (val && (val->IsPtrInt32() || val->IsPtrFloat32())) {
// 检查是否是 Argument数组参数直接传指针
if (dynamic_cast<ir::Argument*>(val)) {
arg = val;
} else if (array_dims_.count(decl)) {
// 本地数组(含 dims 记录):传首元素地址
arg = builder_.CreateGep(val, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
// 检查全局数组
if (!arg) {
auto git = global_storage_map_.find(decl);
if (git != global_storage_map_.end()) {
auto* gv = dynamic_cast<ir::GlobalVariable*>(git->second);
if (gv && gv->IsArray()) {
arg = builder_.CreateGep(git->second,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
}
}
}
}
}
// Also handle partially-indexed multi-dim arrays: arr[i] where arr is
// int arr[M][N] should pass &arr[i*N] as i32*, not load arr[i] as i32.
if (!arg) {
auto* add2 = exp->addExp();
if (add2 && add2->mulExp().size() == 1) {
auto* mul2 = add2->mulExp(0);
if (mul2 && mul2->unaryExp().size() == 1) {
auto* unary2 = mul2->unaryExp(0);
if (unary2 && !unary2->unaryOp() && unary2->primaryExp()) {
auto* primary2 = unary2->primaryExp();
if (primary2 && primary2->lVar() && !primary2->lVar()->exp().empty()) {
auto* lvar2 = primary2->lVar();
auto* decl2 = sema_.ResolveVarUse(lvar2->Ident());
if (decl2) {
std::vector<int> dims2;
ir::Value* base2 = nullptr;
auto it2 = array_dims_.find(decl2);
if (it2 != array_dims_.end()) {
dims2 = it2->second;
auto sit = storage_map_.find(decl2);
if (sit != storage_map_.end()) base2 = sit->second;
} else {
auto git2 = global_array_dims_.find(decl2);
if (git2 != global_array_dims_.end()) {
dims2 = git2->second;
auto gsit = global_storage_map_.find(decl2);
if (gsit != global_storage_map_.end()) base2 = gsit->second;
}
}
// Partially indexed: fewer indices than dimensions -> pass pointer
bool is_param = !dims2.empty() && dims2[0] == -1;
size_t effective_dims = is_param ? dims2.size() - 1 : dims2.size();
if (base2 && !dims2.empty() &&
lvar2->exp().size() < effective_dims + (is_param ? 1 : 0)) {
arg = EvalLVarAddr(lvar2);
}
}
}
}
}
}
}
if (!arg) arg = EvalExpr(*exp);
args.push_back(arg);
}
}
// 模块内已知函数?
ir::Function* callee = module_.GetFunction(callee_name);
if (callee) {
// Coerce args to match parameter types
for (size_t i = 0; i < args.size() && i < callee->GetNumArgs(); ++i) {
auto* param = callee->GetArgument(i);
if (!param || !args[i]) continue;
if (param->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name =
callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp();
auto* call =
@ -246,15 +363,28 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 外部函数
EnsureExternalDecl(callee_name);
// 获取返回类型
// 获取返回类型和参数类型
std::shared_ptr<ir::Type> ret_type = ir::Type::GetInt32Type();
std::vector<std::shared_ptr<ir::Type>> param_types;
for (const auto& decl : module_.GetExternalDecls()) {
if (decl.name == callee_name) {
ret_type = decl.ret_type;
param_types = decl.param_types;
break;
}
}
bool is_void = ret_type->IsVoid();
// Coerce args to match external function parameter types
for (size_t i = 0; i < args.size() && i < param_types.size(); ++i) {
if (!args[i]) continue;
if (param_types[i]->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param_types[i]->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param_types[i]->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name = is_void ? "" : module_.GetContext().NextTemp();
auto* call = builder_.CreateCallExternal(callee_name, ret_type,
std::move(args), ret_name);
@ -331,40 +461,26 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
throw std::runtime_error(FormatError("irgen", "数组索引维度过多"));
}
ir::Value* offset = builder_.CreateConstInt(0);
ir::Value* offset = nullptr;
if (is_array_param) {
// 数组参数dims[0]=-1, dims[1..n]是已知维度
// 索引indices[0]对应第一维indices[1]对应第二维...
for (size_t i = 0; i < indices.size(); ++i) {
ir::Value* idx = EvalExpr(*indices[i]);
if (i == 0) {
// 第一维stride = dims[1] * dims[2] * ... (如果有的话)
int stride = 1;
for (size_t j = 1; j < dims.size(); ++j) {
stride *= dims[j];
}
if (stride > 1) {
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
} else {
offset = builder_.CreateAdd(offset, idx,
module_.GetContext().NextTemp());
}
int stride = 1;
size_t start = (i == 0) ? 1 : i + 1;
for (size_t j = start; j < dims.size(); ++j) stride *= dims[j];
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
// 后续维度
int stride = 1;
for (size_t j = i + 1; j < dims.size(); ++j) {
stride *= dims[j];
}
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
} else {
@ -374,15 +490,24 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1];
if (i < (int)indices.size()) {
ir::Value* idx = EvalExpr(*indices[i]);
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
}
}
if (!offset) offset = builder_.CreateConstInt(0);
return builder_.CreateGep(base, offset, module_.GetContext().NextTemp());
}
@ -486,8 +611,8 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.end");
builder_.CreateCondBr(result, end_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb);
@ -498,6 +623,7 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}
@ -523,8 +649,8 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.end");
builder_.CreateCondBr(result, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
@ -535,6 +661,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}

@ -76,6 +76,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
// 设置插入点到入口块
builder_.SetInsertPoint(func_->GetEntry());
builder_.SetAllocaBlock(func_->GetEntry());
// 处理参数
if (ctx->funcFParams()) {

@ -25,6 +25,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->Return()) {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp());
// Coerce return value to function return type
if (func_->GetType()->IsInt32() && v->IsFloat32()) {
v = ToInt(v);
} else if (func_->GetType()->IsFloat32() && v->IsInt32()) {
v = ToFloat(v);
} else if (func_->GetType()->IsInt32() && v->IsInt1()) {
v = ToI32(v);
}
builder_.CreateRet(v);
} else {
builder_.CreateRetVoid();
@ -54,6 +62,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->lVar() && ctx->Assign()) {
ir::Value* rhs = EvalExpr(*ctx->exp());
ir::Value* addr = EvalLVarAddr(ctx->lVar());
// Coerce rhs to match slot type
if (addr->IsPtrInt32() && rhs->IsFloat32()) {
rhs = ToInt(rhs);
} else if (addr->IsPtrFloat32() && rhs->IsInt32()) {
rhs = ToFloat(rhs);
} else if (addr->IsPtrInt32() && rhs->IsInt1()) {
rhs = ToI32(rhs);
}
builder_.CreateStore(rhs, addr);
return BlockFlow::Continue;
}
@ -74,32 +90,47 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
auto stmts = ctx->stmt();
// Step 1: evaluate condition (may create short-circuit blocks with lower
// SSA numbers — must happen before any branch-target blocks are created).
ir::Value* cond_val = EvalCond(*ctx->cond());
// Step 2: create then_bb now (its label number will be >= all short-circuit
// block numbers allocated during EvalCond).
ir::BasicBlock* then_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.then");
module_.GetContext().NextLabel() + ".if.then");
// Step 3: create else_bb/merge_bb as placeholders. They will be moved to
// the end of the block list after their predecessors are filled in, so the
// block ordering in the output will be correct even though their label
// numbers are allocated here (before then-body sub-blocks are created).
ir::BasicBlock* else_bb = nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.end");
ir::BasicBlock* merge_bb = nullptr;
// 求值条件(可能创建短路求值块)
ir::Value* cond_val = EvalCond(*ctx->cond());
if (stmts.size() >= 2) {
else_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.else");
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
} else {
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
}
// 检查当前块是否已终结(短路求值可能导致)
// Check if current block already terminated (short-circuit may do this)
if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) {
// 条件求值已经终结了当前块,无法继续
// 这种情况下我们需要在merge_bb继续
func_->MoveBlockToEnd(then_bb);
if (else_bb) func_->MoveBlockToEnd(else_bb);
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (stmts.size() >= 2) {
// if-else
else_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".if.else");
builder_.CreateCondBr(cond_val, then_bb, else_bb);
} else {
builder_.CreateCondBr(cond_val, then_bb, merge_bb);
}
// then 分支
// then 分支 — visit body (may create many sub-blocks with higher numbers)
func_->MoveBlockToEnd(then_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = VisitStmt(*stmts[0]);
if (then_flow != BlockFlow::Terminated) {
@ -108,6 +139,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
// else 分支
if (else_bb) {
func_->MoveBlockToEnd(else_bb);
builder_.SetInsertPoint(else_bb);
auto else_flow = VisitStmt(*stmts[1]);
if (else_flow != BlockFlow::Terminated) {
@ -115,6 +147,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
}
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
@ -124,28 +157,32 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (!ctx->cond()) {
throw std::runtime_error(FormatError("irgen", "while 缺少条件"));
}
ir::BasicBlock* cond_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.cond");
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.end");
module_.GetContext().NextLabel() + ".while.cond");
// 跳转到条件块
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(cond_bb);
}
// 条件块
// EvalCond MUST come before creating body_bb/after_bb so that
// short-circuit blocks get lower SSA numbers than the loop body blocks.
builder_.SetInsertPoint(cond_bb);
ir::Value* cond_val = EvalCond(*ctx->cond());
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.end");
// 检查条件求值后是否已终结
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateCondBr(cond_val, body_bb, after_bb);
}
// 循环体(压入循环栈)
func_->MoveBlockToEnd(body_bb);
loop_stack_.push_back({cond_bb, after_bb});
builder_.SetInsertPoint(body_bb);
auto stmts = ctx->stmt();
@ -159,6 +196,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
loop_stack_.pop_back();
func_->MoveBlockToEnd(after_bb);
builder_.SetInsertPoint(after_bb);
return BlockFlow::Continue;
}

@ -1,4 +1,5 @@
#include <exception>
#include <fstream>
#include <iostream>
#include <stdexcept>
@ -21,13 +22,20 @@ int main(int argc, char** argv) {
return 0;
}
auto antlr = ParseFileWithAntlr(opts.input);
bool need_blank_line = false;
if (opts.emit_parse_tree) {
PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout);
need_blank_line = true;
// 确定输出流
std::ofstream ofs;
std::ostream* out = &std::cout;
if (!opts.output.empty()) {
ofs.open(opts.output);
if (!ofs) {
throw std::runtime_error(
FormatError("main", "无法打开输出文件: " + opts.output));
}
out = &ofs;
}
auto antlr = ParseFileWithAntlr(opts.input);
#if !COMPILER_PARSE_ONLY
auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree);
if (!comp_unit) {
@ -36,23 +44,23 @@ int main(int argc, char** argv) {
auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema);
if (opts.opt) {
ir::RunPasses(*module);
}
if (opts.emit_ir) {
ir::IRPrinter printer;
if (need_blank_line) {
std::cout << "\n";
}
printer.Print(*module, std::cout);
need_blank_line = true;
printer.Print(*module, *out);
}
if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
if (need_blank_line) {
std::cout << "\n";
auto machine_funcs = mir::LowerToMIR(*module);
for (auto& mf : machine_funcs) {
mir::RunRegAlloc(*mf);
mir::RunFrameLowering(*mf);
}
mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(machine_funcs, *out);
}
#else
if (opts.emit_ir || opts.emit_asm) {
@ -65,4 +73,4 @@ int main(int argc, char** argv) {
return 1;
}
return 0;
}
}

@ -2,9 +2,16 @@
#include <ostream>
#include <stdexcept>
#include <iostream>
#include <vector>
#include <unordered_map>
#include "utils/Log.h"
// 引用全局变量(定义在 Lowering.cpp 中)
extern std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
@ -16,63 +23,498 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
void EmitStackLoad(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " lw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " lw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
} // namespace
void EmitStackStore(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " sw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " sw " << PhysRegName(src) << ", 0(t4)\n";
}
}
void PrintAsm(const MachineFunction& function, std::ostream& os) {
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
void EmitStackLoadFloat(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " flw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " flw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
void EmitStackStoreFloat(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " fsw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " fsw " << PhysRegName(src) << ", 0(t4)\n";
}
}
// 输出单个函数的汇编
void PrintAsmFunction(const MachineFunction& function, std::ostream& os) {
// 收集所有基本块名称
std::unordered_map<const MachineBasicBlock*, std::string> block_names;
for (const auto& block_ptr : function.GetBlocks()) {
block_names[block_ptr.get()] = block_ptr->GetName();
}
int total_frame_size = 16 + function.GetFrameSize();
bool prologue_done = false;
for (const auto& block_ptr : function.GetBlocks()) {
const auto& block = *block_ptr;
// 输出基本块标签(入口块不输出,因为函数名已经是标签)
if (block.GetName() != "entry") {
os << block.GetName() << ":\n";
}
for (const auto& inst : block.GetInstructions()) {
const auto& ops = inst.GetOperands();
// 在入口块的第一条指令前输出序言
if (!prologue_done && block.GetName() == "entry") {
// 处理大栈帧的情况
if (total_frame_size <= 2047) {
os << " addi sp, sp, -" << total_frame_size << "\n";
} else {
os << " li t4, -" << total_frame_size << "\n";
os << " add sp, sp, t4\n";
}
// 保存 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " sw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t4, " << ra_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw ra, 0(t4)\n";
}
if (s0_offset <= 2047) {
os << " sw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t4, " << s0_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw s0, 0(t4)\n";
}
prologue_done = true;
}
switch (inst.GetOpcode()) {
case Opcode::Prologue:
case Opcode::Epilogue:
break;
case Opcode::MovImm:
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::Load: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoad(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Store: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStore(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Add:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sub:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Mul: {
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
} else {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
}
break;
}
case Opcode::Div:
os << " div " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Rem:
os << " rem " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slt:
os << " slt " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slti:
os << " slti " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Sltu:
os << " sltu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sltiu: // <-- 添加这个
os << " sltiu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Xori:
os << " xori " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::LoadGlobalAddr: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la " << PhysRegName(ops.at(0).GetReg()) << ", " << global_name << "\n";
break;
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
case Opcode::LoadGlobal:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::StoreGlobal: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la t1, " << global_name << "\n";
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t1)\n";
break;
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
case Opcode::GEP:
break;
case Opcode::LoadIndirect:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Call: {
std::string func_name = "memset";
if (!ops.empty() && ops[0].GetKind() == Operand::Kind::Func) {
func_name = ops[0].GetFuncName();
}
os << " call " << func_name << "\n";
break;
}
case Opcode::LoadAddr: {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
if (slot.offset >= -2048 && slot.offset <= 2047) {
os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", sp, " << slot.offset << "\n";
} else {
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", " << slot.offset << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", sp, "
<< PhysRegName(ops.at(0).GetReg()) << "\n";
}
break;
}
case Opcode::Slli:
os << " slli " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::StoreIndirect:
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Ret:{
// 恢复 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " lw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t3, " << ra_offset << "\n";
os << " add t3, sp, t3\n";
os << " lw ra, 0(t3)\n";
}
if (s0_offset <= 2047) {
os << " lw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t3, " << s0_offset << "\n";
os << " add t3, sp, t3\n";
os << " lw s0, 0(t3)\n";
}
// 恢复 sp
if (total_frame_size <= 2047) {
os << " addi sp, sp, " << total_frame_size << "\n";
} else {
os << " li t3, " << total_frame_size << "\n";
os << " add sp, sp, t3\n";
}
os << " ret\n";
break;
}
case Opcode::Br: {
auto* target = reinterpret_cast<MachineBasicBlock*>(ops[0].GetImm64());
os << " j " << target->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* true_target = reinterpret_cast<MachineBasicBlock*>(ops[1].GetImm64());
auto* false_target = reinterpret_cast<MachineBasicBlock*>(ops[2].GetImm64());
auto true_it = block_names.find(true_target);
auto false_it = block_names.find(false_target);
if (true_it == block_names.end() || false_it == block_names.end()) {
throw std::runtime_error(FormatError("mir", "CondBr: 找不到基本块名称"));
}
os << " bnez " << PhysRegName(ops[0].GetReg()) << ", "
<< true_it->second << "\n";
os << " j " << false_it->second << "\n";
break;
}
// 浮点运算
case Opcode::FAdd:
os << " fadd.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FSub:
os << " fsub.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMul:
os << " fmul.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FDiv:
os << " fdiv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FEq:
os << " feq.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLt:
os << " flt.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLe:
os << " fle.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMov:
os << " fmv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovWX:
os << " fmv.w.x " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovXW:
os << " fmv.x.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " fcvt.s.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvt.w.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::LoadFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " flw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoadFloat(os, ops[0].GetReg(), slot.offset);
}
break;
case Opcode::StoreFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " fsw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStoreFloat(os, ops[0].GetReg(), slot.offset);
}
break;
default:
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 输出多个函数的汇编
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os) {
// ========== 输出全局变量 ==========
// 输出 .data 段(已初始化的全局变量)
bool hasData = false;
for (const auto& gv : g_globalVars) {
if (!gv.isConst) {
if (!hasData) {
os << ".data\n";
hasData = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n";
// 输出 .rodata 段(只读常量)
bool hasRodata = false;
for (const auto& gv : g_globalVars) {
if (gv.isConst) {
if (!hasRodata) {
os << ".section .rodata\n";
hasRodata = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
}
}
// ========== 输出代码段 ==========
os << ".text\n";
// 输出每个函数
for (const auto& func_ptr : functions) {
os << ".global " << func_ptr->GetName() << "\n";
os << ".type " << func_ptr->GetName() << ", @function\n";
os << func_ptr->GetName() << ":\n";
PrintAsmFunction(*func_ptr, os);
os << "\n"; // 函数之间加空行
}
}
} // namespace mir
} // namespace mir

@ -15,10 +15,12 @@ target_link_libraries(mir_core PUBLIC
ir
)
target_compile_options(mir_core PRIVATE -Wno-unused-parameter)
add_subdirectory(passes)
add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE
mir_core
mir_passes
)
)

@ -18,9 +18,9 @@ void RunFrameLowering(MachineFunction& function) {
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
//if (-cursor < -2048) {
//throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
//}
}
cursor = 0;
@ -30,7 +30,8 @@ void RunFrameLowering(MachineFunction& function) {
}
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions();
// 修复GetEntry() 返回指针,使用 ->
auto& insts = function.GetEntry()->GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
@ -42,4 +43,4 @@ void RunFrameLowering(MachineFunction& function) {
insts = std::move(lowered);
}
} // namespace mir
} // namespace mir

@ -1,123 +1,642 @@
#include "mir/MIR.h"
#include <iostream>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <cstring>
#include "ir/IR.h"
#include "utils/Log.h"
std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
static std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
MachineBasicBlock* GetOrCreateBlock(const ir::BasicBlock* ir_block,
MachineFunction& function) {
auto it = block_map.find(ir_block);
if (it != block_map.end()) {
return it->second;
}
std::string name = ir_block->GetName();
if (name.empty()) {
name = "block_" + std::to_string(block_map.size());
}
auto* block = function.CreateBlock(name);
block_map[ir_block] = block;
return block;
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
const ValueSlotMap& slots, MachineBasicBlock& block,
bool for_address=false) {
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
auto it = slots.find(arg);
if (it != slots.end()) {
// 从栈槽加载参数值
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
}
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
int64_t val = constant->GetValue();
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
{Operand::Reg(target), Operand::Imm(static_cast<int>(val))});
return;
}
// 处理浮点常量
if (auto* fconstant = dynamic_cast<const ir::ConstantFloat*>(value)) {
float val = fconstant->GetValue();
uint32_t bits;
memcpy(&bits, &val, sizeof(val));
// 检查目标是否是浮点寄存器
bool target_is_fp = (target == PhysReg::FT0 || target == PhysReg::FT1 ||
target == PhysReg::FT2 || target == PhysReg::FT3 ||
target == PhysReg::FT4 || target == PhysReg::FT5 ||
target == PhysReg::FT6 || target == PhysReg::FT7 ||
target == PhysReg::FA0 || target == PhysReg::FA1 ||
target == PhysReg::FA2 || target == PhysReg::FA3 ||
target == PhysReg::FA4 || target == PhysReg::FA5 ||
target == PhysReg::FA6 || target == PhysReg::FA7);
if (target_is_fp) {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::T0), Operand::Imm(static_cast<int>(bits))});
block.Append(Opcode::FMovWX, {Operand::Reg(target), Operand::Reg(PhysReg::T0)});
} else {
// 目标是整数寄存器,直接加载
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(static_cast<int>(bits))});
}
return;
}
if (auto* gep = dynamic_cast<const ir::GepInst*>(value)) {
EmitValueToReg(gep->GetBasePtr(), target, slots, block, true);
EmitValueToReg(gep->GetIndex(), PhysReg::T1, slots, block);
block.Append(Opcode::Slli, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Imm(2)});
block.Append(Opcode::Add, {Operand::Reg(target),
Operand::Reg(target),
Operand::Reg(PhysReg::T1)});
return;
}
if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(value)) {
auto it = slots.find(alloca);
if (it != slots.end()) {
block.Append(Opcode::LoadAddr,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(value)) {
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(target), Operand::Global(global->GetName())});
if (!for_address) {
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::Reg(target)});
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Reg(target)});
}
}
return;
}
// 关键:在 slots 中查找,并根据类型生成正确的加载指令
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
if (it != slots.end()) {
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
std::cerr << "未找到的值: " << value << std::endl;
std::cerr << " 名称: " << value->GetName() << std::endl;
std::cerr << " 类型: " << (value->GetType()->IsFloat32() ? "float" : "int") << std::endl;
std::cerr << " 是否是 ConstantInt: " << (dynamic_cast<const ir::ConstantInt*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 ConstantFloat: " << (dynamic_cast<const ir::ConstantFloat*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 Instruction: " << (dynamic_cast<const ir::Instruction*>(value) != nullptr) << std::endl;
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
void StoreRegToSlot(PhysReg reg, int slot, MachineBasicBlock& block, bool isFloat = false) {
if (isFloat) {
block.Append(Opcode::StoreFloat,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
} else {
block.Append(Opcode::Store,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
}
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots) {
auto& block = function.GetEntry();
// 将 LowerInstruction 重命名为 LowerInstructionToBlock并添加 MachineBasicBlock 参数
void LowerInstructionToBlock(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots, MachineBasicBlock& block) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex());
auto& alloca = static_cast<const ir::AllocaInst&>(inst);
int size = 4;
if (alloca.GetNumElements() > 1) {
size = alloca.GetNumElements() * 4;
}
slots.emplace(&inst, function.CreateFrameIndex(size));
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T2, slots, block);
EmitValueToReg(store.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::T2), Operand::Reg(PhysReg::T0)});
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
std::string global_name = global->GetName();
block.Append(Opcode::StoreGlobal,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
return;
}
auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
if (dst != slots.end()) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
StoreRegToSlot(PhysReg::T0, dst->second, block);
return;
}
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
return;
throw std::runtime_error(FormatError("mir", "Store: 无法处理的指针类型"));
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(load.GetPtr())) {
EmitValueToReg(load.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
int dst_slot = function.CreateFrameIndex(4);
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(load.GetPtr())) {
int dst_slot = function.CreateFrameIndex(4);
std::string global_name = global->GetName();
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, true);
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
auto src = slots.find(load.GetPtr());
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
if (src != slots.end()) {
int dst_slot = function.CreateFrameIndex(4);
if (load.GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
} else {
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
throw std::runtime_error(FormatError("mir", "Load: 无法处理的指针类型"));
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
EmitValueToReg(bin.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::T1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::Add: op = Opcode::Add; break;
case ir::Opcode::Sub: op = Opcode::Sub; break;
case ir::Opcode::Mul: op = Opcode::Mul; break;
case ir::Opcode::Div: op = Opcode::Div; break;
case ir::Opcode::Mod: op = Opcode::Rem; break;
default: op = Opcode::Add; break;
}
block.Append(op, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
case ir::Opcode::Gep: {
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
block.Append(Opcode::Ret);
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
for (size_t i = 0; i < call.GetNumArgs() && i < 8; i++) {
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
EmitValueToReg(call.GetArg(i), argReg, slots, block);
}
std::string func_name = call.GetCalleeName();
block.Append(Opcode::Call, {Operand::Func(func_name)});
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
StoreRegToSlot(PhysReg::A0, dst_slot, block);
slots.emplace(&inst, dst_slot);
}
return;
}
case ir::Opcode::Sub:
case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
}
case ir::Opcode::ICmp: {
auto& icmp = static_cast<const ir::ICmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(icmp.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(icmp.GetRhs(), PhysReg::T1, slots, block);
ir::ICmpPredicate pred = icmp.GetPredicate();
switch (pred) {
case ir::ICmpPredicate::EQ:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
}
case ir::ICmpPredicate::NE:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SLT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
break;
case ir::ICmpPredicate::SLE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SGT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
break;
case ir::ICmpPredicate::SGE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
}
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::ZExt: {
auto& zext = static_cast<const ir::ZExtInst&>(inst);
int dst_slot = function.CreateFrameIndex(4); // i32 是 4 字节
// 获取源操作数的值
EmitValueToReg(zext.GetSrc(), PhysReg::T0, slots, block);
// 存储到新栈槽
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(bin.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::FT1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::FAdd: op = Opcode::FAdd; break;
case ir::Opcode::FSub: op = Opcode::FSub; break;
case ir::Opcode::FMul: op = Opcode::FMul; break;
case ir::Opcode::FDiv: op = Opcode::FDiv; break;
default: op = Opcode::FAdd; break;
}
block.Append(op, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
} // namespace
case ir::Opcode::FCmp: {
auto& fcmp = static_cast<const ir::FCmpInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(fcmp.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(fcmp.GetRhs(), PhysReg::FT1, slots, block);
ir::FCmpPredicate pred = fcmp.GetPredicate();
switch (pred) {
case ir::FCmpPredicate::OEQ:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLT:
block.Append(Opcode::FLt, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLE:
block.Append(Opcode::FLe, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
default:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
}
block.Append(Opcode::FMov, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
DefaultContext();
case ir::Opcode::SIToFP: {
auto& conv = static_cast<const ir::SIToFPInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "SIToFP: 找不到源操作数的栈槽"));
}
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
slots.emplace(&inst, dst_slot);
return;
}
if (module.GetFunctions().size() != 1) {
throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
}
case ir::Opcode::FPToSI: {
auto& conv = static_cast<const ir::FPToSIInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "FPToSI: 找不到源操作数的栈槽"));
}
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BrInst&>(inst);
auto* target = br.GetTarget();
MachineBasicBlock* target_block = GetOrCreateBlock(target, function);
block.Append(Opcode::Br, {Operand::Imm64(reinterpret_cast<intptr_t>(target_block))});
return;
}
case ir::Opcode::CondBr: {
auto& condbr = static_cast<const ir::CondBrInst&>(inst);
auto* true_bb = condbr.GetTrueBB();
auto* false_bb = condbr.GetFalseBB();
const auto& func = *module.GetFunctions().front();
if (func.GetName() != "main") {
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
// 如果条件涉及函数调用,需要特殊处理
// 简单方案:将条件值保存到栈槽
int cond_slot = function.CreateFrameIndex(4);
EmitValueToReg(condbr.GetCond(), PhysReg::T0, slots, block);
// 保存条件值到栈
block.Append(Opcode::Store, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)});
// 从栈加载条件值(确保函数调用后还能获取)
block.Append(Opcode::Load, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)});
block.Append(Opcode::Sltu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::ZERO),
Operand::Reg(PhysReg::T0)});
MachineBasicBlock* true_block = GetOrCreateBlock(true_bb, function);
MachineBasicBlock* false_block = GetOrCreateBlock(false_bb, function);
block.Append(Opcode::CondBr, {Operand::Reg(PhysReg::T1),
Operand::Imm64(reinterpret_cast<intptr_t>(true_block)),
Operand::Imm64(reinterpret_cast<intptr_t>(false_block))});
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
auto val = ret.GetValue();
if (val->GetType()->IsFloat32()) {
auto it = slots.find(val);
if (it != slots.end()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::A0),
Operand::Reg(PhysReg::FT0)});
} else {
throw std::runtime_error(FormatError("mir", "Ret: 找不到浮点返回值的栈槽"));
}
} else {
EmitValueToReg(val, PhysReg::A0, slots, block);
}
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::A0), Operand::Imm(0)});
}
block.Append(Opcode::Ret);
return;
}
default: {
break;
}
}
}
} // namespace
std::unique_ptr<MachineFunction> LowerFunctionToMIR(const ir::Function& func) {
block_map.clear();
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
const auto* entry = func.GetEntry();
if (!entry) {
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
// ========== 新增:为函数参数分配栈槽 ==========
for (size_t i = 0; i < func.GetNumArgs(); i++) {
ir::Argument* arg = func.GetArgument(i);
int slot = machine_func->CreateFrameIndex(4); // int 和指针都是 4 字节
// 将参数值从寄存器存储到栈槽
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
MachineBasicBlock* entry = machine_func->GetEntry();
// 存储参数到栈槽
if (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32()) {
// 指针类型
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsInt32()) {
// 整数类型
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsFloat32()) {
// 浮点类型
entry->Append(Opcode::StoreFloat, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
}
slots[arg] = slot;
}
for (const auto& inst : entry->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots);
// 第一遍:创建所有 IR 基本块对应的 MIR 基本块
for (const auto& ir_block : func.GetBlocks()) {
GetOrCreateBlock(ir_block.get(), *machine_func);
}
// 第二遍:遍历所有基本块,降低指令
for (const auto& ir_block : func.GetBlocks()) {
MachineBasicBlock* mbb = GetOrCreateBlock(ir_block.get(), *machine_func);
for (const auto& inst : ir_block->GetInstructions()) {
LowerInstructionToBlock(*inst, *machine_func, slots, *mbb);
}
}
return machine_func;
}
} // namespace mir
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module) {
DefaultContext();
// 收集全局变量(只做一次)
g_globalVars.clear();
for (const auto& global : module.GetGlobalVariables()) {
GlobalVarInfo info;
info.name = global->GetName();
info.isConst = global->IsConst();
info.isArray = global->IsArray();
info.arraySize = global->GetNumElements();
info.isFloat = global->IsFloat();
info.value = 0;
info.valueF = 0.0f;
if (info.isArray) {
if (info.isFloat) {
const auto& initVals = global->GetInitValsF();
for (float val : initVals) {
info.arrayValuesF.push_back(val);
}
} else {
if (global->HasInitVals()) {
const auto& initVals = global->GetInitVals();
for (int val : initVals) {
info.arrayValues.push_back(val);
}
}
}
} else {
if (info.isFloat) {
info.valueF = global->GetInitValF();
} else {
info.value = global->GetInitVal();
}
}
g_globalVars.push_back(info);
}
const auto& functions = module.GetFunctions();
if (functions.empty()) {
throw std::runtime_error(FormatError("mir", "模块中没有函数"));
}
std::vector<std::unique_ptr<MachineFunction>> result;
// 为每个函数生成 MachineFunction
for (const auto& func : functions) {
auto machine_func = LowerFunctionToMIR(*func);
result.push_back(std::move(machine_func));
}
return result;
}
} // namespace mir

@ -8,7 +8,16 @@
namespace mir {
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {}
: name_(std::move(name)) {
entry_ = CreateBlock("entry");
}
MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) {
auto block = std::make_unique<MachineBasicBlock>(name);
auto* ptr = block.get();
blocks_.push_back(std::move(block));
return ptr;
}
int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size());
@ -30,4 +39,4 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
return frame_slots_[index];
}
} // namespace mir
} // namespace mir

@ -6,18 +6,34 @@ namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand::Operand(Kind kind, PhysReg reg, int64_t imm64)
: kind_(kind), reg_(PhysReg::ZERO), imm_(0), imm64_(imm64) {}
// 新增构造函数
Operand::Operand(Kind kind, PhysReg reg, int imm, const std::string& name)
: kind_(kind), reg_(reg), imm_(imm), global_name_(name) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
return Operand(Kind::Imm, PhysReg::ZERO, value);
}
Operand Operand::Imm64(int64_t value) {
return Operand(Kind::Imm, PhysReg::ZERO, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index);
return Operand(Kind::FrameIndex, PhysReg::ZERO, index);
}
// 新增
Operand Operand::Global(const std::string& name) {
return Operand(Kind::Global, PhysReg::ZERO, 0, name);
}
Operand Operand::Func(const std::string& name) {
Operand op(Kind::Func, PhysReg::ZERO, 0);
op.func_name_ = name;
return op;
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
} // namespace mir
} // namespace mir

@ -9,12 +9,66 @@ namespace {
bool IsAllowedReg(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29:
case PhysReg::X30:
// 临时寄存器
case PhysReg::T0:
case PhysReg::T1:
case PhysReg::T2:
case PhysReg::T3:
case PhysReg::T4:
case PhysReg::T5:
case PhysReg::T6:
// 参数/返回值寄存器
case PhysReg::A0:
case PhysReg::A1:
case PhysReg::A2:
case PhysReg::A3:
case PhysReg::A4:
case PhysReg::A5:
case PhysReg::A6:
case PhysReg::A7:
// 保存寄存器
case PhysReg::S0:
case PhysReg::S1:
case PhysReg::S2:
case PhysReg::S3:
case PhysReg::S4:
case PhysReg::S5:
case PhysReg::S6:
case PhysReg::S7:
case PhysReg::S8:
case PhysReg::S9:
case PhysReg::S10:
case PhysReg::S11:
// 特殊寄存器
case PhysReg::ZERO:
case PhysReg::RA:
case PhysReg::SP:
case PhysReg::GP:
case PhysReg::TP:
case PhysReg::FT0:
case PhysReg::FT1:
case PhysReg::FT2:
case PhysReg::FT3:
case PhysReg::FT4:
case PhysReg::FT5:
case PhysReg::FT6:
case PhysReg::FT7:
case PhysReg::FT8:
case PhysReg::FT9:
case PhysReg::FT10:
case PhysReg::FT11:
// 浮点保存寄存器
case PhysReg::FS0:
case PhysReg::FS1:
// 浮点参数寄存器
case PhysReg::FA0:
case PhysReg::FA1:
case PhysReg::FA2:
case PhysReg::FA3:
case PhysReg::FA4:
case PhysReg::FA5:
case PhysReg::FA6:
case PhysReg::FA7:
return true;
}
return false;
@ -23,7 +77,8 @@ bool IsAllowedReg(PhysReg reg) {
} // namespace
void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) {
// 修复GetEntry() 返回指针,使用 ->
for (const auto& inst : function.GetEntry()->GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
@ -33,4 +88,4 @@ void RunRegAlloc(MachineFunction& function) {
}
}
} // namespace mir
} // namespace mir

@ -8,20 +8,96 @@ namespace mir {
const char* PhysRegName(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
return "w0";
case PhysReg::W8:
return "w8";
case PhysReg::W9:
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
// 整数寄存器
case PhysReg::ZERO:
return "zero";
case PhysReg::RA:
return "ra";
case PhysReg::SP:
return "sp";
case PhysReg::GP:
return "gp";
case PhysReg::TP:
return "tp";
case PhysReg::T0:
return "t0";
case PhysReg::T1:
return "t1";
case PhysReg::T2:
return "t2";
case PhysReg::S0:
return "s0";
case PhysReg::S1:
return "s1";
case PhysReg::A0:
return "a0";
case PhysReg::A1:
return "a1";
case PhysReg::A2:
return "a2";
case PhysReg::A3:
return "a3";
case PhysReg::A4:
return "a4";
case PhysReg::A5:
return "a5";
case PhysReg::A6:
return "a6";
case PhysReg::A7:
return "a7";
case PhysReg::S2:
return "s2";
case PhysReg::S3:
return "s3";
case PhysReg::S4:
return "s4";
case PhysReg::S5:
return "s5";
case PhysReg::S6:
return "s6";
case PhysReg::S7:
return "s7";
case PhysReg::S8:
return "s8";
case PhysReg::S9:
return "s9";
case PhysReg::S10:
return "s10";
case PhysReg::S11:
return "s11";
case PhysReg::T3:
return "t3";
case PhysReg::T4:
return "t4";
case PhysReg::T5:
return "t5";
case PhysReg::T6:
return "t6";
// 浮点寄存器
case PhysReg::FT0: return "ft0";
case PhysReg::FT1: return "ft1";
case PhysReg::FT2: return "ft2";
case PhysReg::FT3: return "ft3";
case PhysReg::FT4: return "ft4";
case PhysReg::FT5: return "ft5";
case PhysReg::FT6: return "ft6";
case PhysReg::FT7: return "ft7";
case PhysReg::FS0: return "fs0";
case PhysReg::FS1: return "fs1";
case PhysReg::FA0: return "fa0";
case PhysReg::FA1: return "fa1";
case PhysReg::FA2: return "fa2";
case PhysReg::FA3: return "fa3";
case PhysReg::FA4: return "fa4";
case PhysReg::FA5: return "fa5";
case PhysReg::FA6: return "fa6";
case PhysReg::FA7: return "fa7";
case PhysReg::FT8: return "ft8";
case PhysReg::FT9: return "ft9";
case PhysReg::FT10: return "ft10";
case PhysReg::FT11: return "ft11";
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
}
} // namespace mir
} // namespace mir

@ -66,8 +66,8 @@ class SemaVisitor final : public SysYBaseVisitor {
}
func->accept(this);
}
if (!has_main && ctx->funcDef().empty()) {
throw std::runtime_error(FormatError("sema", "缺少函数定义"));
if (!has_main) {
throw std::runtime_error(FormatError("sema", "缺少main函数定义"));
}
scope_.PopScope();
return {};

@ -1,5 +1,6 @@
#include "sem/func.h"
#include <cstring>
#include <stdexcept>
#include <string>
@ -7,6 +8,12 @@
namespace sem {
// Truncate double to float32 precision (mimics C float arithmetic)
static double ToFloat32(double v) {
float f = static_cast<float>(v);
return static_cast<double>(f);
}
// 编译时求值常量表达式
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) {
return EvaluateExp(*ctx.addExp());
@ -73,14 +80,65 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp()) {
return EvaluateExp(*ctx.exp()->addExp());
} else if (ctx.lVar()) {
// 处理变量引用(必须是已定义的常量)
// 处理变量引用:向上遍历 AST 找到对应的常量定义并求值
auto* ident = ctx.lVar()->Ident();
if (!ident) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
std::string name = ident->getText();
// 这里简化处理,实际应该在符号表中查找常量
// 暂时假设常量已经在前面被处理过
// 向上遍历 AST 找到作用域内的 constDef
antlr4::ParserRuleContext* scope =
dynamic_cast<antlr4::ParserRuleContext*>(ctx.lVar()->parent);
while (scope) {
// 检查当前作用域中的所有 constDecl
for (auto* tree_child : scope->children) {
auto* child = dynamic_cast<antlr4::ParserRuleContext*>(tree_child);
if (!child) continue;
auto* block_item = dynamic_cast<SysYParser::BlockItemContext*>(child);
if (block_item && block_item->decl()) {
auto* decl = block_item->decl();
if (decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
// compUnit 级别的 constDecl
auto* decl = dynamic_cast<SysYParser::DeclContext*>(child);
if (decl && decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
// If declared as int, truncate to integer
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
scope = dynamic_cast<antlr4::ParserRuleContext*>(scope->parent);
}
// 未找到常量定义,返回 0
ConstValue val;
val.is_int = true;
val.int_val = 0;
@ -94,11 +152,11 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
ConstValue val;
if (int_const) {
val.is_int = true;
val.int_val = std::stoll(int_const->getText());
val.int_val = std::stoll(int_const->getText(), nullptr, 0);
val.float_val = static_cast<double>(val.int_val);
} else if (float_const) {
val.is_int = false;
val.float_val = std::stod(float_const->getText());
val.float_val = ToFloat32(std::stod(float_const->getText()));
val.int_val = static_cast<long long>(val.float_val);
} else {
throw std::runtime_error(FormatError("sema", "非法数字字面量"));
@ -127,8 +185,9 @@ ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) +
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l + r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -143,8 +202,9 @@ ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) -
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l - r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -159,8 +219,9 @@ ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) *
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l * r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -175,8 +236,9 @@ ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) /
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l / r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;

@ -1,4 +1,6 @@
// 解析帮助、输入文件和输出阶段选项。
// 解析命令行: compiler <input.sy> -S -o <output.s> [-O1]
// 或: compiler <input.sy> -IR -o <output.ll> [-O1]
// 同时兼容 --emit-ir / --emit-asm 旧格式
#include "utils/CLI.h"
@ -15,30 +17,31 @@ CLIOptions ParseCLI(int argc, char** argv) {
if (argc <= 1) {
throw std::runtime_error(FormatError(
"cli",
"用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>"));
"用法: compiler <input.sy> -S -o <output.s> [-O1]\n"
" 或: compiler <input.sy> -IR -o <output.ll> [-O1]"));
}
for (int i = 1; i < argc; ++i) {
const char* arg = argv[i];
if (std::strcmp(arg, "-h") == 0 || std::strcmp(arg, "--help") == 0) {
opt.show_help = true;
return opt;
}
if (std::strcmp(arg, "--emit-parse-tree") == 0) {
// 输出阶段(新格式)
if (std::strcmp(arg, "-S") == 0) {
if (!explicit_emit) {
opt.emit_parse_tree = false;
opt.emit_ir = false;
opt.emit_asm = false;
explicit_emit = true;
}
opt.emit_parse_tree = true;
opt.emit_asm = true;
continue;
}
if (std::strcmp(arg, "--emit-ir") == 0) {
if (std::strcmp(arg, "-IR") == 0) {
if (!explicit_emit) {
opt.emit_parse_tree = false;
opt.emit_ir = false;
opt.emit_asm = false;
explicit_emit = true;
@ -47,9 +50,9 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue;
}
// 输出阶段(兼容旧格式)
if (std::strcmp(arg, "--emit-asm") == 0) {
if (!explicit_emit) {
opt.emit_parse_tree = false;
opt.emit_ir = false;
opt.emit_asm = false;
explicit_emit = true;
@ -58,6 +61,32 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue;
}
if (std::strcmp(arg, "--emit-ir") == 0) {
if (!explicit_emit) {
opt.emit_ir = false;
opt.emit_asm = false;
explicit_emit = true;
}
opt.emit_ir = true;
continue;
}
// 优化级别
if (std::strcmp(arg, "-O1") == 0) {
opt.opt = true;
continue;
}
// 输出文件
if (std::strcmp(arg, "-o") == 0) {
if (i + 1 >= argc) {
throw std::runtime_error(
FormatError("cli", "-o 缺少输出文件名"));
}
opt.output = argv[++i];
continue;
}
if (arg[0] == '-') {
throw std::runtime_error(
FormatError("cli", std::string("未知参数: ") + arg +
@ -73,11 +102,12 @@ CLIOptions ParseCLI(int argc, char** argv) {
if (opt.input.empty() && !opt.show_help) {
throw std::runtime_error(
FormatError("cli", "缺少输入文件:请提供 <input.sy>(使用 --help 查看用法)"));
FormatError("cli", "缺少输入文件:请提供 <input.sy>"));
}
if (!opt.emit_parse_tree && !opt.emit_ir && !opt.emit_asm) {
throw std::runtime_error(FormatError(
"cli", "未选择任何输出:请使用 --emit-parse-tree / --emit-ir / --emit-asm"));
if (!explicit_emit) {
// 未显式选择输出阶段时默认输出 IR
opt.emit_ir = true;
}
return opt;
}

@ -50,17 +50,22 @@ void PrintHelp(std::ostream& os) {
os << "SysY Compiler\n"
<< "\n"
<< "用法:\n"
<< " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>\n"
<< " compiler <input.sy> -IR -o <output.ll> [-O1] # 输出 IR\n"
<< " compiler <input.sy> -S -o <output.s> [-O1] # 输出汇编\n"
<< "\n"
<< "选项:\n"
<< " -h, --help 打印帮助信息并退出\n"
<< " --emit-parse-tree 仅在显式模式下启用语法树输出\n"
<< " --emit-ir 仅在显式模式下启用 IR 输出\n"
<< " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n"
<< " -IR 输出中间代码IR 文本)\n"
<< " -S 输出 AArch64 汇编码\n"
<< " -o <file> 输出文件(默认 stdout\n"
<< " -O1 启用 IR 优化Mem2Reg + 标量优化)\n"
<< " -h, --help 打印帮助信息并退出\n"
<< "\n"
<< "说明:\n"
<< " - 默认输出 IR\n"
<< " - 若使用 --emit-parse-tree/--emit-ir/--emit-asm则仅输出显式选择的阶段\n"
<< " - 可使用重定向写入文件:\n"
<< " compiler --emit-asm test/test_case/functional/simple_add.sy > out.s\n";
<< "兼容格式(仍可使用):\n"
<< " --emit-ir 同 -IR\n"
<< " --emit-asm 同 -S\n"
<< "\n"
<< "示例:\n"
<< " compiler test.sy -IR -o test.ll -O1 # 生成优化 IR\n"
<< " compiler test.sy -S -o test.s -O1 # 生成优化汇编\n"
<< " compiler test.sy -IR # IR 输出到 stdout\n";
}

Binary file not shown.

@ -1,4 +1,83 @@
// SysY 运行库实现:
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为
#include<stdio.h>
#include<stdarg.h>
#include<sys/time.h>
#include"sylib.h"
/* Input & output functions */
int getint(){int t; scanf("%d",&t); return t; }
int getch(){char c; scanf("%c",&c); return (int)c; }
float getfloat(){
float n;
scanf("%a", &n);
return n;
}
int getarray(int a[]){
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)scanf("%d",&a[i]);
return n;
}
int getfarray(float a[]) {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%a", &a[i]);
}
return n;
}
void putint(int a){ printf("%d",a);}
void putch(int a){ printf("%c",a); }
void putarray(int n,int a[]){
printf("%d:",n);
for(int i=0;i<n;i++)printf(" %d",a[i]);
printf("\n");
}
void putfloat(float a) {
printf("%a", a);
}
void putfarray(int n, float a[]) {
printf("%d:", n);
for (int i = 0; i < n; i++) {
printf(" %a", a[i]);
}
printf("\n");
}
void putf(char a[], ...) {
va_list args;
va_start(args, a);
vfprintf(stdout, a, args);
va_end(args);
}
/* Timing function implementation */
__attribute((constructor)) void before_main(){
for(int i=0;i<_SYSY_N;i++)
_sysy_h[i] = _sysy_m[i]= _sysy_s[i] = _sysy_us[i] =0;
_sysy_idx=1;
}
__attribute((destructor)) void after_main(){
for(int i=1;i<_sysy_idx;i++){
fprintf(stderr,"Timer@%04d-%04d: %dH-%dM-%dS-%dus\n",\
_sysy_l1[i],_sysy_l2[i],_sysy_h[i],_sysy_m[i],_sysy_s[i],_sysy_us[i]);
_sysy_us[0]+= _sysy_us[i];
_sysy_s[0] += _sysy_s[i]; _sysy_us[0] %= 1000000;
_sysy_m[0] += _sysy_m[i]; _sysy_s[0] %= 60;
_sysy_h[0] += _sysy_h[i]; _sysy_m[0] %= 60;
}
fprintf(stderr,"TOTAL: %dH-%dM-%dS-%dus\n",_sysy_h[0],_sysy_m[0],_sysy_s[0],_sysy_us[0]);
}
void _sysy_starttime(int lineno){
_sysy_l1[_sysy_idx] = lineno;
gettimeofday(&_sysy_start,NULL);
}
void _sysy_stoptime(int lineno){
gettimeofday(&_sysy_end,NULL);
_sysy_l2[_sysy_idx] = lineno;
_sysy_us[_sysy_idx] += 1000000 * ( _sysy_end.tv_sec - _sysy_start.tv_sec ) + _sysy_end.tv_usec - _sysy_start.tv_usec;
_sysy_s[_sysy_idx] += _sysy_us[_sysy_idx] / 1000000 ; _sysy_us[_sysy_idx] %= 1000000;
_sysy_m[_sysy_idx] += _sysy_s[_sysy_idx] / 60 ; _sysy_s[_sysy_idx] %= 60;
_sysy_h[_sysy_idx] += _sysy_m[_sysy_idx] / 60 ; _sysy_m[_sysy_idx] %= 60;
_sysy_idx ++;
}

@ -1,4 +1,31 @@
// SysY 运行库头文件:
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明
#ifndef __SYLIB_H_
#define __SYLIB_H_
#include<stdio.h>
#include<stdarg.h>
#include<sys/time.h>
/* Input & output functions */
int getint(),getch(),getarray(int a[]);
float getfloat();
int getfarray(float a[]);
void putint(int a),putch(int a),putarray(int n,int a[]);
void putfloat(float a);
void putfarray(int n, float a[]);
void putf(char a[], ...);
/* Timing function implementation */
struct timeval _sysy_start,_sysy_end;
#define starttime() _sysy_starttime(__LINE__)
#define stoptime() _sysy_stoptime(__LINE__)
#define _SYSY_N 1024
int _sysy_l1[_SYSY_N],_sysy_l2[_SYSY_N];
int _sysy_h[_SYSY_N], _sysy_m[_SYSY_N],_sysy_s[_SYSY_N],_sysy_us[_SYSY_N];
int _sysy_idx;
__attribute((constructor)) void before_main();
__attribute((destructor)) void after_main();
void _sysy_starttime(int lineno);
void _sysy_stoptime(int lineno);
#endif

@ -0,0 +1,7 @@
// test/test_case/functional/test_riscv.sy
int main() {
int a = 10;
int b = 20;
int c = a + b;
return c; // 应该返回30
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -0,0 +1,8 @@
//test domain of global var define and local define
int a = 3;
int b = 5;
int main(){
int a = 5;
return a + b;
}

@ -0,0 +1,8 @@
//test local var define
int main(){
int a, b0, _c;
a = 1;
b0 = 2;
_c = 3;
return b0 + _c;
}

@ -0,0 +1,4 @@
int a[10][10];
int main(){
return 0;
}

@ -0,0 +1,9 @@
//test array define
int main(){
int a[4][2] = {};
int b[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int c[4][2] = {{1, 2}, {3, 4}, {5, 6}, {7, 8}};
int d[4][2] = {1, 2, {3}, {5}, 7 , 8};
int e[4][2] = {{d[2][1], c[2][1]}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1] + e[0][0] + e[0][1] + a[2][0];
}

@ -0,0 +1,9 @@
int main(){
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
int b[4][2] = {};
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int d[3 + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
}

@ -0,0 +1,6 @@
//test const gloal var define
const int a = 10, b = 5;
int main(){
return b;
}

@ -0,0 +1,5 @@
//test const local var define
int main(){
const int a = 10, b = 5;
return b;
}

@ -0,0 +1,5 @@
const int a[5]={0,1,2,3,4};
int main(){
return a[4];
}

@ -0,0 +1,11 @@
int a;
int func(int p){
p = p - 1;
return p;
}
int main(){
int b;
a = 10;
b = func(a);
return b;
}

@ -0,0 +1,8 @@
int defn(){
return 4;
}
int main(){
int a=defn();
return a;
}

@ -0,0 +1,7 @@
//test add
int main(){
int a, b;
a = 10;
b = -1;
return a + b;
}

@ -0,0 +1,5 @@
//test addc
const int a = 10;
int main(){
return a + 5;
}

@ -0,0 +1,7 @@
//test sub
const int a = 10;
int main(){
int b;
b = 2;
return b - a;
}

@ -0,0 +1,6 @@
//test subc
int main(){
int a;
a = 10;
return a - 2;
}

@ -0,0 +1,7 @@
//test mul
int main(){
int a, b;
a = 10;
b = 5;
return a * b;
}

@ -0,0 +1,5 @@
//test mulc
const int a = 5;
int main(){
return a * 5;
}

@ -0,0 +1,7 @@
//test div
int main(){
int a, b;
a = 10;
b = 5;
return a / b;
}

@ -0,0 +1,5 @@
//test divc
const int a = 10;
int main(){
return a / 5;
}

@ -0,0 +1,6 @@
//test mod
int main(){
int a;
a = 10;
return a / 3;
}

@ -0,0 +1,6 @@
//test rem
int main(){
int a;
a = 10;
return a % 3;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save