Compare commits

...

24 Commits

Author SHA1 Message Date
cy f27138e6fb test(mir):删除已不需要的测试代码
2 hours ago
cy e26b9f8a43 fix(mir):通过了所有功能测试样例
3 hours ago
cy 6b6de49fcf fix(mir):修复浮点数的问题
2 days ago
cy f4658ec3fa fix(mir):修复浮点数的错误
2 days ago
cy 8ce783c967 fix(mir):修复一部分链接错误
2 days ago
cy 539e92d6bb fix(mir):修复部分大数组的排序错误
3 days ago
cy 2660960674 fix(mir):修复RV栈帧布局方向错误
6 days ago
ptabmhn4l 23c274eab6 Merge pull request '完成lab4' (#9) from ptabmhn4l/nudt-compiler-cpp:develop into develop
1 week ago
Junhe Wu 99826566e6 Merge branch 'feat/ir-opt' into develop
1 week ago
Junhe Wu 19928c4945 feat(ir-opt): 完成了lab4
1 week ago
Junhe Wu 827558938b fix(sylib): 使用官方提供的库文件
1 week ago
Junhe Wu c7e8b28d29 fix(testdata): 添加了2026年的测试用例
1 week ago
Junhe Wu e3de2c59af fix(ir): 修复了最后通不过的测试样例。
1 week ago
Junhe Wu 3c6ffe8e3e fix(ir):修复了一些ir的错误
4 weeks ago
ppxf25tqu de126b93d6 Merge pull request 'feat(mir):修正并完善功能' (#7) from pt9wfaocb/nudt-compiler-cpp:tansiping into develop
4 weeks ago
tansiping 310c7c3697 feat(mir):修正并完善功能
4 weeks ago
ptabmhn4l 248db05cf4 Merge pull request 'feat(mir):实现MIR后端' (#6) from pfwvrotsf/nudt-compiler-cpp:feature/mir into develop
4 weeks ago
cy feaba9abd4 fix(mir):修正测试用例
4 weeks ago
cy 1ff1b543d1 feat(mir): MIR 后端(RISC-V架构)
4 weeks ago
ptabmhn4l 80c46cee7e Merge pull request '初步通过verify测试' (#5) from ptabmhn4l/nudt-compiler-cpp:fix/irgen into develop
1 month ago
Junhe Wu 19ef82738f fix(irgen):通过了除了性能测试外的测试用例。
1 month ago
Junhe Wu 4693253459 Merge branch 'develop' of https://bdgit.educoder.net/ppxf25tqu/nudt-compiler-cpp into develop
1 month ago
ptabmhn4l fd45b74e2e Merge pull request '基本完成了ir生成' (#4) from ptabmhn4l/nudt-compiler-cpp:feature/ir into develop
1 month ago
ptabmhn4l 74bcb45776 Merge pull request '把比赛的测试用例放进来' (#2) from ptabmhn4l/nudt-compiler-cpp:fix/testdata into develop
1 month ago

1
.gitignore vendored

@ -54,6 +54,7 @@ compile_commands.json
.fleet/ .fleet/
.vs/ .vs/
*.code-workspace *.code-workspace
CLAUDE.md
# CLion # CLion
cmake-build-debug/ cmake-build-debug/

5745
30.txt

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,3 @@
bash scripts/run_ir_test.sh --run # 优化模式,计时
bash scripts/run_ir_test.sh --run --O0 # 无优化,计时
bash scripts/bench_ir.sh # 同时对比 O0 vs O1

@ -15,6 +15,7 @@
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -60,12 +61,14 @@ class Context {
~Context(); ~Context();
ConstantInt* GetConstInt(int v); ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v); ConstantFloat* GetConstFloat(float v);
std::string NextTemp(); std::string NextTemp(); // 用于指令名(数字,连续)
std::string NextLabel(); // 用于块名(字母前缀,独立计数)
private: private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_; std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_; std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1; int temp_index_ = -1;
int label_index_ = -1;
}; };
// ─── Type ───────────────────────────────────────────────────────────────────── // ─── Type ─────────────────────────────────────────────────────────────────────
@ -160,6 +163,8 @@ enum class Opcode {
Gep, Gep,
// 控制流 // 控制流
Ret, Br, CondBr, Ret, Br, CondBr,
// PHI 节点
Phi,
// 函数调用 // 函数调用
Call, Call,
// 类型转换 // 类型转换
@ -198,16 +203,30 @@ class GlobalValue : public User {
class GlobalVariable : public Value { class GlobalVariable : public Value {
public: public:
GlobalVariable(std::string name, bool is_const, int init_val, GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements = 1); int num_elements = 1, bool is_array_decl = false,
bool is_float = false);
bool IsConst() const { return is_const_; } bool IsConst() const { return is_const_; }
bool IsFloat() const { return is_float_; }
int GetInitVal() const { return init_val_; } int GetInitVal() const { return init_val_; }
float GetInitValF() const { return init_val_f_; }
int GetNumElements() const { return num_elements_; } int GetNumElements() const { return num_elements_; }
bool IsArray() const { return num_elements_ > 1; } bool IsArray() const { return is_array_decl_ || num_elements_ > 1; }
// GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store void SetInitVals(std::vector<int> v) { init_vals_ = std::move(v); }
void SetInitValsF(std::vector<float> v) { init_vals_f_ = std::move(v); }
void SetInitValF(float v){ init_val_f_ = v; }
void SetInitVal(int v){ init_val_ = v;}
const std::vector<int>& GetInitVals() const { return init_vals_; }
const std::vector<float>& GetInitValsF() const { return init_vals_f_; }
bool HasInitVals() const { return !init_vals_.empty() || !init_vals_f_.empty(); }
private: private:
bool is_const_; bool is_const_;
bool is_float_;
int init_val_; int init_val_;
float init_val_f_;
int num_elements_; int num_elements_;
bool is_array_decl_;
std::vector<int> init_vals_;
std::vector<float> init_vals_f_;
}; };
// ─── Instruction ────────────────────────────────────────────────────────────── // ─── Instruction ──────────────────────────────────────────────────────────────
@ -218,6 +237,7 @@ class Instruction : public User {
bool IsTerminator() const; bool IsTerminator() const;
BasicBlock* GetParent() const; BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent); void SetParent(BasicBlock* parent);
void RemoveFromParent();
private: private:
Opcode opcode_; Opcode opcode_;
@ -358,6 +378,18 @@ class StoreInst : public Instruction {
Value* GetPtr() const; Value* GetPtr() const;
}; };
// PHI 节点:在控制流汇合处选择值
// 操作数布局:[val0, bb0, val1, bb1, ...](偶数下标为值,奇数下标为基本块)
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* val, BasicBlock* bb);
size_t GetNumIncoming() const { return GetNumOperands() / 2; }
Value* GetIncomingValue(size_t i) const { return GetOperand(i * 2); }
BasicBlock* GetIncomingBlock(size_t i) const;
void SetIncomingValue(size_t i, Value* val) { SetOperand(i * 2, val); }
};
// ─── BasicBlock ─────────────────────────────────────────────────────────────── // ─── BasicBlock ───────────────────────────────────────────────────────────────
class BasicBlock : public Value { class BasicBlock : public Value {
public: public:
@ -368,6 +400,16 @@ class BasicBlock : public Value {
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const; const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const; const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const; const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* bb);
void RemovePredecessor(BasicBlock* bb);
void ClearPredecessors();
void AddSuccessor(BasicBlock* bb);
void ClearSuccessors();
// 指令管理
void RemoveInstruction(Instruction* inst);
// 在 before 之前插入指令before 为 nullptr 时追加到末尾
void InsertBefore(Instruction* inst, Instruction* before);
template <typename T, typename... Args> template <typename T, typename... Args>
T* Append(Args&&... args) { T* Append(Args&&... args) {
@ -381,6 +423,29 @@ class BasicBlock : public Value {
return ptr; return ptr;
} }
template <typename T, typename... Args>
T* Prepend(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr;
}
// Insert before terminator (or append if no terminator)
template <typename T, typename... Args>
T* InsertBeforeTerminator(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
instructions_.insert(instructions_.end() - 1, std::move(inst));
} else {
instructions_.push_back(std::move(inst));
}
return ptr;
}
private: private:
Function* parent_ = nullptr; Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_; std::vector<std::unique_ptr<Instruction>> instructions_;
@ -409,6 +474,12 @@ class Function : public Value {
Argument* GetArgument(size_t i) const; Argument* GetArgument(size_t i) const;
size_t GetNumArgs() const { return args_.size(); } size_t GetNumArgs() const { return args_.size(); }
bool IsVoidReturn() const { return type_->IsVoid(); } bool IsVoidReturn() const { return type_->IsVoid(); }
// 将某个块移动到 blocks_ 列表末尾(用于确保块顺序正确)
void MoveBlockToEnd(BasicBlock* bb);
// 重建 CFG根据终结指令计算所有块的前驱/后继
void RebuildCFG();
// 从函数中移除一个基本块
void RemoveBlock(BasicBlock* bb);
private: private:
BasicBlock* entry_ = nullptr; BasicBlock* entry_ = nullptr;
@ -437,7 +508,9 @@ class Module {
const std::vector<std::unique_ptr<Function>>& GetFunctions() const; const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const, GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const,
int init_val, int num_elements = 1); int init_val, int num_elements = 1,
bool is_array_decl = false,
bool is_float = false);
GlobalVariable* GetGlobalVariable(const std::string& name) const; GlobalVariable* GetGlobalVariable(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const; const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const;
@ -494,9 +567,12 @@ class IRBuilder {
AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaF32(const std::string& name); AllocaInst* CreateAllocaF32(const std::string& name);
AllocaInst* CreateAllocaArray(int num_elements, const std::string& name); AllocaInst* CreateAllocaArray(int num_elements, const std::string& name);
AllocaInst* CreateAllocaArrayF32(int num_elements, const std::string& name);
GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name); GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr); StoreInst* CreateStore(Value* val, Value* ptr);
// 零初始化数组emit memset call
void CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod);
// 控制流 // 控制流
ReturnInst* CreateRet(Value* v); ReturnInst* CreateRet(Value* v);
@ -504,6 +580,8 @@ class IRBuilder {
BrInst* CreateBr(BasicBlock* target); BrInst* CreateBr(BasicBlock* target);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_bb, CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_bb,
BasicBlock* false_bb); BasicBlock* false_bb);
// PHI 节点(添加到当前块开头)
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
// 调用 // 调用
CallInst* CreateCall(Function* callee, std::vector<Value*> args, CallInst* CreateCall(Function* callee, std::vector<Value*> args,
@ -518,9 +596,31 @@ class IRBuilder {
SIToFPInst* CreateSIToFP(Value* val, const std::string& name); SIToFPInst* CreateSIToFP(Value* val, const std::string& name);
FPToSIInst* CreateFPToSI(Value* val, const std::string& name); FPToSIInst* CreateFPToSI(Value* val, const std::string& name);
void SetAllocaBlock(BasicBlock* bb) { alloca_block_ = bb; }
private: private:
Context& ctx_; Context& ctx_;
BasicBlock* insert_block_; BasicBlock* insert_block_;
BasicBlock* alloca_block_ = nullptr;
};
// ─── DominatorTree ────────────────────────────────────────────────────────────
class DominatorTree {
public:
void Compute(Function& func);
BasicBlock* GetIDom(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetChildren(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetDominanceFrontier(BasicBlock* bb) const;
bool Dominates(BasicBlock* a, BasicBlock* b) const;
const std::vector<BasicBlock*>& GetDFOrder() const { return df_order_; }
private:
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
std::unordered_map<BasicBlock*, size_t> dom_level_;
std::vector<BasicBlock*> df_order_;
std::unordered_set<BasicBlock*> visited_;
}; };
// ─── IRPrinter ──────────────────────────────────────────────────────────────── // ─── IRPrinter ────────────────────────────────────────────────────────────────
@ -529,4 +629,7 @@ class IRPrinter {
void Print(const Module& module, std::ostream& os); void Print(const Module& module, std::ostream& os);
}; };
// ─── Pass Manager ────────────────────────────────────────────────────────────
void RunPasses(Module& module);
} // namespace ir } // namespace ir

@ -113,5 +113,4 @@ class IRGenImpl final : public SysYBaseVisitor {
std::vector<LoopCtx> loop_stack_; std::vector<LoopCtx> loop_stack_;
}; };
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, const SemanticContext& sema);
const SemanticContext& sema);

@ -19,39 +19,165 @@ class MIRContext {
MIRContext& DefaultContext(); MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP }; // RISC-V 64位寄存器定义
enum class PhysReg {
// 通用寄存器
ZERO, // x0, 恒为0
RA, // x1, 返回地址
SP, // x2, 栈指针
GP, // x3, 全局指针
TP, // x4, 线程指针
T0, // x5, 临时寄存器
T1, // x6, 临时寄存器
T2, // x7, 临时寄存器
S0, // x8, 帧指针/保存寄存器
S1, // x9, 保存寄存器
A0, // x10, 参数/返回值
A1, // x11, 参数
A2, // x12, 参数
A3, // x13, 参数
A4, // x14, 参数
A5, // x15, 参数
A6, // x16, 参数
A7, // x17, 参数
S2, // x18, 保存寄存器
S3, // x19, 保存寄存器
S4, // x20, 保存寄存器
S5, // x21, 保存寄存器
S6, // x22, 保存寄存器
S7, // x23, 保存寄存器
S8, // x24, 保存寄存器
S9, // x25, 保存寄存器
S10, // x26, 保存寄存器
S11, // x27, 保存寄存器
T3, // x28, 临时寄存器
T4, // x29, 临时寄存器
T5, // x30, 临时寄存器
T6, // x31, 临时寄存器
FT0, FT1, FT2, FT3, FT4, FT5, FT6, FT7,
FS0, FS1,
FA0, FA1, FA2, FA3, FA4, FA5, FA6, FA7,
FT8, FT9, FT10, FT11,
};
const char* PhysRegName(PhysReg reg); const char* PhysRegName(PhysReg reg);
// 在 MIR.h 中添加(在 Opcode 枚举之前)
struct GlobalVarInfo {
std::string name;
int value;
float valueF;
bool isConst;
bool isArray;
bool isFloat;
std::vector<int> arrayValues;
std::vector<float> arrayValuesF;
int arraySize;
};
enum class Opcode { enum class Opcode {
Prologue, Prologue,
Epilogue, Epilogue,
MovImm, MovImm,
LoadStack, Load,
StoreStack, Store,
AddRR, Add,
Addi,
Sub,
Mul,
Div,
Rem,
Slt,
Slti,
Slli,
Sltu, // 无符号小于
Sltiu,
Xori,
LoadGlobalAddr,
LoadGlobal,
StoreGlobal,
LoadIndirect, // lw rd, 0(rs1) 从寄存器地址加载
StoreIndirect, // sw rs2, 0(rs1)
LoadIndirectFloat, // flw rd, 0(rs1)
StoreIndirectFloat, // fsw rs2, 0(rs1)
Call,
GEP,
LoadAddr,
Ret, Ret,
// 浮点指令
FMov, // 浮点移动
FMovWX, // fmv.w.x fs, x 整数寄存器移动到浮点寄存器
FMovXW, // fmv.x.w x, fs 浮点寄存器移动到整数寄存器
FAdd,
FSub,
FMul,
FDiv,
FEq, // 浮点相等比较
FLt, // 浮点小于比较
FLe, // 浮点小于等于比较
FNeg, // 浮点取反
FAbs, // 浮点绝对值
SIToFP, // int 转 float
FPToSI, // float 转 int
LoadFloat, // 浮点加载 (flw)
StoreFloat, // 浮点存储 (fsw)
Br,
CondBr,
Label,
LoadCallerStackArg, // 从调用者栈帧加载参数
LoadCallerStackArgFloat, // 从调用者栈帧加载浮点参数
};
enum class GlobalKind {
Data, // .data 段(已初始化)
BSS, // .bss 段未初始化初始为0
RoData // .rodata 段(只读常量)
};
// 全局变量信息
struct GlobalInfo {
std::string name;
GlobalKind kind;
int size; // 大小(字节)
int value; // 初始值(对于简单变量)
bool isArray;
int arraySize;
std::vector<int> dimensions; // 数组维度
}; };
class Operand { class Operand {
public: public:
enum class Kind { Reg, Imm, FrameIndex }; enum class Kind { Reg, Imm, FrameIndex, Global, Func };
static Operand Reg(PhysReg reg); static Operand Reg(PhysReg reg);
static Operand Imm(int value); static Operand Imm(int value);
static Operand Imm64(int64_t value); // 新增:存储 64 位值
static Operand FrameIndex(int index); static Operand FrameIndex(int index);
static Operand Global(const std::string& name);
static Operand Func(const std::string& name);
Kind GetKind() const { return kind_; } Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; } PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; } int GetImm() const { return imm_; }
int64_t GetImm64() const { return imm64_; } // 新增
int GetFrameIndex() const { return imm_; } int GetFrameIndex() const { return imm_; }
const std::string& GetGlobalName() const { return global_name_; }
const std::string& GetFuncName() const { return func_name_; }
private: private:
Operand(Kind kind, PhysReg reg, int imm); Operand(Kind kind, PhysReg reg, int imm);
Operand(Kind kind, PhysReg reg, int64_t imm64); // 新增构造函数
Operand(Kind kind, PhysReg reg, int imm, const std::string& name);
Kind kind_; Kind kind_;
PhysReg reg_; PhysReg reg_;
int imm_; int imm_;
int64_t imm64_; // 新增
std::string global_name_;
std::string func_name_;
}; };
class MachineInstr { class MachineInstr {
@ -71,7 +197,6 @@ struct FrameSlot {
int size = 4; int size = 4;
int offset = 0; int offset = 0;
}; };
class MachineBasicBlock { class MachineBasicBlock {
public: public:
explicit MachineBasicBlock(std::string name); explicit MachineBasicBlock(std::string name);
@ -93,9 +218,14 @@ class MachineFunction {
explicit MachineFunction(std::string name); explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; } const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
// 基本块管理
MachineBasicBlock* CreateBlock(const std::string& name);
MachineBasicBlock* GetEntry() { return entry_; }
const MachineBasicBlock* GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const { return blocks_; }
// 栈帧管理
int CreateFrameIndex(int size = 4); int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index); FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const; const FrameSlot& GetFrameSlot(int index) const;
@ -103,17 +233,20 @@ class MachineFunction {
int GetFrameSize() const { return frame_size_; } int GetFrameSize() const { return frame_size_; }
void SetFrameSize(int size) { frame_size_ = size; } void SetFrameSize(int size) { frame_size_ = size; }
int GetLocalVarsSize() const { return local_vars_size_; }
void SetLocalVarsSize(int s) { local_vars_size_ = s; }
private: private:
std::string name_; std::string name_;
MachineBasicBlock entry_; MachineBasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_; std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0; int frame_size_ = 0;
int local_vars_size_ = 0;
}; };
//std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function); void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function); void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os); //void PrintAsm(const MachineFunction& function, std::ostream& os);
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module);
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os);
} // namespace mir } // namespace mir

@ -1,14 +1,16 @@
// 简易命令行解析:支持帮助、输入文件与输出阶段选择。 // 命令行解析compiler <input.sy> -S|-IR -o <output> [-O1]
// 同时兼容 --emit-ir / --emit-asm
#pragma once #pragma once
#include <string> #include <string>
struct CLIOptions { struct CLIOptions {
std::string input; std::string input;
bool emit_parse_tree = false; std::string output; // -o <file>,为空则输出到 stdout
bool emit_ir = true; bool emit_ir = false; // -IR / --emit-ir
bool emit_asm = false; bool emit_asm = false; // -S / --emit-asm
bool show_help = false; bool show_help = false;
bool opt = false; // -O1
}; };
CLIOptions ParseCLI(int argc, char** argv); CLIOptions ParseCLI(int argc, char** argv);

272
ir.txt

@ -0,0 +1,272 @@
@MAX = global i32 1000000000
@TWO = global i32 2
@THREE = global i32 3
@FIVE = global i32 5
declare void @putch(i32)
declare void @memset(i32*, i32, i32)
declare i32 @getfarray(float*)
declare float @getfloat()
declare void @putfloat(float)
declare void @putint(i32)
declare void @putfarray(i32, float*)
define float @float_abs(float %x) {
entry:
%0 = alloca float
store float %x, float* %0
%2 = load float, float* %0
%3 = sitofp i32 0 to float
%4 = fcmp olt float %2, %3
br i1 %4, label %L0.if.then, label %L1.if.end
L0.if.then:
%6 = load float, float* %0
%7 = fsub float 0x0, %6
ret float %7
L1.if.end:
%9 = load float, float* %0
ret float %9
}
define float @circle_area(i32 %radius) {
entry:
%0 = alloca i32
store i32 %radius, i32* %0
%2 = load i32, i32* %0
%3 = sitofp i32 %2 to float
%4 = fmul float 0x400921FB60000000, %3
%5 = load i32, i32* %0
%6 = sitofp i32 %5 to float
%7 = fmul float %4, %6
%8 = load i32, i32* %0
%9 = load i32, i32* %0
%10 = mul i32 %8, %9
%11 = sitofp i32 %10 to float
%12 = fmul float %11, 0x400921FB60000000
%13 = fadd float %7, %12
%14 = sitofp i32 2 to float
%15 = fdiv float %13, %14
ret float %15
}
define i32 @float_eq(float %a, float %b) {
entry:
%0 = alloca float
%1 = alloca float
store float %a, float* %0
store float %b, float* %1
%4 = load float, float* %0
%5 = load float, float* %1
%6 = fsub float %4, %5
%7 = call float @float_abs(float %6)
%8 = fcmp olt float %7, 0x3EB0C6F7A0000000
br i1 %8, label %L2.if.then, label %L3.if.else
L2.if.then:
%10 = sitofp i32 1 to float
%11 = fmul float %10, 0x4000000000000000
%12 = sitofp i32 2 to float
%13 = fdiv float %11, %12
%14 = fptosi float %13 to i32
ret i32 %14
L3.if.else:
ret i32 0
L4.if.end:
ret i32 0
}
define void @error() {
entry:
call void @putch(i32 101)
call void @putch(i32 114)
call void @putch(i32 114)
call void @putch(i32 111)
call void @putch(i32 114)
call void @putch(i32 10)
ret void
}
define void @ok() {
entry:
call void @putch(i32 111)
call void @putch(i32 107)
call void @putch(i32 10)
ret void
}
define void @assert(i32 %cond) {
entry:
%0 = alloca i32
store i32 %cond, i32* %0
%2 = load i32, i32* %0
%3 = icmp eq i32 %2, 0
%4 = zext i1 %3 to i32
%5 = icmp ne i32 %4, 0
br i1 %5, label %L5.if.then, label %L6.if.else
L5.if.then:
call void @error()
br label %L7.if.end
L6.if.else:
call void @ok()
br label %L7.if.end
L7.if.end:
ret void
}
define void @assert_not(i32 %cond) {
entry:
%0 = alloca i32
store i32 %cond, i32* %0
%2 = load i32, i32* %0
%3 = icmp ne i32 %2, 0
br i1 %3, label %L8.if.then, label %L9.if.else
L8.if.then:
call void @error()
br label %L10.if.end
L9.if.else:
call void @ok()
br label %L10.if.end
L10.if.end:
ret void
}
define i32 @main() {
entry:
%0 = alloca i32
%1 = alloca i32
%2 = alloca i32
%3 = alloca i32
%4 = alloca float, i32 10
%5 = alloca i32
%6 = alloca float
%7 = alloca float
%8 = alloca float
%9 = call i32 @float_eq(float 0x3FB4000000000000, float 0xC0E01D0000000000)
call void @assert_not(i32 %9)
%11 = call i32 @float_eq(float 0x4057C21FC0000000, float 0x4041475CE0000000)
call void @assert_not(i32 %11)
%13 = call i32 @float_eq(float 0x4041475CE0000000, float 0x4041475CE0000000)
call void @assert(i32 %13)
%15 = fptosi float 0x4016000000000000 to i32
%16 = call float @circle_area(i32 %15)
%17 = load i32, i32* @FIVE
%18 = call float @circle_area(i32 %17)
%19 = call i32 @float_eq(float %16, float %18)
call void @assert(i32 %19)
%21 = call i32 @float_eq(float 0x406D200000000000, float 0x40AFFE0000000000)
call void @assert_not(i32 %21)
%23 = fcmp one float 0x3FF8000000000000, 0x0
br i1 %23, label %L11.if.then, label %L12.if.end
L11.if.then:
call void @ok()
br label %L12.if.end
L12.if.end:
%27 = fcmp oeq float 0x400A666660000000, 0x0
%28 = zext i1 %27 to i32
%29 = icmp eq i32 %28, 0
%30 = zext i1 %29 to i32
%31 = icmp ne i32 %30, 0
br i1 %31, label %L13.if.then, label %L14.if.end
L13.if.then:
call void @ok()
br label %L14.if.end
L14.if.end:
%35 = fcmp one float 0x0, 0x0
%36 = zext i1 %35 to i32
store i32 %36, i32* %0
br i1 %35, label %L15.and.rhs, label %L16.and.end
L15.and.rhs:
%39 = icmp ne i32 3, 0
%40 = zext i1 %39 to i32
store i32 %40, i32* %0
br label %L16.and.end
L16.and.end:
%43 = load i32, i32* %0
%44 = icmp ne i32 %43, 0
br i1 %44, label %L17.if.then, label %L18.if.end
L17.if.then:
call void @error()
br label %L18.if.end
L18.if.end:
%48 = icmp ne i32 0, 0
%49 = zext i1 %48 to i32
store i32 %49, i32* %1
br i1 %48, label %L20.or.end, label %L19.or.rhs
L19.or.rhs:
%52 = fcmp one float 0x3FD3333340000000, 0x0
%53 = zext i1 %52 to i32
store i32 %53, i32* %1
br label %L20.or.end
L20.or.end:
%56 = load i32, i32* %1
%57 = icmp ne i32 %56, 0
br i1 %57, label %L21.if.then, label %L22.if.end
L21.if.then:
call void @ok()
br label %L22.if.end
L22.if.end:
store i32 1, i32* %2
store i32 0, i32* %3
call void @memset(float* %4, i32 0, i32 40)
%64 = getelementptr float, float* %4, i32 0
store float 0x3FF0000000000000, float* %64
%66 = getelementptr float, float* %4, i32 1
store i32 2, float* %66
%68 = getelementptr float, float* %4, i32 0
%69 = call i32 @getfarray(float* %68)
store i32 %69, i32* %5
br label %L23.while.cond
L23.while.cond:
%72 = load i32, i32* %2
%73 = load i32, i32* @MAX
%74 = icmp slt i32 %72, %73
br i1 %74, label %L24.while.body, label %L25.while.end
L24.while.body:
%76 = call float @getfloat()
store float %76, float* %6
%78 = load float, float* %6
%79 = fmul float 0x400921FB60000000, %78
%80 = load float, float* %6
%81 = fmul float %79, %80
store float %81, float* %7
%83 = load float, float* %6
%84 = fptosi float %83 to i32
%85 = call float @circle_area(i32 %84)
store float %85, float* %8
%87 = load i32, i32* %3
%88 = getelementptr float, float* %4, i32 %87
%89 = load float, float* %88
%90 = load float, float* %6
%91 = fadd float %89, %90
%92 = load i32, i32* %3
%93 = getelementptr float, float* %4, i32 %92
store float %91, float* %93
%95 = load float, float* %7
call void @putfloat(float %95)
call void @putch(i32 32)
%98 = load float, float* %8
%99 = fptosi float %98 to i32
call void @putint(i32 %99)
call void @putch(i32 10)
%102 = load i32, i32* %2
%103 = fsub float 0x0, 0x4024000000000000
%104 = fsub float 0x0, %103
%105 = sitofp i32 %102 to float
%106 = fmul float %105, %104
%107 = fptosi float %106 to i32
store i32 %107, i32* %2
%109 = load i32, i32* %3
%110 = add i32 %109, 1
store i32 %110, i32* %3
br label %L23.while.cond
L25.while.end:
%113 = load i32, i32* %5
%114 = getelementptr float, float* %4, i32 0
call void @putfarray(i32 %113, float* %114)
%116 = srem i32 0, 256
%117 = add i32 %116, 256
%118 = srem i32 %117, 256
call void @putint(i32 %118)
call void @putch(i32 10)
ret i32 0
}

@ -0,0 +1,309 @@
.text
.global main
.type main, @function
main:
addi sp, sp, -272
sw ra, 264(sp)
sw s0, 256(sp)
addi a0, sp, -4
li a1, 0
li a2, 32
call
addi a0, sp, -8
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -8
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -8
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -8
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -8
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -8
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -8
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -8
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -8
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -44
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -44
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -44
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -44
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -44
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -44
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -44
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -44
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -44
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -80
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -80
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -80
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -80
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -80
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -80
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -80
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -112(sp)
li t0, 1
lw t1, -112(sp)
add t0, t0, t1
sw t0, -116(sp)
addi t0, sp, -80
lw t1, -116(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -124(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -128(sp)
li t0, 1
lw t1, -128(sp)
add t0, t0, t1
sw t0, -132(sp)
addi t0, sp, -44
lw t1, -132(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -140(sp)
addi a0, sp, -108
li a1, 0
li a2, 32
call
lw t2, -124(sp)
addi t0, sp, -108
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
lw t2, -140(sp)
addi t0, sp, -108
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -108
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -108
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -108
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -108
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -108
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -108
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 3
li t1, 2
mul t0, t0, t1
sw t0, -176(sp)
li t0, 1
lw t1, -176(sp)
add t0, t0, t1
sw t0, -180(sp)
addi t0, sp, -108
lw t1, -180(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -188(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -192(sp)
li t0, 0
lw t1, -192(sp)
add t0, t0, t1
sw t0, -196(sp)
addi t0, sp, -108
lw t1, -196(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -204(sp)
lw t0, -188(sp)
lw t1, -204(sp)
add t0, t0, t1
sw t0, -208(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -212(sp)
li t0, 1
lw t1, -212(sp)
add t0, t0, t1
sw t0, -216(sp)
addi t0, sp, -108
lw t1, -216(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -224(sp)
lw t0, -208(sp)
lw t1, -224(sp)
add t0, t0, t1
sw t0, -228(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -232(sp)
li t0, 0
lw t1, -232(sp)
add t0, t0, t1
sw t0, -236(sp)
addi t0, sp, -4
lw t1, -236(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -244(sp)
lw t0, -228(sp)
lw t1, -244(sp)
add t0, t0, t1
sw t0, -248(sp)
lw a0, -248(sp)
lw ra, 264(sp)
lw s0, 256(sp)
addi sp, sp, 272
ret
.size main, .-main

@ -0,0 +1,147 @@
#!/usr/bin/env bash
# 优化效果对比:测量 O0 vs O1 的编译时间和运行时间
# 用法: bash scripts/bench_ir.sh [--test-dir=<dir>] [--result-dir=<dir>]
set -uo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
TEST_CASE_DIR="${PROJECT_ROOT}/test/test_case"
RESULT_DIR="${PROJECT_ROOT}/test/test_result/bench"
while [[ $# -gt 0 ]]; do
case "$1" in
--test-dir=*) TEST_CASE_DIR="${1#*=}" ;;
--result-dir=*) RESULT_DIR="${1#*=}" ;;
*) echo "未知参数: $1" >&2; exit 1 ;;
esac
shift
done
compiler="${PROJECT_ROOT}/build/bin/compiler"
[[ -x "$compiler" ]] || { echo "错误:未找到编译器 $compiler" >&2; exit 1; }
command -v llc >/dev/null 2>&1 || { echo "错误:未找到 llc" >&2; exit 1; }
command -v clang >/dev/null 2>&1 || { echo "错误:未找到 clang" >&2; exit 1; }
mkdir -p "$RESULT_DIR"
# 时间测量:使用 date +%s.%N
now() { date +%s.%N; }
elapsed() { python3 -c "print(f'{float($2)-float($1):.4f}')" 2>/dev/null || awk "BEGIN{printf \"%.4f\\n\",$2-$1}"; }
summary_file="${RESULT_DIR}/summary.csv"
echo "test,opt,compile_s,exec_s,compile+exec_s" > "$summary_file"
total=0
o0_ct_total=0; o1_ct_total=0
o0_et_total=0; o1_et_total=0
echo "=== 优化效果对比 O0 vs O1 ==="
echo ""
while read -r test_file; do
full_path=$(readlink -f "$test_file")
tcdir=$(readlink -f "$TEST_CASE_DIR")
rel="${full_path#$tcdir}"
[[ "${rel:0:1}" != "/" ]] && rel="/$rel"
base=$(basename "$test_file")
stem="${base%.sy}"
idir=$(dirname "$test_file")
stdin="${idir}/${stem}.in"
expected="${idir}/${stem}.out"
total=$((total+1))
printf "[%4d] %s" "$total" "$rel"
o0_ll="${RESULT_DIR}/O0/${rel%.sy}.ll"
o1_ll="${RESULT_DIR}/O1/${rel%.sy}.ll"
mkdir -p "$(dirname "$o0_ll")" "$(dirname "$o1_ll")"
# --- 编译 O0 ---
t1=$(now)
"$compiler" "$test_file" -IR -o "$o0_ll" 2>/dev/null; rc0=$?
t2=$(now)
if [[ $rc0 -ne 0 ]]; then
echo " | O0编译失败"
echo "$stem,O0,-,-,-" >> "$summary_file"
echo "$stem,O1,-,-,-" >> "$summary_file"
continue
fi
o0_ct=$(elapsed "$t1" "$t2")
# --- 编译 O1 ---
t1=$(now)
"$compiler" "$test_file" -IR -o "$o1_ll" -O1 2>/dev/null; rc1=$?
t2=$(now)
if [[ $rc1 -ne 0 ]]; then
echo " | O1编译失败"
echo "$stem,O0,$o0_ct,-,-" >> "$summary_file"
echo "$stem,O1,-,-,-" >> "$summary_file"
continue
fi
o1_ct=$(elapsed "$t1" "$t2")
# --- llc + clang O0 ---
o0_obj="${RESULT_DIR}/O0/${stem}.o"
o1_obj="${RESULT_DIR}/O1/${stem}.o"
o0_exe="${RESULT_DIR}/O0/${stem}.exe"
o1_exe="${RESULT_DIR}/O1/${stem}.exe"
llc -filetype=obj "$o0_ll" -o "$o0_obj" 2>/dev/null
llc -filetype=obj "$o1_ll" -o "$o1_obj" 2>/dev/null
clang "$o0_obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$o0_exe" -lm 2>/dev/null
clang "$o1_obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$o1_exe" -lm 2>/dev/null
# --- 运行 O0 ---
t1=$(now)
sr0=0
if [[ -f "$stdin" ]]; then
(ulimit -s unlimited; "$o0_exe" < "$stdin") > /dev/null 2>&1 || sr0=$?
else
(ulimit -s unlimited; "$o0_exe") > /dev/null 2>&1 || sr0=$?
fi
t2=$(now)
o0_et=$(elapsed "$t1" "$t2")
# --- 运行 O1 ---
t1=$(now)
sr1=0
if [[ -f "$stdin" ]]; then
(ulimit -s unlimited; "$o1_exe" < "$stdin") > /dev/null 2>&1 || sr1=$?
else
(ulimit -s unlimited; "$o1_exe") > /dev/null 2>&1 || sr1=$?
fi
t2=$(now)
o1_et=$(elapsed "$t1" "$t2")
# 验证一致性
flag=""
if [[ $sr0 -ne $sr1 ]]; then
flag=" EXIT:O0=$sr0 O1=$sr1"
fi
# 累计 & 比率
o0_ct_total=$(awk "BEGIN{printf \"%.4f\",$o0_ct_total+$o0_ct}")
o1_ct_total=$(awk "BEGIN{printf \"%.4f\",$o1_ct_total+$o1_ct}")
o0_et_total=$(awk "BEGIN{printf \"%.4f\",$o0_et_total+$o0_et}")
o1_et_total=$(awk "BEGIN{printf \"%.4f\",$o1_et_total+$o1_et}")
cspd=$(awk "BEGIN{if($o1_ct>0)printf \"%.1fx\",$o0_ct/$o1_ct; else print \"-\"}")
espd=$(awk "BEGIN{if($o1_et>0)printf \"%.1fx\",$o0_et/$o1_et; else print \"-\"}")
printf " | 编译 O0:%.4fs O1:%.4fs(%s) 运行 O0:%.4fs O1:%.4fs(%s)%s\n" \
"$o0_ct" "$o1_ct" "$cspd" "$o0_et" "$o1_et" "$espd" "$flag"
echo "$stem,O0,$o0_ct,$o0_et,$(awk "BEGIN{printf \"%.4f\",$o0_ct+$o0_et}")" >> "$summary_file"
echo "$stem,O1,$o1_ct,$o1_et,$(awk "BEGIN{printf \"%.4f\",$o1_ct+$o1_et}")" >> "$summary_file"
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
echo ""
echo "============================================"
echo "总用例: $total"
echo "O0 编译总耗时: ${o0_ct_total}s"
echo "O1 编译总耗时: ${o1_ct_total}s"
echo "O0 运行总耗时: ${o0_et_total}s"
echo "O1 运行总耗时: ${o1_et_total}s"
echo "CSV: $summary_file"

@ -0,0 +1,100 @@
#!/usr/bin/env python3
import sys
def read_file(filepath):
"""读取文件内容,返回行列表"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
return f.readlines()
except FileNotFoundError:
print(f"错误:文件 '{filepath}' 不存在")
sys.exit(1)
def get_context(line, pos, context_len=10):
"""获取字符上下文"""
start = max(0, pos - context_len)
end = min(len(line), pos + context_len + 1)
prefix = "..." if start > 0 else ""
suffix = "..." if end < len(line) else ""
return prefix + line[start:end] + suffix
def compare_files(file1, file2):
"""比较两个文件,输出详细差异"""
lines1 = read_file(file1)
lines2 = read_file(file2)
print(f"比较文件: {file1} vs {file2}")
print("=" * 80)
max_lines = max(len(lines1), len(lines2))
differences = 0
for line_num in range(max_lines):
line1 = lines1[line_num] if line_num < len(lines1) else None
line2 = lines2[line_num] if line_num < len(lines2) else None
if line1 is None:
print(f"\n[新增行] 第 {line_num + 1}")
print(f" + {repr(line2)}")
differences += 1
continue
if line2 is None:
print(f"\n[删除行] 第 {line_num + 1}")
print(f" - {repr(line1)}")
differences += 1
continue
if line1 == line2:
continue
# 行内容不同,逐字符比较
print(f"\n[差异行] 第 {line_num + 1}")
max_chars = max(len(line1), len(line2))
for char_pos in range(max_chars):
char1 = line1[char_pos] if char_pos < len(line1) else None
char2 = line2[char_pos] if char_pos < len(line2) else None
if char1 == char2:
continue
# 找到差异字符
context1 = get_context(line1, char_pos) if line1 else ""
context2 = get_context(line2, char_pos) if line2 else ""
print(f" 字符位置 {char_pos + 1}:")
if char1 is not None:
print(f" - {repr(char1)} | 上下文: {repr(context1)}")
else:
print(f" - (缺失)")
if char2 is not None:
print(f" + {repr(char2)} | 上下文: {repr(context2)}")
else:
print(f" + (缺失)")
differences += 1
# 跳过一些连续差异,避免输出过多
while char_pos + 1 < max_chars:
next_char1 = line1[char_pos + 1] if char_pos + 1 < len(line1) else None
next_char2 = line2[char_pos + 1] if char_pos + 1 < len(line2) else None
if next_char1 != next_char2:
char_pos += 1
else:
break
print("\n" + "=" * 80)
print(f"比较完成,共发现 {differences} 处差异")
def main():
if len(sys.argv) != 3:
print("用法: python diff.py <文件1> <文件2>")
sys.exit(1)
file1 = sys.argv[1]
file2 = sys.argv[2]
compare_files(file1, file2)
if __name__ == "__main__":
main()

@ -0,0 +1,154 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/mir"
SYLIB_C="$PROJECT_ROOT/sylib/sylib.c"
SYLIB_O="/tmp/sylib.o"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
exit 1
fi
# 编译 sylib 运行时库
echo "编译运行时库..."
if [ ! -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -c "$SYLIB_C" -o "$SYLIB_O" 2>/dev/null
if [ $? -ne 0 ]; then
echo "警告:无法编译 sylib.c部分测试可能链接失败"
fi
fi
echo ""
mkdir -p "$TEST_RESULT_DIR"
echo "=========================================="
echo "RISC-V 后端测试"
echo "=========================================="
echo ""
# 收集测试用例
mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort)
total=${#test_files[@]}
pass_gen=0
fail_gen=0
pass_run=0
fail_run=0
timeout_cnt=0
echo "=== 阶段1汇编生成 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s"
mkdir -p "$(dirname "$output_file")"
"$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file"
if [ $? -eq 0 ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}${NC} $relative_path"
((pass_gen++))
else
echo -e " ${RED}${NC} $relative_path"
((fail_gen++))
fi
done
echo ""
echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---"
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
stem="${relative_path%.sy}"
asm_file="$TEST_RESULT_DIR/${stem}.s"
exe_file="$TEST_RESULT_DIR/${stem}"
expected_file="${test_file%.sy}.out"
if [ ! -s "$asm_file" ]; then
echo -e " ${YELLOW}${NC} $relative_path (跳过)"
continue
fi
# 链接
if [ -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -static "$asm_file" "$SYLIB_O" -o "$exe_file" -no-pie 2>/dev/null
else
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null
fi
if [ $? -ne 0 ]; then
echo -e " ${RED}${NC} $relative_path (链接失败)"
((fail_run++))
continue
fi
# 运行程序 - 修改:丢弃 stderr只捕获 stdout
input_file="${test_file%.sy}.in"
tmp_out=$(mktemp)
if [ -f "$input_file" ]; then
timeout 10 qemu-riscv64 "$exe_file" < "$input_file" > "$tmp_out" 2>/dev/null
else
timeout 10 qemu-riscv64 "$exe_file" > "$tmp_out" 2>/dev/null
fi
exit_code=$?
if [ $exit_code -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
rm -f "$tmp_out"
continue
fi
# 直接读取输出文件,不做任何处理
program_output=$(cat "$tmp_out" | tr -d '\n')
rm -f "$tmp_out"
if [ -f "$expected_file" ]; then
expected=$(cat "$expected_file" | tr -d '\n')
if [[ "$expected" =~ ^[0-9]+$ ]] && [ "$expected" -ge 0 ] && [ "$expected" -le 255 ] && [ -z "$program_output" ]; then
# 期望退出码(且没有输出)
if [ $exit_code -eq "$expected" ] 2>/dev/null; then
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)"
((fail_run++))
fi
else
# 期望输出内容
if [ "$program_output" = "$expected" ]; then
echo -e " ${GREEN}${NC} $relative_path (输出匹配)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (输出不匹配: 期望 '$expected', 实际 '$program_output')"
((fail_run++))
fi
fi
else
# 没有期望文件
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code, 输出: '$program_output')"
((pass_run++))
fi
done
echo ""
echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---"
echo ""
echo "=========================================="
echo "测试完成"
echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen"
echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt"
echo "=========================================="

@ -1,58 +1,308 @@
#!/bin/bash #!/usr/bin/env bash
# 串行执行IR测试脚本实时输出结果
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd) PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir"
if [ ! -x "$COMPILER" ]; then # 默认参数
echo "错误:编译器不存在或不可执行: $COMPILER" TEST_CASE_DIR="${PROJECT_ROOT}/test/test_case"
echo "请先构建项目cmake --build build -j\$(nproc)" TEST_RESULT_DIR="${PROJECT_ROOT}/test/test_result/ir"
RUN_EXEC=false
VERBOSE=false
OPT_FLAG="-O1" # 默认开启优化
# 解析命令行参数
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
RUN_EXEC=true
;;
--verbose|-v)
VERBOSE=true
;;
--O0)
OPT_FLAG=""
;;
--O1)
OPT_FLAG="-O1"
;;
--test-dir=*)
TEST_CASE_DIR="${1#*=}"
;;
--result-dir=*)
TEST_RESULT_DIR="${1#*=}"
;;
*)
echo "未知参数: $1" >&2
echo "用法: $0 [--run] [--O0|--O1] [--verbose] [--test-dir=<dir>] [--result-dir=<dir>]" >&2
exit 1
;;
esac
shift
done
# 检查编译器是否存在
compiler="${PROJECT_ROOT}/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "错误:未找到编译器 $compiler" >&2
echo "请先构建项目: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1 exit 1
fi fi
# 创建输出目录
mkdir -p "$TEST_RESULT_DIR" mkdir -p "$TEST_RESULT_DIR"
pass_count=0 # 统计变量
fail_count=0 total_tests=0
failed_cases=() passed_tests=0
failed_tests=0
# 汇总日志文件
summary_log="${TEST_RESULT_DIR}/summary.log"
> "$summary_log"
echo "=== 开始测试 IR 生成 ===" # 失败测试列表
failed_list=""
echo "=== 开始IR测试 ==="
echo "测试目录: $TEST_CASE_DIR"
echo "结果目录: $TEST_RESULT_DIR"
echo "优化级别: ${OPT_FLAG:--O0}"
echo "运行可执行文件: $RUN_EXEC"
echo "" echo ""
while IFS= read -r test_file; do # 串行遍历所有测试用例
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") while read -r test_file; do
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll" total_tests=$((total_tests + 1))
mkdir -p "$(dirname "$output_file")" # 计算相对路径
full_path=$(readlink -f "$test_file")
test_case_path=$(readlink -f "$TEST_CASE_DIR")
relative_path="${full_path#$test_case_path}"
# 确保路径以 / 开头
if [[ "${relative_path:0:1}" != "/" ]]; then
relative_path="/$relative_path"
fi
echo -n "测试: $relative_path ... " # 计算输出文件路径
base=$(basename "$test_file")
stem="${base%.sy}"
output_file="${TEST_RESULT_DIR}/${relative_path%.sy}.ll"
output_dir=$(dirname "$output_file")
"$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1 # 创建输出目录
exit_code=$? mkdir -p "$output_dir"
if [ $exit_code -eq 0 ] && [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then # 获取输入和预期输出文件路径
echo "通过" input_dir=$(dirname "$test_file")
pass_count=$((pass_count + 1)) stdin_file="${input_dir}/${stem}.in"
expected_file="${input_dir}/${stem}.out"
# 每个测试用例的详细日志文件
test_log="${output_dir}/${stem}.log"
> "$test_log"
echo "[$total_tests] 处理: $relative_path"
echo "[$(date '+%Y-%m-%d %H:%M:%S')] 开始处理: $relative_path" >> "$test_log"
echo "输入文件: $test_file" >> "$test_log"
echo "输出目录: $output_dir" >> "$test_log"
# 生成IR
if $VERBOSE; then
echo " 生成IR..."
fi
echo "步骤1: 生成IR" >> "$test_log"
# 计时编译
compile_start=$(date +%s.%N)
set +e
"$compiler" "$test_file" -IR -o "$output_file" $OPT_FLAG 2>&1
ir_status=$?
set -e
compile_end=$(date +%s.%N)
compile_time=$(python3 -c "print(f'{float($compile_end)-float($compile_start):.3f}s')" 2>/dev/null || echo "-")
if [[ $ir_status -ne 0 ]]; then
echo " ✗ IR生成失败"
echo "结果: FAILED (IR生成失败)" >> "$test_log"
echo "错误信息:" >> "$test_log"
cat "$output_file" >> "$test_log"
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - IR生成失败"
if $VERBOSE; then
cat "$output_file"
fi
echo ""
continue
fi
echo "IR文件: $output_file" >> "$test_log"
if $VERBOSE; then
echo " ✓ IR已生成: $output_file"
fi
# 如果需要运行可执行文件
if [[ "$RUN_EXEC" == true ]]; then
echo "步骤2: 编译和运行" >> "$test_log"
if ! command -v llc >/dev/null 2>&1; then
echo " 警告: 未找到 llc跳过执行测试" >&2
echo "警告: 未找到 llc跳过执行测试" >> "$test_log"
echo "结果: SKIPPED (缺少llc)" >> "$test_log"
passed_tests=$((passed_tests + 1))
echo ""
continue
fi
if ! command -v clang >/dev/null 2>&1; then
echo " 警告: 未找到 clang跳过执行测试" >&2
echo "警告: 未找到 clang跳过执行测试" >> "$test_log"
echo "结果: SKIPPED (缺少clang)" >> "$test_log"
passed_tests=$((passed_tests + 1))
echo ""
continue
fi
obj="${output_dir}/${stem}.o"
exe="${output_dir}/${stem}"
stdout_file="${output_dir}/${stem}.stdout"
actual_file="${output_dir}/${stem}.actual.out"
# 编译IR为目标文件
if $VERBOSE; then
echo " 编译IR..."
fi
echo "编译IR: llc -filetype=obj $output_file -o $obj" >> "$test_log"
llc -filetype=obj "$output_file" -o "$obj" 2>/dev/null
if [[ $? -ne 0 ]]; then
echo " ✗ IR编译失败"
echo "结果: FAILED (IR编译失败)" >> "$test_log"
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - IR编译失败"
echo ""
continue
fi
echo "目标文件: $obj" >> "$test_log"
# 链接为可执行文件
echo "链接: clang $obj ${PROJECT_ROOT}/sylib/sylib.c -o $exe -lm" >> "$test_log"
clang "$obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$exe" -lm 2>/dev/null
if [[ $? -ne 0 ]]; then
echo " ✗ 链接失败"
echo "结果: FAILED (链接失败)" >> "$test_log"
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - 链接失败"
echo ""
continue
fi
echo "可执行文件: $exe" >> "$test_log"
# 运行可执行文件
if $VERBOSE; then
echo " 运行..."
fi
echo "运行命令: $exe" >> "$test_log"
if [[ -f "$stdin_file" ]]; then
echo "标准输入: $stdin_file" >> "$test_log"
fi
run_start=$(date +%s.%N)
set +e
if [[ -f "$stdin_file" ]]; then
(ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file"
else else
echo "失败" (ulimit -s unlimited; "$exe") > "$stdout_file"
fail_count=$((fail_count + 1)) fi
failed_cases+=("$relative_path") status=$?
echo " 错误信息已保存到: $output_file" set -e
run_end=$(date +%s.%N)
run_time=$(python3 -c "print(f'{float($run_end)-float($run_start):.3f}s')" 2>/dev/null || echo "-")
echo "退出码: $status" >> "$test_log"
echo "标准输出:" >> "$test_log"
cat "$stdout_file" >> "$test_log"
# 保存实际输出(包含退出码)
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
# 比对输出
echo "步骤3: 比对输出" >> "$test_log"
if [[ -f "$expected_file" ]]; then
echo "预期输出: $expected_file" >> "$test_log"
echo "实际输出: $actual_file" >> "$test_log"
if diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \
<(tr -d '\r' < "$actual_file" | sed -e '$a\') > /dev/null 2>&1; then
echo " ✓ 编译:${compile_time} 运行:${run_time}"
echo "结果: PASSED (输出匹配)" >> "$test_log"
passed_tests=$((passed_tests + 1))
else
echo " ✗ 输出不匹配"
echo "结果: FAILED (输出不匹配)" >> "$test_log"
echo "差异:" >> "$test_log"
diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \
<(tr -d '\r' < "$actual_file" | sed -e '$a\') >> "$test_log" 2>&1 || true
failed_tests=$((failed_tests + 1))
failed_list="$failed_list\n[$total_tests] $relative_path - 输出不匹配"
if $VERBOSE; then
echo " 预期:"
cat "$expected_file"
echo " 实际:"
cat "$actual_file"
fi
fi
else
echo " ? 无预期输出文件,跳过比对 (编译:${compile_time})"
echo "警告: 无预期输出文件,跳过比对" >> "$test_log"
echo "结果: SKIPPED (无预期输出)" >> "$test_log"
passed_tests=$((passed_tests + 1))
fi
else
echo "步骤2: 跳过执行测试 (--run未启用)" >> "$test_log"
echo "结果: PASSED (仅IR生成)" >> "$test_log"
passed_tests=$((passed_tests + 1))
echo " ✓ 编译:${compile_time}"
fi fi
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
echo "" echo ""
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
# 输出统计结果到终端
echo "=== 测试完成 ===" echo "=== 测试完成 ==="
echo "通过: $pass_count" echo "总测试数: $total_tests"
echo "失败: $fail_count" echo "通过: $passed_tests"
echo "结果保存在: $TEST_RESULT_DIR" echo "失败: $failed_tests"
if [ ${#failed_cases[@]} -gt 0 ]; then # 写入汇总日志
echo "" echo "=== IR测试汇总报告 ===" > "$summary_log"
echo "=== 失败的用例 ===" echo "测试时间: $(date '+%Y-%m-%d %H:%M:%S')" >> "$summary_log"
for f in "${failed_cases[@]}"; do echo "测试目录: $TEST_CASE_DIR" >> "$summary_log"
echo " - $f" echo "结果目录: $TEST_RESULT_DIR" >> "$summary_log"
done echo "运行可执行文件: $RUN_EXEC" >> "$summary_log"
echo "" >> "$summary_log"
echo "=== 统计结果 ===" >> "$summary_log"
echo "总测试数: $total_tests" >> "$summary_log"
echo "通过: $passed_tests" >> "$summary_log"
echo "失败: $failed_tests" >> "$summary_log"
echo "成功率: $((passed_tests * 100 / total_tests))%" >> "$summary_log"
echo "" >> "$summary_log"
if [[ $failed_tests -gt 0 ]]; then
echo "=== 失败测试列表 ===" >> "$summary_log"
echo -e "$failed_list" >> "$summary_log"
fi
echo "详细日志已保存到各测试用例目录"
echo "汇总日志: $summary_log"
if [[ $failed_tests -eq 0 ]]; then
echo "所有测试通过!"
exit 0
else
echo "$failed_tests 个测试失败"
exit 1 exit 1
fi fi

@ -3,6 +3,8 @@
set -euo pipefail set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2 echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1 exit 1
@ -31,7 +33,7 @@ if [[ ! -f "$input" ]]; then
exit 1 exit 1
fi fi
compiler="./build/bin/compiler" compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2 echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1 exit 1
@ -43,7 +45,7 @@ stem=${base%.sy}
out_file="$out_dir/$stem.ll" out_file="$out_dir/$stem.ll"
stdin_file="$input_dir/$stem.in" stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out" expected_file="$input_dir/$stem.out"
"$compiler" --emit-ir "$input" > "$out_file" "$compiler" "$input" -IR -o "$out_file" -O1
echo "IR 已生成: $out_file" echo "IR 已生成: $out_file"
if [[ "$run_exec" == true ]]; then if [[ "$run_exec" == true ]]; then
@ -60,13 +62,13 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj" llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe" clang "$obj" "$PROJECT_ROOT/sylib/sylib.c" -o "$exe" -lm
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file" (ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file"
else else
"$exe" > "$stdout_file" (ulimit -s unlimited; "$exe") > "$stdout_file"
fi fi
status=$? status=$?
set -e set -e
@ -81,9 +83,11 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file" } > "$actual_file"
if [[ -f "$expected_file" ]]; then if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then if diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \
<(tr -d '\r' < "$actual_file" | sed -e '$a\') > /dev/null 2>&1; then
echo "输出匹配: $expected_file" echo "输出匹配: $expected_file"
else else
diff -u "$expected_file" "$actual_file" || true
echo "输出不匹配: $expected_file" >&2 echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2 echo "实际输出已保存: $actual_file" >&2
exit 1 exit 1

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/mir"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
mir_file="$out_dir/$stem.mir"
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
# 生成 MIR
"$compiler" --emit-mir "$input" > "$mir_file"
echo "MIR 已生成: $mir_file"
# 生成汇编
"$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file"
if [[ "$run_exec" == true ]]; then
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 riscv64-linux-gnu-gcc" >&2
exit 1
fi
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "未找到 qemu-riscv64" >&2
exit 1
fi
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe" -no-pie
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
qemu-riscv64 "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-riscv64 "$exe" > "$stdout_file"
fi
status=$?
set -e
cat "$stdout_file"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/riscv_asm"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 riscv64-linux-gnu-gcc无法汇编/链接。" >&2
echo "请安装: sudo apt install gcc-riscv64-linux-gnu" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" 2>/dev/null > "$asm_file"
echo "汇编已生成: $asm_file"
# 使用静态链接避免动态链接器问题
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe" -no-pie
echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "未找到 qemu-riscv64无法运行生成的可执行文件。" >&2
echo "请安装: sudo apt install qemu-user" >&2
exit 1
fi
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
qemu-riscv64 "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-riscv64 "$exe" > "$stdout_file"
fi
status=$?
set -e
cat "$stdout_file"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi

@ -21,6 +21,7 @@ if(NOT COMPILER_PARSE_ONLY)
target_link_libraries(compiler PRIVATE target_link_libraries(compiler PRIVATE
sem sem
irgen irgen
ir
mir mir
) )
target_compile_definitions(compiler PRIVATE COMPILER_PARSE_ONLY=0) target_compile_definitions(compiler PRIVATE COMPILER_PARSE_ONLY=0)

@ -9,6 +9,7 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <algorithm>
#include <utility> #include <utility>
namespace ir { namespace ir {
@ -42,4 +43,53 @@ const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_; return successors_;
} }
void BasicBlock::AddPredecessor(BasicBlock* bb) {
if (bb) predecessors_.push_back(bb);
}
void BasicBlock::RemovePredecessor(BasicBlock* bb) {
predecessors_.erase(
std::remove(predecessors_.begin(), predecessors_.end(), bb),
predecessors_.end());
}
void BasicBlock::ClearPredecessors() { predecessors_.clear(); }
void BasicBlock::AddSuccessor(BasicBlock* bb) {
if (bb) successors_.push_back(bb);
}
void BasicBlock::ClearSuccessors() { successors_.clear(); }
void BasicBlock::RemoveInstruction(Instruction* inst) {
for (auto it = instructions_.begin(); it != instructions_.end(); ++it) {
if (it->get() == inst) {
instructions_.erase(it);
return;
}
}
}
void BasicBlock::InsertBefore(Instruction* inst, Instruction* before) {
if (!before) {
// append (respecting terminator)
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
instructions_.insert(instructions_.end() - 1,
std::unique_ptr<Instruction>(inst));
} else {
instructions_.push_back(std::unique_ptr<Instruction>(inst));
}
} else {
for (auto it = instructions_.begin(); it != instructions_.end(); ++it) {
if (it->get() == before) {
instructions_.insert(it, std::unique_ptr<Instruction>(inst));
return;
}
}
// before not found, append instead
instructions_.push_back(std::unique_ptr<Instruction>(inst));
}
inst->SetParent(this);
}
} // namespace ir } // namespace ir

@ -29,4 +29,10 @@ std::string Context::NextTemp() {
return oss.str(); return oss.str();
} }
std::string Context::NextLabel() {
std::ostringstream oss;
oss << "L" << ++label_index_;
return oss.str();
}
} // namespace ir } // namespace ir

@ -1,6 +1,8 @@
// IR Function // IR Function
#include "ir/IR.h" #include "ir/IR.h"
#include <algorithm>
namespace ir { namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type) Function::Function(std::string name, std::shared_ptr<Type> ret_type)
@ -17,6 +19,17 @@ BasicBlock* Function::CreateBlock(const std::string& name) {
return ptr; return ptr;
} }
void Function::MoveBlockToEnd(BasicBlock* bb) {
for (size_t i = 0; i < blocks_.size(); ++i) {
if (blocks_[i].get() == bb) {
auto tmp = std::move(blocks_[i]);
blocks_.erase(blocks_.begin() + i);
blocks_.push_back(std::move(tmp));
return;
}
}
}
BasicBlock* Function::GetEntry() { return entry_; } BasicBlock* Function::GetEntry() { return entry_; }
const BasicBlock* Function::GetEntry() const { return entry_; } const BasicBlock* Function::GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const { const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
@ -34,4 +47,84 @@ Argument* Function::GetArgument(size_t i) const {
return args_[i].get(); return args_[i].get();
} }
void Function::RebuildCFG() {
// 清除所有块的前驱/后继
for (auto& bb : blocks_) {
bb->ClearPredecessors();
bb->ClearSuccessors();
}
// 根据终结指令重新计算
for (auto& bb : blocks_) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (!term->IsTerminator()) continue;
switch (term->GetOpcode()) {
case Opcode::Br: {
auto* target = static_cast<BrInst*>(term)->GetTarget();
bb->AddSuccessor(target);
target->AddPredecessor(bb.get());
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<CondBrInst*>(term);
auto* t = cbr->GetTrueBB();
auto* f = cbr->GetFalseBB();
bb->AddSuccessor(t);
bb->AddSuccessor(f);
t->AddPredecessor(bb.get());
f->AddPredecessor(bb.get());
break;
}
case Opcode::Ret:
// 无后继
break;
default:
break;
}
}
}
void Function::RemoveBlock(BasicBlock* bb) {
if (entry_ == bb) return;
// 步骤1清除所有 PHI 节点中对本块的引用
for (auto& other_bb : blocks_) {
if (other_bb.get() == bb) continue;
for (auto& inst : other_bb->GetInstructions()) {
if (auto* phi = dynamic_cast<PhiInst*>(inst.get())) {
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == bb) {
phi->SetOperand(i * 2, nullptr); // value
phi->SetOperand(i * 2 + 1, nullptr); // block
}
}
}
}
}
// 步骤2将块内所有指令"未定义化",用 undef (nullptr) 替换所有使用
// 这样其他引用这些指令的 place 会变成 null后续 SanitizePhis 会修复
for (auto& inst : bb->GetInstructions()) {
inst->ReplaceAllUsesWith(nullptr);
}
// 步骤3断开本块指令对操作数的引用
for (auto& inst : bb->GetInstructions()) {
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
inst->SetOperand(i, nullptr);
}
}
// 步骤4从 blocks_ 中移除
blocks_.erase(
std::remove_if(blocks_.begin(), blocks_.end(),
[bb](const std::unique_ptr<BasicBlock>& b) {
return b.get() == bb;
}),
blocks_.end());
}
} // namespace ir } // namespace ir

@ -136,6 +136,15 @@ AllocaInst* IRBuilder::CreateAllocaArray(int num_elements,
num_elements, name); num_elements, name);
} }
AllocaInst* IRBuilder::CreateAllocaArrayF32(int num_elements,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrFloat32Type(),
num_elements, name);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index, GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index,
const std::string& name) { const std::string& name) {
if (!insert_block_) { if (!insert_block_) {
@ -237,4 +246,28 @@ FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) {
return insert_block_->Append<FPToSIInst>(val, name); return insert_block_->Append<FPToSIInst>(val, name);
} }
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Prepend<PhiInst>(std::move(ty), name);
}
void IRBuilder::CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
// declare memset if not already declared
if (!mod.HasExternalDecl("memset")) {
mod.DeclareExternalFunc("memset", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
}
int byte_count = num_elements * 4;
insert_block_->Append<CallInst>(
std::string("memset"), Type::GetVoidType(),
std::vector<Value*>{ptr, ctx.GetConstInt(0), ctx.GetConstInt(byte_count)},
std::string(""));
}
} // namespace ir } // namespace ir

@ -1,8 +1,13 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <ostream> #include <ostream>
#include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <algorithm>
#include <unordered_map>
#include "utils/Log.h" #include "utils/Log.h"
@ -44,201 +49,141 @@ static const char* FPredToStr(FCmpPredicate pred) {
return "?"; return "?";
} }
static std::string ValStr(const Value* v) { using RenameMap = std::unordered_map<const Value*, int>;
static std::string ValStr(const Value* v, const RenameMap& rm) {
if (!v) return "<null>"; if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) { if (dynamic_cast<const ConstantInt*>(v))
return std::to_string(ci->GetValue()); return std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) { if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::to_string(cf->GetValue()); double d = static_cast<double>(cf->GetValue());
} uint64_t bits;
// BasicBlock: 打印为 label %name std::memcpy(&bits, &d, sizeof(bits));
if (dynamic_cast<const BasicBlock*>(v)) { std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
return oss.str();
}
if (dynamic_cast<const BasicBlock*>(v))
return "%" + v->GetName(); return "%" + v->GetName();
}
// GlobalVariable: 打印为 @name
if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) { if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) {
if (gv->IsArray()) { if (gv->IsArray()) {
// 数组全局变量的指针getelementptr [N x i32], [N x i32]* @name, i32 0, i32 0 const char* et = gv->IsFloat() ? "float" : "i32";
return "getelementptr ([" + std::to_string(gv->GetNumElements()) + return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) +
" x i32], [" + std::to_string(gv->GetNumElements()) + " x " + et + "], [" + std::to_string(gv->GetNumElements()) +
" x i32]* @" + gv->GetName() + ", i32 0, i32 0)"; " x " + et + "]* @" + gv->GetName() + ", i32 0, i32 0)";
} }
return "@" + v->GetName(); return "@" + v->GetName();
} }
auto it = rm.find(v);
if (it != rm.end()) return "%" + std::to_string(it->second);
return "%" + v->GetName(); return "%" + v->GetName();
} }
static std::string TypeVal(const Value* v) { static std::string TypeVal(const Value* v, const RenameMap& rm) {
if (!v) return "void"; if (!v) return "void";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) { if (dynamic_cast<const ConstantInt*>(v))
return std::string(TypeToStr(*ci->GetType())) + " " + return std::string(TypeToStr(*v->GetType())) + " " +
std::to_string(ci->GetValue()); std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) { if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::string(TypeToStr(*cf->GetType())) + " " + double d = static_cast<double>(cf->GetValue());
std::to_string(cf->GetValue()); uint64_t bits;
} std::memcpy(&bits, &d, sizeof(bits));
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v); std::ostringstream oss;
} oss << "float 0x" << std::hex << std::uppercase << bits;
return oss.str();
void IRPrinter::Print(const Module& module, std::ostream& os) {
// 1. 全局变量/常量
for (const auto& g : module.GetGlobalVariables()) {
if (g->IsArray()) {
// 全局数组zeroinitializer
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
} else {
os << "@" << g->GetName() << " = global [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
}
} else {
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant i32 " << g->GetInitVal()
<< "\n";
} else {
os << "@" << g->GetName() << " = global i32 " << g->GetInitVal()
<< "\n";
}
}
}
if (!module.GetGlobalVariables().empty()) os << "\n";
// 2. 外部函数声明
for (const auto& decl : module.GetExternalDecls()) {
os << "declare " << TypeToStr(*decl.ret_type) << " @" << decl.name << "(";
for (size_t i = 0; i < decl.param_types.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToStr(*decl.param_types[i]);
}
if (decl.is_variadic) {
if (!decl.param_types.empty()) os << ", ";
os << "...";
} }
os << ")\n"; return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v, rm);
} }
if (!module.GetExternalDecls().empty()) os << "\n";
// 3. 函数定义 // Print one instruction (non-alloca) using rename map
for (const auto& func : module.GetFunctions()) { static void PrintInst(const Instruction* inst, std::ostream& os,
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName() const RenameMap& rm) {
<< "("; auto N = [&](const Value* v) -> std::string {
for (size_t i = 0; i < func->GetNumArgs(); ++i) { auto it = rm.find(v);
if (i > 0) os << ", "; if (it != rm.end()) return std::to_string(it->second);
auto* arg = func->GetArgument(i); return v->GetName();
os << TypeToStr(*arg->GetType()) << " %" << arg->GetName(); };
} auto VS = [&](const Value* v) { return ValStr(v, rm); };
os << ") {\n"; auto TV = [&](const Value* v) { return TypeVal(v, rm); };
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
os << bb->GetName() << ":\n";
for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get();
switch (inst->GetOpcode()) { switch (inst->GetOpcode()) {
// ── 算术 ────────────────────────────────────────────────────────── case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Add: case Opcode::Div: case Opcode::Mod: {
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst); auto* bin = static_cast<const BinaryInst*>(inst);
const char* op_str = nullptr; const char* op = nullptr;
switch (bin->GetOpcode()) { switch (bin->GetOpcode()) {
case Opcode::Add: op_str = "add"; break; case Opcode::Add: op = "add"; break;
case Opcode::Sub: op_str = "sub"; break; case Opcode::Sub: op = "sub"; break;
case Opcode::Mul: op_str = "mul"; break; case Opcode::Mul: op = "mul"; break;
case Opcode::Div: op_str = "sdiv"; break; case Opcode::Div: op = "sdiv"; break;
case Opcode::Mod: op_str = "srem"; break; case Opcode::Mod: op = "srem"; break;
default: op_str = "?"; break; default: op = "?"; break;
} }
os << " %" << bin->GetName() << " = " << op_str << " i32 " os << " %" << N(bin) << " = " << op << " i32 "
<< ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs()) << VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
<< "\n";
break; break;
} }
// ── 浮点算术 ────────────────────────────────────────────────────── case Opcode::FAdd: case Opcode::FSub:
case Opcode::FAdd: case Opcode::FMul: case Opcode::FDiv: {
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst); auto* bin = static_cast<const BinaryInst*>(inst);
const char* op_str = nullptr; const char* op = nullptr;
switch (bin->GetOpcode()) { switch (bin->GetOpcode()) {
case Opcode::FAdd: op_str = "fadd"; break; case Opcode::FAdd: op = "fadd"; break;
case Opcode::FSub: op_str = "fsub"; break; case Opcode::FSub: op = "fsub"; break;
case Opcode::FMul: op_str = "fmul"; break; case Opcode::FMul: op = "fmul"; break;
case Opcode::FDiv: op_str = "fdiv"; break; case Opcode::FDiv: op = "fdiv"; break;
default: op_str = "?"; break; default: op = "?"; break;
} }
os << " %" << bin->GetName() << " = " << op_str << " float " os << " %" << N(bin) << " = " << op << " float "
<< ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs()) << VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
<< "\n";
break; break;
} }
// ── 比较 ──────────────────────────────────────────────────────────
case Opcode::ICmp: { case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst); auto* cmp = static_cast<const ICmpInst*>(inst);
os << " %" << cmp->GetName() << " = icmp " os << " %" << N(cmp) << " = icmp " << PredToStr(cmp->GetPredicate())
<< PredToStr(cmp->GetPredicate()) << " i32 " << " i32 " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
<< ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs())
<< "\n";
break; break;
} }
case Opcode::FCmp: { case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst); auto* cmp = static_cast<const FCmpInst*>(inst);
os << " %" << cmp->GetName() << " = fcmp " os << " %" << N(cmp) << " = fcmp " << FPredToStr(cmp->GetPredicate())
<< FPredToStr(cmp->GetPredicate()) << " float " << " float " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
<< ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs())
<< "\n";
break; break;
} }
// ── 内存 ──────────────────────────────────────────────────────────
case Opcode::Alloca: { case Opcode::Alloca: {
auto* al = static_cast<const AllocaInst*>(inst); auto* al = static_cast<const AllocaInst*>(inst);
const char* elem_type = al->GetType()->IsPtrFloat32() ? "float" : "i32"; const char* et = al->GetType()->IsPtrFloat32() ? "float" : "i32";
if (al->IsArray()) { if (al->IsArray())
os << " %" << al->GetName() << " = alloca " << elem_type << ", i32 " os << " %" << N(al) << " = alloca " << et << ", i32 " << al->GetNumElements() << "\n";
<< al->GetNumElements() << "\n"; else
} else { os << " %" << N(al) << " = alloca " << et << "\n";
os << " %" << al->GetName() << " = alloca " << elem_type << "\n";
}
break; break;
} }
case Opcode::Gep: { case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst); auto* gep = static_cast<const GepInst*>(inst);
os << " %" << gep->GetName() bool fp = gep->GetBasePtr()->GetType()->IsPtrFloat32();
<< " = getelementptr i32, i32* " os << " %" << N(gep) << " = getelementptr " << (fp ? "float" : "i32")
<< ValStr(gep->GetBasePtr()) << ", i32 " << ", " << (fp ? "float*" : "i32*") << " "
<< ValStr(gep->GetIndex()) << "\n"; << VS(gep->GetBasePtr()) << ", i32 " << VS(gep->GetIndex()) << "\n";
break; break;
} }
case Opcode::Load: { case Opcode::Load: {
auto* ld = static_cast<const LoadInst*>(inst); auto* ld = static_cast<const LoadInst*>(inst);
const char* val_type = ld->GetType()->IsFloat32() ? "float" : "i32"; bool fp = ld->GetPtr()->GetType()->IsPtrFloat32();
const char* ptr_type = ld->GetPtr()->GetType()->IsPtrFloat32() ? "float*" : "i32*"; os << " %" << N(ld) << " = load " << (fp ? "float" : "i32")
os << " %" << ld->GetName() << " = load " << val_type << ", " << ptr_type << " " << ", " << (fp ? "float*" : "i32*") << " " << VS(ld->GetPtr()) << "\n";
<< ValStr(ld->GetPtr()) << "\n";
break; break;
} }
case Opcode::Store: { case Opcode::Store: {
auto* st = static_cast<const StoreInst*>(inst); auto* st = static_cast<const StoreInst*>(inst);
os << " store " << TypeVal(st->GetValue()) << ", " os << " store " << TV(st->GetValue()) << ", "
<< TypeToStr(*st->GetPtr()->GetType()) << " " << TypeToStr(*st->GetPtr()->GetType()) << " " << VS(st->GetPtr()) << "\n";
<< ValStr(st->GetPtr()) << "\n";
break; break;
} }
// ── 控制流 ────────────────────────────────────────────────────────
case Opcode::Ret: { case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst); auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->HasValue()) { if (!ret->HasValue()) os << " ret void\n";
os << " ret void\n"; else os << " ret " << TV(ret->GetValue()) << "\n";
} else {
auto* v = ret->GetValue();
os << " ret " << TypeVal(v) << "\n";
}
break; break;
} }
case Opcode::Br: { case Opcode::Br: {
@ -248,57 +193,195 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
} }
case Opcode::CondBr: { case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst); auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << ValStr(cbr->GetCond()) << ", label %" os << " br i1 " << VS(cbr->GetCond()) << ", label %"
<< cbr->GetTrueBB()->GetName() << ", label %" << cbr->GetTrueBB()->GetName() << ", label %"
<< cbr->GetFalseBB()->GetName() << "\n"; << cbr->GetFalseBB()->GetName() << "\n";
break; break;
} }
// ── 调用 ──────────────────────────────────────────────────────────
case Opcode::Call: { case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst); auto* call = static_cast<const CallInst*>(inst);
std::string ret_type_str; if (!call->IsVoid() && !call->GetName().empty())
if (call->IsVoid()) { os << " %" << N(call) << " = ";
ret_type_str = "void"; else
} else {
ret_type_str = TypeToStr(*call->GetType());
}
// 打印赋值部分(仅当有返回值时)
if (!call->IsVoid() && !call->GetName().empty()) {
os << " %" << call->GetName() << " = ";
} else {
os << " "; os << " ";
} os << "call " << (call->IsVoid() ? "void" : TypeToStr(*call->GetType()))
os << "call " << ret_type_str << " @" << call->GetCalleeName() << " @" << call->GetCalleeName() << "(";
<< "(";
for (size_t i = 0; i < call->GetNumArgs(); ++i) { for (size_t i = 0; i < call->GetNumArgs(); ++i) {
if (i > 0) os << ", "; if (i > 0) os << ", ";
auto* arg = call->GetArg(i); os << TV(call->GetArg(i));
os << TypeVal(arg);
} }
os << ")\n"; os << ")\n";
break; break;
} }
// ── 类型转换 ──────────────────────────────────────────────────────
case Opcode::ZExt: { case Opcode::ZExt: {
auto* ze = static_cast<const ZExtInst*>(inst); auto* ze = static_cast<const ZExtInst*>(inst);
os << " %" << ze->GetName() << " = zext i1 " os << " %" << N(ze) << " = zext i1 " << VS(ze->GetSrc()) << " to i32\n";
<< ValStr(ze->GetSrc()) << " to i32\n";
break; break;
} }
case Opcode::SIToFP: { case Opcode::SIToFP: {
auto* si = static_cast<const SIToFPInst*>(inst); auto* si = static_cast<const SIToFPInst*>(inst);
os << " %" << si->GetName() << " = sitofp i32 " os << " %" << N(si) << " = sitofp i32 " << VS(si->GetSrc()) << " to float\n";
<< ValStr(si->GetSrc()) << " to float\n";
break; break;
} }
case Opcode::FPToSI: { case Opcode::FPToSI: {
auto* fp = static_cast<const FPToSIInst*>(inst); auto* fp = static_cast<const FPToSIInst*>(inst);
os << " %" << fp->GetName() << " = fptosi float " os << " %" << N(fp) << " = fptosi float " << VS(fp->GetSrc()) << " to i32\n";
<< ValStr(fp->GetSrc()) << " to i32\n"; break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " %" << N(phi) << " = phi " << TypeToStr(*phi->GetType()) << " ";
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (i > 0) os << ", ";
os << "[ " << VS(phi->GetIncomingValue(i)) << ", %"
<< phi->GetIncomingBlock(i)->GetName() << " ]";
}
os << "\n";
break; break;
} }
} }
} }
void IRPrinter::Print(const Module& module, std::ostream& os) {
// 1. 全局变量/常量
for (const auto& gv : module.GetGlobalVariables()) {
if (!gv) continue;
if (gv->IsConstant()) {
os << "@" << gv->GetName() << " = constant i32 " << gv->GetInitVal() << "\n";
} else if (gv->IsArray()) {
const char* et = gv->IsFloat() ? "float" : "i32";
os << "@" << gv->GetName() << " = global [" << gv->GetNumElements()
<< " x " << et << "] ";
if (!gv->HasInitVals()) {
os << "zeroinitializer\n";
} else if (gv->IsFloat()) {
const auto& vals = gv->GetInitValsF();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](float f){ return f == 0.0f; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
float fv = (i < (int)vals.size()) ? vals[i] : 0.0f;
double d = static_cast<double>(fv);
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
os << oss.str();
}
os << "]\n";
}
} else {
const auto& vals = gv->GetInitVals();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](int v){ return v == 0; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
os << "i32 " << (i < (int)vals.size() ? vals[i] : 0);
}
os << "]\n";
}
}
} else {
if(gv->IsFloat()) {
double d = static_cast<double>(gv->GetInitValF());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
os << "@" << gv->GetName() << " = global float " << oss.str() << "\n";
} else
{
os << "@" << gv->GetName() << " = global i32 " << gv->GetInitVal() << "\n";
}
}
}
if (!module.GetGlobalVariables().empty()) os << "\n";
// 2. 外部声明
for (const auto& decl : module.GetExternalDecls()) {
os << "declare " << TypeToStr(*decl.ret_type) << " @" << decl.name << "(";
for (size_t i = 0; i < decl.param_types.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToStr(*decl.param_types[i]);
}
os << ")\n";
}
if (!module.GetExternalDecls().empty()) os << "\n";
// 3. 函数定义
for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName() << "(";
for (size_t i = 0; i < func->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
auto* arg = func->GetArgument(i);
os << TypeToStr(*arg->GetType()) << " %" << arg->GetName();
}
os << ") {\n";
// Build rename map: alloca instructions first (in block order), then rest
RenameMap rm;
int next_id = 0;
auto assign = [&](const Value* v) {
if (!v) return;
if (dynamic_cast<const ConstantInt*>(v)) return;
if (dynamic_cast<const ConstantFloat*>(v)) return;
if (dynamic_cast<const BasicBlock*>(v)) return;
if (dynamic_cast<const GlobalVariable*>(v)) return;
if (dynamic_cast<const Argument*>(v)) return;
if (rm.count(v) == 0) rm[v] = next_id++;
};
// Pass 1: all allocas across all blocks
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca) assign(ip.get());
}
// Pass 2: all non-alloca instructions in block order
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca) assign(ip.get());
}
// Print: entry block first with all allocas hoisted, then rest
bool first_bb = true;
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
os << bb->GetName() << ":\n";
if (first_bb) {
first_bb = false;
// Print all allocas from all blocks (only for entry block)
for (const auto& bb2 : func->GetBlocks()) {
if (!bb2) continue;
for (const auto& ip : bb2->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca)
PrintInst(ip.get(), os, rm);
}
// Print PHI nodes of entry block
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Phi)
PrintInst(ip.get(), os, rm);
// Print non-alloca non-phi instructions of entry block
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
PrintInst(ip.get(), os, rm);
} else {
// Non-entry blocks: skip allocas (already printed)
// Print PHI nodes first
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Phi)
PrintInst(ip.get(), os, rm);
// Print non-alloca non-phi instructions
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
PrintInst(ip.get(), os, rm);
}
} }
os << "}\n\n"; os << "}\n\n";
} }

@ -48,6 +48,12 @@ bool Instruction::IsTerminator() const {
opcode_ == Opcode::CondBr; opcode_ == Opcode::CondBr;
} }
void Instruction::RemoveFromParent() {
if (parent_) {
parent_->RemoveInstruction(this);
}
}
BasicBlock* Instruction::GetParent() const { return parent_; } BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
@ -224,7 +230,11 @@ AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, int num_elements,
// ─── GepInst ────────────────────────────────────────────────────────────────── // ─── GepInst ──────────────────────────────────────────────────────────────────
GepInst::GepInst(Value* base_ptr, Value* index, std::string name) GepInst::GepInst(Value* base_ptr, Value* index, std::string name)
: Instruction(Opcode::Gep, Type::GetPtrInt32Type(), std::move(name)) { : Instruction(Opcode::Gep,
(base_ptr && base_ptr->GetType()->IsPtrFloat32())
? Type::GetPtrFloat32Type()
: Type::GetPtrInt32Type(),
std::move(name)) {
if (!base_ptr || !index) { if (!base_ptr || !index) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数")); throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数"));
} }
@ -263,12 +273,30 @@ Value* StoreInst::GetPtr() const { return GetOperand(1); }
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name) GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {} : User(std::move(ty), std::move(name)) {}
// ─── PhiInst ──────────────────────────────────────────────────────────────────
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* val, BasicBlock* bb) {
AddOperand(val);
AddOperand(bb);
}
BasicBlock* PhiInst::GetIncomingBlock(size_t i) const {
return static_cast<BasicBlock*>(GetOperand(i * 2 + 1));
}
// ─── GlobalVariable ──────────────────────────────────────────────────────────── // ─── GlobalVariable ────────────────────────────────────────────────────────────
GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val, GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements) int num_elements, bool is_array_decl,
: Value(Type::GetPtrInt32Type(), std::move(name)), bool is_float)
: Value(is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type(),
std::move(name)),
is_const_(is_const), is_const_(is_const),
is_float_(is_float),
init_val_(init_val), init_val_(init_val),
num_elements_(num_elements) {} init_val_f_(0.0f),
num_elements_(num_elements),
is_array_decl_(is_array_decl) {}
} // namespace ir } // namespace ir

@ -28,9 +28,10 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
// ─── 全局变量管理 ───────────────────────────────────────────────────────────── // ─── 全局变量管理 ─────────────────────────────────────────────────────────────
GlobalVariable* Module::CreateGlobalVariable(const std::string& name, GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
bool is_const, int init_val, bool is_const, int init_val,
int num_elements) { int num_elements, bool is_array_decl,
bool is_float) {
globals_.push_back( globals_.push_back(
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements)); std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements, is_array_decl, is_float));
GlobalVariable* g = globals_.back().get(); GlobalVariable* g = globals_.back().get();
global_map_[name] = g; global_map_[name] = g;
return g; return g;

@ -1,4 +1,205 @@
// 支配树分析: // 支配树分析:
// - 构建/查询 Dominator Tree 及相关关系 // - 构建/查询 Dominator Tree 及相关关系
// - 为 mem2reg、CFG 优化与循环分析提供基础能力 // - 使用 Cooper-Harvey-Kennedy 算法,近线性时间复杂度
#include "ir/IR.h"
#include <algorithm>
#include <functional>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
void DominatorTree::Compute(Function& func) {
func.RebuildCFG();
// Build block list and reverse postorder (RPO)
std::vector<BasicBlock*> blocks;
std::unordered_map<BasicBlock*, int> rpo;
{
std::vector<BasicBlock*> rpo_vec;
std::unordered_set<BasicBlock*> visited;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* bb) {
if (!bb || visited.count(bb)) return;
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
dfs(succ);
}
rpo_vec.push_back(bb);
};
dfs(func.GetEntry());
// Reverse to get RPO (postorder reversed)
std::reverse(rpo_vec.begin(), rpo_vec.end());
blocks = rpo_vec;
for (int i = 0; i < (int)blocks.size(); ++i) {
rpo[blocks[i]] = i;
}
}
if (blocks.empty()) return;
int n = (int)blocks.size();
auto* entry = func.GetEntry();
if (!entry) return;
// ─── 1. CHK algorithm for immediate dominators ─────────────────────────
idom_.clear();
idom_[entry] = entry; // entry is its own dominator
// Intersect: find common ancestor of b1 and b2 walking up the dom tree
// Uses RPO number: a dominator always has lower RPO number
auto intersect = [&](BasicBlock* b1, BasicBlock* b2) -> BasicBlock* {
auto i1 = rpo.find(b1), i2 = rpo.find(b2);
if (i1 == rpo.end() || i2 == rpo.end()) return entry;
int r1 = i1->second, r2 = i2->second;
while (b1 != b2) {
while (r1 > r2) {
auto it = idom_.find(b1);
if (it == idom_.end() || it->second == b1) return b1;
b1 = it->second;
r1 = rpo[b1];
}
while (r2 > r1) {
auto it = idom_.find(b2);
if (it == idom_.end() || it->second == b2) return b2;
b2 = it->second;
r2 = rpo[b2];
}
}
return b1;
};
bool changed = true;
while (changed) {
changed = false;
// Process in RPO (skip entry which is first in RPO)
for (int i = 1; i < n; ++i) {
auto* bb = blocks[i];
// Find first predecessor with defined IDOM
BasicBlock* new_idom = nullptr;
for (auto* pred : bb->GetPredecessors()) {
if (idom_.count(pred) && pred != bb) {
new_idom = pred;
break;
}
}
if (!new_idom) continue;
// Intersect with remaining predecessors
for (auto* pred : bb->GetPredecessors()) {
if (pred == new_idom || pred == bb) continue;
if (idom_.count(pred)) {
new_idom = intersect(pred, new_idom);
}
}
auto old = idom_.find(bb);
if (old == idom_.end() || old->second != new_idom) {
idom_[bb] = new_idom;
changed = true;
}
}
}
// Entry is its own IDOM, set to nullptr for external queries
idom_[entry] = nullptr;
// Unreached blocks get entry as IDOM
for (auto* bb : blocks) {
if (!idom_.count(bb)) idom_[bb] = entry;
}
// ─── 2. Build children map and dom levels ──────────────────────────────
children_.clear();
dom_level_.clear();
for (auto& [child, parent] : idom_) {
if (parent) children_[parent].push_back(child);
}
// BFS to compute dom levels
std::queue<BasicBlock*> q;
dom_level_[entry] = 0;
q.push(entry);
while (!q.empty()) {
auto* cur = q.front();
q.pop();
size_t cur_level = dom_level_[cur];
auto it = children_.find(cur);
if (it != children_.end()) {
for (auto* child : it->second) {
dom_level_[child] = cur_level + 1;
q.push(child);
}
}
}
// ─── 3. Compute dominance frontier ─────────────────────────────────────
df_.clear();
for (int i = 0; i < n; ++i) {
auto* b = blocks[i];
if (b->GetPredecessors().size() < 2) continue;
for (auto* p : b->GetPredecessors()) {
auto* runner = p;
auto* b_idom = GetIDom(b);
while (runner != b_idom) {
if (!runner) break;
df_[runner].push_back(b);
runner = GetIDom(runner);
}
}
}
// Deduplicate DF entries
for (auto& [bb, vec] : df_) {
std::sort(vec.begin(), vec.end());
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
}
// ─── 4. Compute DFS order of dominator tree ────────────────────────────
df_order_.clear();
visited_.clear();
std::function<void(BasicBlock*)> dfs_tree = [&](BasicBlock* bb) {
if (!bb || visited_.count(bb)) return;
visited_.insert(bb);
df_order_.push_back(bb);
auto it = children_.find(bb);
if (it != children_.end()) {
for (auto* child : it->second) {
dfs_tree(child);
}
}
};
dfs_tree(entry);
}
BasicBlock* DominatorTree::GetIDom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return (it != idom_.end()) ? it->second : nullptr;
}
const std::vector<BasicBlock*>& DominatorTree::GetChildren(
BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = children_.find(bb);
return (it != children_.end()) ? it->second : empty;
}
const std::vector<BasicBlock*>& DominatorTree::GetDominanceFrontier(
BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return (it != df_.end()) ? it->second : empty;
}
bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const {
if (a == b) return true;
BasicBlock* runner = b;
while (runner) {
runner = GetIDom(runner);
if (runner == a) return true;
}
return false;
}
} // namespace ir

@ -2,3 +2,143 @@
// - 删除不可达块、合并空块、简化分支等 // - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成 // - 改善 IR 结构,便于后续优化与后端生成
#include "ir/IR.h"
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
// 简化常量条件跳转
bool SimplifyConstantBranches(Function& func, Context& ctx) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (auto* cbr = dynamic_cast<CondBrInst*>(term)) {
if (auto* ci = dynamic_cast<ConstantInt*>(cbr->GetCond())) {
BasicBlock* target =
(ci->GetValue() != 0) ? cbr->GetTrueBB() : cbr->GetFalseBB();
// 替换条件跳转为无条件跳转
cbr->RemoveFromParent();
bb->Append<BrInst>(target);
changed = true;
}
}
}
return changed;
}
// 合并空基本块:如果一个块只有一个 br 指令,可以绕过它
bool MergeEmptyBlocks(Function& func) {
bool changed = false;
bool local_changed = true;
// 迭代处理,因为合并一个空块可能产生新的空块或改变前驱关系
while (local_changed) {
local_changed = false;
func.RebuildCFG();
BasicBlock* block_to_remove = nullptr;
for (auto& bb : func.GetBlocks()) {
auto* block = bb.get();
if (block == func.GetEntry()) continue;
const auto& insts = block->GetInstructions();
if (insts.size() != 1) continue;
auto* br = dynamic_cast<BrInst*>(insts[0].get());
if (!br) continue;
auto* target = br->GetTarget();
if (target == block) continue;
// 不能合并目标也是空块的块(将在下一轮处理)
if (target->GetInstructions().size() == 1 &&
dynamic_cast<BrInst*>(target->GetInstructions()[0].get()) &&
target != func.GetEntry()) {
continue;
}
// 只重定向仍然引用此块的前驱
for (auto* pred : block->GetPredecessors()) {
if (pred == block) continue;
auto& p_insts = pred->GetInstructions();
if (p_insts.empty()) continue;
auto* p_term = p_insts.back().get();
if (auto* cbr = dynamic_cast<CondBrInst*>(p_term)) {
if (cbr->GetTrueBB() == block) cbr->SetOperand(1, target);
if (cbr->GetFalseBB() == block) cbr->SetOperand(2, target);
} else if (auto* p_br = dynamic_cast<BrInst*>(p_term)) {
if (p_br->GetTarget() == block) p_br->SetOperand(0, target);
}
// 更新 target 中引用 block 的 PHI 节点,改为引用 pred
for (auto& t_inst : target->GetInstructions()) {
if (auto* phi = dynamic_cast<PhiInst*>(t_inst.get())) {
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
phi->SetOperand(i * 2 + 1, pred);
}
}
}
}
}
block_to_remove = block;
local_changed = true;
changed = true;
break; // 只合并一个,下一轮迭代重建 CFG
}
if (block_to_remove) {
func.RemoveBlock(block_to_remove);
}
}
if (changed) func.RebuildCFG();
return changed;
}
// 删除不可达基本块
bool RemoveUnreachableBlocks(Function& func) {
func.RebuildCFG();
// BFS from entry
std::unordered_set<BasicBlock*> reachable;
std::queue<BasicBlock*> q;
auto* entry = func.GetEntry();
if (!entry) return false;
q.push(entry);
reachable.insert(entry);
while (!q.empty()) {
auto* bb = q.front();
q.pop();
for (auto* succ : bb->GetSuccessors()) {
if (!succ) continue;
if (reachable.insert(succ).second) q.push(succ);
}
}
std::vector<BasicBlock*> unreachable;
for (auto& bb : func.GetBlocks()) {
if (!reachable.count(bb.get())) unreachable.push_back(bb.get());
}
for (auto* bb : unreachable) {
func.RemoveBlock(bb);
}
if (!unreachable.empty()) func.RebuildCFG();
return !unreachable.empty();
}
} // namespace
bool RunCFGSimplify(Function& func, Context& ctx) {
bool changed = false;
changed |= SimplifyConstantBranches(func, ctx);
changed |= MergeEmptyBlocks(func);
changed |= RemoveUnreachableBlocks(func);
return changed;
}
} // namespace ir

@ -1,4 +1,123 @@
// 公共子表达式消除CSE // 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式 // - 识别并复用重复计算的等价表达式
// - 典型放置在 ConstFold 之后、DCE 之前 // - 局部值编号:在单个基本块内消除重复计算
// - 当前为 Lab4 的框架占位,具体算法由实验实现
#include "ir/IR.h"
#include <cstdint>
#include <sstream>
#include <string>
#include <unordered_map>
namespace ir {
namespace {
// 为操作数生成唯一标识:常量使用值,否则使用指针
std::string ValKey(Value* v) {
if (auto* ci = dynamic_cast<ConstantInt*>(v))
return "ci" + std::to_string(ci->GetValue());
if (auto* cf = dynamic_cast<ConstantFloat*>(v)) {
// 使用 IEEE 754 位表示
union { float f; uint32_t i; } u;
u.f = cf->GetValue();
return "cf" + std::to_string(u.i);
}
// 非常量使用指针地址作为唯一标识
std::ostringstream oss;
oss << "p" << reinterpret_cast<uintptr_t>(v);
return oss.str();
}
// 为可消除的指令生成 hash key
std::string MakeKey(Instruction* inst) {
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod:
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv: {
auto* bin = static_cast<BinaryInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(bin->GetLhs()) + "|" + ValKey(bin->GetRhs());
}
case Opcode::ICmp: {
auto* cmp = static_cast<ICmpInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
std::to_string(static_cast<int>(cmp->GetPredicate())) + "|" +
ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs());
}
case Opcode::FCmp: {
auto* cmp = static_cast<FCmpInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
std::to_string(static_cast<int>(cmp->GetPredicate())) + "|" +
ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs());
}
case Opcode::Gep: {
auto* gep = static_cast<GepInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(gep->GetBasePtr()) + "|" + ValKey(gep->GetIndex());
}
case Opcode::Load: {
auto* ld = static_cast<LoadInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(ld->GetPtr());
}
case Opcode::ZExt: {
auto* ze = static_cast<ZExtInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(ze->GetSrc());
}
case Opcode::SIToFP: {
auto* si = static_cast<SIToFPInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(si->GetSrc());
}
case Opcode::FPToSI: {
auto* fs = static_cast<FPToSIInst*>(inst);
return std::to_string(static_cast<int>(inst->GetOpcode())) + "|" +
ValKey(fs->GetSrc());
}
default: return "";
}
}
} // namespace
bool RunCSE(Function& func) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
std::unordered_map<std::string, Value*> available;
std::vector<Instruction*> to_remove;
for (auto& inst : bb->GetInstructions()) {
auto* ip = inst.get();
std::string key = MakeKey(ip);
if (key.empty()) {
// 不可消除的指令:如果它有结果,可以考虑将其加入可用集
// 但为了简单,这里不处理
continue;
}
auto it = available.find(key);
if (it != available.end()) {
// 找到已有的等价指令,替换使用
ip->ReplaceAllUsesWith(it->second);
to_remove.push_back(ip);
changed = true;
} else {
available[key] = ip;
}
}
for (auto* ip : to_remove) {
for (size_t i = 0; i < ip->GetNumOperands(); ++i)
ip->SetOperand(i, nullptr);
ip->RemoveFromParent();
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,187 @@
// IR 常量折叠: // IR 常量折叠:
// - 折叠可判定的常量表达式 // - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪) // - 简化常量控制流分支
#include "ir/IR.h"
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
// 在释放指令前断开其 use-def 链,避免其他值的 uses_ 中有悬空指针
static void DetachAndRemove(Instruction* inst) {
// 先断开自身对操作数的引用,清除 use-def 链
for (size_t i = 0; i < inst->GetNumOperands(); ++i)
inst->SetOperand(i, nullptr);
// 然后从父块中移除(清除父指针后不会再次尝试访问)
if (auto* parent = inst->GetParent()) {
parent->RemoveInstruction(inst);
}
}
bool FoldICmp(ICmpInst* cmp, Context& ctx) {
auto* lhs = dynamic_cast<ConstantInt*>(cmp->GetLhs());
auto* rhs = dynamic_cast<ConstantInt*>(cmp->GetRhs());
if (!lhs || !rhs) return false;
int lv = lhs->GetValue(), rv = rhs->GetValue();
bool result = false;
switch (cmp->GetPredicate()) {
case ICmpPredicate::EQ: result = lv == rv; break;
case ICmpPredicate::NE: result = lv != rv; break;
case ICmpPredicate::SLT: result = lv < rv; break;
case ICmpPredicate::SLE: result = lv <= rv; break;
case ICmpPredicate::SGT: result = lv > rv; break;
case ICmpPredicate::SGE: result = lv >= rv; break;
}
cmp->ReplaceAllUsesWith(ctx.GetConstInt(result ? 1 : 0));
DetachAndRemove(cmp);
return true;
}
bool FoldFCmp(FCmpInst* cmp, Context& ctx) {
auto* lhs = dynamic_cast<ConstantFloat*>(cmp->GetLhs());
auto* rhs = dynamic_cast<ConstantFloat*>(cmp->GetRhs());
if (!lhs || !rhs) return false;
float lv = lhs->GetValue(), rv = rhs->GetValue();
bool result = false;
switch (cmp->GetPredicate()) {
case FCmpPredicate::OEQ: result = lv == rv; break;
case FCmpPredicate::ONE: result = lv != rv; break;
case FCmpPredicate::OLT: result = lv < rv; break;
case FCmpPredicate::OLE: result = lv <= rv; break;
case FCmpPredicate::OGT: result = lv > rv; break;
case FCmpPredicate::OGE: result = lv >= rv; break;
}
cmp->ReplaceAllUsesWith(ctx.GetConstInt(result ? 1 : 0));
DetachAndRemove(cmp);
return true;
}
bool FoldZExt(ZExtInst* zext, Context& ctx) {
auto* src = dynamic_cast<ConstantInt*>(zext->GetSrc());
if (!src) return false;
zext->ReplaceAllUsesWith(ctx.GetConstInt(src->GetValue() != 0 ? 1 : 0));
DetachAndRemove(zext);
return true;
}
bool FoldSIToFP(SIToFPInst* inst, Context& ctx) {
auto* src = dynamic_cast<ConstantInt*>(inst->GetSrc());
if (!src) return false;
inst->ReplaceAllUsesWith(ctx.GetConstFloat(static_cast<float>(src->GetValue())));
DetachAndRemove(inst);
return true;
}
bool FoldFPToSI(FPToSIInst* inst, Context& ctx) {
auto* src = dynamic_cast<ConstantFloat*>(inst->GetSrc());
if (!src) return false;
inst->ReplaceAllUsesWith(ctx.GetConstInt(static_cast<int>(src->GetValue())));
DetachAndRemove(inst);
return true;
}
// Fold constant binary operations (int and float)
bool FoldBinaryWithCtx(BinaryInst* bin, Context& ctx) {
auto* lhs_c = dynamic_cast<ConstantInt*>(bin->GetLhs());
auto* rhs_c = dynamic_cast<ConstantInt*>(bin->GetRhs());
auto* lhs_f = dynamic_cast<ConstantFloat*>(bin->GetLhs());
auto* rhs_f = dynamic_cast<ConstantFloat*>(bin->GetRhs());
if (lhs_c && rhs_c) {
int lv = lhs_c->GetValue(), rv = rhs_c->GetValue();
int result = 0;
bool valid = true;
switch (bin->GetOpcode()) {
case Opcode::Add: result = lv + rv; break;
case Opcode::Sub: result = lv - rv; break;
case Opcode::Mul: result = lv * rv; break;
case Opcode::Div: if (rv != 0) result = lv / rv; else valid = false; break;
case Opcode::Mod: if (rv != 0) result = lv % rv; else valid = false; break;
default: valid = false; break;
}
if (valid) {
bin->ReplaceAllUsesWith(ctx.GetConstInt(result));
DetachAndRemove(bin);
return true;
}
}
if (lhs_f && rhs_f) {
float lv = lhs_f->GetValue(), rv = rhs_f->GetValue();
float result = 0.0f;
bool valid = true;
switch (bin->GetOpcode()) {
case Opcode::FAdd: result = lv + rv; break;
case Opcode::FSub: result = lv - rv; break;
case Opcode::FMul: result = lv * rv; break;
case Opcode::FDiv: if (rv != 0.0f) result = lv / rv; else valid = false; break;
default: valid = false; break;
}
if (valid) {
bin->ReplaceAllUsesWith(ctx.GetConstFloat(result));
DetachAndRemove(bin);
return true;
}
}
return false;
}
} // namespace
bool RunConstFold(Function& func, Context& ctx) {
bool changed = false;
std::unordered_set<void*> removed;
bool any_changed = true;
while (any_changed) {
any_changed = false;
for (auto& bb : func.GetBlocks()) {
// 每轮重新收集(因为指令列表在变化)
std::vector<Instruction*> insts;
for (auto& inst : bb->GetInstructions())
insts.push_back(inst.get());
for (auto* inst : insts) {
if (removed.count(inst)) continue;
bool folded = false;
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod:
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv:
folded = FoldBinaryWithCtx(static_cast<BinaryInst*>(inst), ctx);
break;
case Opcode::ICmp:
folded = FoldICmp(static_cast<ICmpInst*>(inst), ctx);
break;
case Opcode::FCmp:
folded = FoldFCmp(static_cast<FCmpInst*>(inst), ctx);
break;
case Opcode::ZExt:
folded = FoldZExt(static_cast<ZExtInst*>(inst), ctx);
break;
case Opcode::SIToFP:
folded = FoldSIToFP(static_cast<SIToFPInst*>(inst), ctx);
break;
case Opcode::FPToSI:
folded = FoldFPToSI(static_cast<FPToSIInst*>(inst), ctx);
break;
default: break;
}
if (folded) {
removed.insert(inst);
any_changed = true;
changed = true;
}
}
}
}
return changed;
}
} // namespace ir

@ -1,5 +1,63 @@
// 常量传播Constant Propagation // 常量传播Constant Propagation
// - 沿 use-def 关系传播已知常量 // - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会 // - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
#include "ir/IR.h"
#include <queue>
#include <unordered_map>
#include <vector>
namespace ir {
bool RunConstProp(Function& func, Context& ctx) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
std::vector<Instruction*> insts;
for (auto& inst : bb->GetInstructions())
insts.push_back(inst.get());
for (auto* inst : insts) {
if (inst->GetParent() == nullptr) continue;
// 检查是否为"复制"类指令:直接将一个操作数作为结果传播
// 实际上常量传播由 ConstFold 配合 use-def 链完成
// 这里处理简单的常量替换
switch (inst->GetOpcode()) {
case Opcode::ZExt: {
auto* ze = static_cast<ZExtInst*>(inst);
if (auto* ci = dynamic_cast<ConstantInt*>(ze->GetSrc())) {
ze->ReplaceAllUsesWith(ctx.GetConstInt(ci->GetValue() != 0 ? 1 : 0));
ze->RemoveFromParent();
changed = true;
}
break;
}
case Opcode::SIToFP: {
auto* si = static_cast<SIToFPInst*>(inst);
if (auto* ci = dynamic_cast<ConstantInt*>(si->GetSrc())) {
si->ReplaceAllUsesWith(
ctx.GetConstFloat(static_cast<float>(ci->GetValue())));
si->RemoveFromParent();
changed = true;
}
break;
}
case Opcode::FPToSI: {
auto* fp = static_cast<FPToSIInst*>(inst);
if (auto* cf = dynamic_cast<ConstantFloat*>(fp->GetSrc())) {
fp->ReplaceAllUsesWith(
ctx.GetConstInt(static_cast<int>(cf->GetValue())));
fp->RemoveFromParent();
changed = true;
}
break;
}
default: break;
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,68 @@
// 死代码删除DCE // 死代码删除DCE
// - 删除无用指令与无用基本块 // - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用 // - 标记-清扫算法:先标记有用指令,再清除未标记的
#include "ir/IR.h"
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
bool HasSideEffect(Instruction* inst) {
switch (inst->GetOpcode()) {
case Opcode::Store:
case Opcode::Call:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Ret:
return true;
default:
return false;
}
}
} // namespace
bool RunDCE(Function& func) {
bool changed = false;
for (auto& bb : func.GetBlocks()) {
// 收集要删除的指令
std::vector<Instruction*> to_remove;
for (auto& inst : bb->GetInstructions()) {
auto* ip = inst.get();
if (ip->IsTerminator()) continue;
if (HasSideEffect(ip)) continue;
// 检查该指令是否有使用者
bool has_use = false;
for (const auto& use : ip->GetUses()) {
if (use.GetUser()) {
has_use = true;
break;
}
}
if (!has_use) {
to_remove.push_back(ip);
}
}
for (auto* ip : to_remove) {
// 断开该指令对操作数的引用
for (size_t i = 0; i < ip->GetNumOperands(); ++i) {
ip->SetOperand(i, nullptr);
}
ip->RemoveFromParent();
changed = true;
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,190 @@
// Mem2RegSSA 构造): // Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式 // - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析 // - 插入 PHI 节点并重写使用,所有可提升的 alloca 在一次重命名遍中处理
#include "ir/IR.h"
#include <functional>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
bool IsPromotable(AllocaInst* alloca) {
if (alloca->IsArray()) return false;
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
if (auto* load = dynamic_cast<LoadInst*>(user)) {
if (load->GetPtr() != alloca) return false;
} else if (auto* store = dynamic_cast<StoreInst*>(user)) {
if (store->GetPtr() != alloca) return false;
} else {
return false;
}
}
return true;
}
void CollectStoresAndLoads(AllocaInst* alloca,
std::vector<StoreInst*>& stores,
std::vector<LoadInst*>& loads) {
for (const auto& use : alloca->GetUses()) {
if (auto* store = dynamic_cast<StoreInst*>(use.GetUser())) {
if (use.GetOperandIndex() == 1) stores.push_back(store);
} else if (auto* load = dynamic_cast<LoadInst*>(use.GetUser())) {
loads.push_back(load);
}
}
}
std::set<BasicBlock*> ComputeIDF(const std::set<BasicBlock*>& def_blocks,
DominatorTree& dt) {
std::set<BasicBlock*> df_plus;
std::vector<BasicBlock*> worklist(def_blocks.begin(), def_blocks.end());
std::set<BasicBlock*> visited(def_blocks.begin(), def_blocks.end());
while (!worklist.empty()) {
auto* bb = worklist.back();
worklist.pop_back();
for (auto* df_bb : dt.GetDominanceFrontier(bb)) {
if (df_plus.insert(df_bb).second) {
if (visited.insert(df_bb).second) worklist.push_back(df_bb);
}
}
}
return df_plus;
}
struct AllocaInfo {
AllocaInst* alloca = nullptr;
std::vector<StoreInst*> stores;
std::vector<LoadInst*> loads;
std::unordered_map<BasicBlock*, PhiInst*> phis;
std::vector<Value*> value_stack;
Value* undef_val = nullptr;
};
} // namespace
bool RunMem2Reg(Function& func, Context& ctx) {
DominatorTree dt;
dt.Compute(func);
// 收集所有可提升的 alloca
std::vector<AllocaInst*> promotable;
for (auto& bb : func.GetBlocks()) {
for (auto& inst : bb->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotable(alloca)) promotable.push_back(alloca);
}
}
}
if (promotable.empty()) return false;
// 为每个可提升的 alloca 构建信息
std::vector<AllocaInfo> infos(promotable.size());
std::unordered_map<StoreInst*, int> store_to_info;
std::unordered_map<LoadInst*, int> load_to_info;
for (size_t i = 0; i < promotable.size(); ++i) {
auto* alloca = promotable[i];
auto& info = infos[i];
info.alloca = alloca;
CollectStoresAndLoads(alloca, info.stores, info.loads);
std::set<BasicBlock*> def_blocks;
for (auto* s : info.stores) def_blocks.insert(s->GetParent());
auto val_type = alloca->GetType()->IsPtrFloat32() ? Type::GetFloat32Type()
: Type::GetInt32Type();
info.undef_val = alloca->GetType()->IsPtrFloat32()
? static_cast<Value*>(ctx.GetConstFloat(0.0f))
: static_cast<Value*>(ctx.GetConstInt(0));
info.value_stack.push_back(info.undef_val);
// 插入 PHI 节点到迭代支配边界
auto df_plus = ComputeIDF(def_blocks, dt);
for (auto* bb : df_plus) {
auto* phi = bb->Prepend<PhiInst>(val_type, "");
info.phis[bb] = phi;
}
// 建立快速查找映射
for (auto* s : info.stores) store_to_info[s] = (int)i;
for (auto* l : info.loads) load_to_info[l] = (int)i;
}
// ─── 单次重命名遍DFS 遍历支配树,同时处理所有 alloca ──────────────
std::function<void(BasicBlock*)> rename = [&](BasicBlock* bb) {
// 保存所有栈大小
std::vector<size_t> saved_sizes(infos.size());
for (size_t i = 0; i < infos.size(); ++i) {
saved_sizes[i] = infos[i].value_stack.size();
auto phi_it = infos[i].phis.find(bb);
if (phi_it != infos[i].phis.end()) {
infos[i].value_stack.push_back(phi_it->second);
}
}
// 处理块内指令
for (auto& inst_up : bb->GetInstructions()) {
auto* inst = inst_up.get();
// Skip PHI nodes (they've already been pushed onto stacks)
if (dynamic_cast<PhiInst*>(inst)) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
auto it = store_to_info.find(store);
if (it != store_to_info.end()) {
infos[it->second].value_stack.push_back(store->GetValue());
}
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
auto it = load_to_info.find(load);
if (it != load_to_info.end()) {
load->ReplaceAllUsesWith(infos[it->second].value_stack.back());
}
}
}
// 设置后继块中 PHI 节点的 incoming values
for (auto* succ : bb->GetSuccessors()) {
for (size_t i = 0; i < infos.size(); ++i) {
auto phi_it = infos[i].phis.find(succ);
if (phi_it != infos[i].phis.end()) {
phi_it->second->AddIncoming(infos[i].value_stack.back(), bb);
}
}
}
// 递归遍历支配树子节点
for (auto* child : dt.GetChildren(bb)) rename(child);
// 恢复栈
for (size_t i = 0; i < infos.size(); ++i) {
infos[i].value_stack.resize(saved_sizes[i]);
}
};
rename(func.GetEntry());
// 删除已提升的 load、store 和 alloca
// 必须先断开 use-def 链再删除,否则其他值的使用列表中会有悬空指针
for (auto& info : infos) {
for (auto* ld : info.loads) {
ld->SetOperand(0, nullptr); // 断开对 alloca 的引用
ld->RemoveFromParent();
}
for (auto* st : info.stores) {
st->SetOperand(0, nullptr); // 断开对 value 的引用
st->SetOperand(1, nullptr); // 断开对 alloca 的引用
st->RemoveFromParent();
}
info.alloca->RemoveFromParent();
}
return true;
}
} // namespace ir

@ -1 +1,89 @@
// IR Pass 管理骨架。 // IR Pass 管理:按顺序执行优化遍,支持迭代至不动点
#include "ir/IR.h"
#include <queue>
#include <unordered_set>
namespace ir {
// 前向声明(定义在各 pass 文件中)
extern bool RunMem2Reg(Function& func, Context& ctx);
extern bool RunConstFold(Function& func, Context& ctx);
extern bool RunConstProp(Function& func, Context& ctx);
extern bool RunDCE(Function& func);
extern bool RunCFGSimplify(Function& func, Context& ctx);
extern bool RunCSE(Function& func);
// 清理 PHI 节点:修复无效的 incoming 值/块引用,补齐缺失的前驱条目
static void SanitizePhis(Function& func, Context& ctx) {
func.RebuildCFG();
auto* entry = func.GetEntry();
if (!entry) return;
for (auto& bb : func.GetBlocks()) {
// 收集 PHI 节点
std::vector<PhiInst*> phis;
for (auto& inst : bb->GetInstructions()) {
if (auto* phi = dynamic_cast<PhiInst*>(inst.get()))
phis.push_back(phi);
}
if (phis.empty()) continue;
auto& preds = bb->GetPredecessors();
Value* undef_val = ctx.GetConstInt(0);
for (auto* phi : phis) {
// 收集已有的 incoming 块
std::unordered_set<BasicBlock*> existing;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
existing.insert(phi->GetIncomingBlock(i));
}
// 为每个前驱补齐缺失的 incomingLLVM 要求 PHI 覆盖所有前驱)
for (auto* pred : preds) {
if (!existing.count(pred)) {
phi->AddIncoming(undef_val, pred);
}
}
// 修复无效的 incoming
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
auto* inc_val = phi->GetIncomingValue(i);
auto* inc_bb = phi->GetIncomingBlock(i);
if (!inc_val || !inc_bb) {
phi->SetOperand(i * 2, undef_val);
phi->SetOperand(i * 2 + 1, entry);
}
}
}
}
}
void RunPasses(Module& module) {
Context& ctx = module.GetContext();
bool changed = true;
int iteration = 0;
const int kMaxIterations = 10;
for (auto& func : module.GetFunctions()) {
RunMem2Reg(*func, ctx);
}
while (changed && iteration < kMaxIterations) {
changed = false;
++iteration;
for (auto& func : module.GetFunctions()) {
changed |= RunConstFold(*func, ctx);
changed |= RunConstProp(*func, ctx);
changed |= RunCSE(*func);
// CFGSimplify 在处理 PHI 节点较多的 CFG 时会导致悬空指针
// 其功能(空块合并 + 不可达块删除)由后续 DCE + SanitizePhis 部分承担
changed |= RunDCE(*func);
SanitizePhis(*func, ctx);
}
}
}
} // namespace ir

@ -117,33 +117,101 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
for (int d : dims) total *= (d > 0 ? d : 1); for (int d : dims) total *= (d > 0 ? d : 1);
if (in_global_scope_) { if (in_global_scope_) {
auto* gv = module_.CreateGlobalVariable(name, true, 0, total); auto* gv = module_.CreateGlobalVariable(name, true, 0, total, true);
global_storage_map_[constDef] = gv; global_storage_map_[constDef] = gv;
global_array_dims_[constDef] = dims; global_array_dims_[constDef] = dims;
// 计算初始值并存入全局变量
if (constDef->constInitVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::vector<int> flat(total, 0);
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) { flat[pos] = EvalConstExprInt(iv->constExp()); return; }
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; fill(sub,a,sub_stride); cur=a+sub_stride; }
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur+top_stride-1)/top_stride)*top_stride; fill(sub,a,top_stride); cur=a+top_stride; }
}
}
gv->SetInitVals(flat);
}
} else { } else {
auto* slot = builder_.CreateAllocaArray(total, name); auto* slot = builder_.CreateAllocaArray(total, name);
storage_map_[constDef] = slot; storage_map_[constDef] = slot;
array_dims_[constDef] = dims; array_dims_[constDef] = dims;
// 扁平化初始化 // 按 C 语义扁平化初始化(子列表对齐到维度边界)
if (constDef->constInitVal()) { if (constDef->constInitVal()) {
std::vector<int> flat; std::vector<int> flat(total, 0);
flat.reserve(total);
std::function<void(SysYParser::ConstInitValContext*)> flatten = std::vector<int> strides(dims.size(), 1);
[&](SysYParser::ConstInitValContext* iv) { for (int i = (int)dims.size() - 2; i >= 0; --i)
if (!iv) return; strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) { if (iv->constExp()) {
flat.push_back(EvalConstExprInt(iv->constExp())); flat[pos] = EvalConstExprInt(iv->constExp());
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else { } else {
for (auto* sub : iv->constInitVal()) flatten(sub); int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
} }
}; };
flatten(constDef->constInitVal());
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
for (int i = 0; i < total; ++i) { for (int i = 0; i < total; ++i) {
int v = (i < (int)flat.size()) ? flat[i] : 0;
auto* ptr = builder_.CreateGep( auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i), slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp()); module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(v), ptr); builder_.CreateStore(builder_.CreateConstInt(flat[i]), ptr);
} }
} }
} }
@ -179,7 +247,13 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (in_global_scope_) { if (in_global_scope_) {
if (!is_array) { if (!is_array) {
// 全局标量初始化器必须是常量简化处理为0 auto* gv = module_.CreateGlobalVariable(name, false, 0, 1, false, is_float);
// 全局标量:初始化器必须是常量
if(is_float){
float init_valf = 0;
try { init_valf = static_cast<float>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {}
gv->SetInitValF(init_valf);
} else {
int init_val = 0; int init_val = 0;
if (ctx->initVal() && ctx->initVal()->exp()) { if (ctx->initVal() && ctx->initVal()->exp()) {
try { try {
@ -189,14 +263,103 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
init_val = 0; init_val = 0;
} }
} }
auto* gv = module_.CreateGlobalVariable(name, false, init_val); gv->SetInitVal(init_val);
}
global_storage_map_[ctx] = gv; global_storage_map_[ctx] = gv;
} else { } else {
int total = 1; int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1); for (int d : dims) total *= (d > 0 ? d : 1);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total); auto* gv = module_.CreateGlobalVariable(name, false, 0, total, true, is_float);
global_storage_map_[ctx] = gv; global_storage_map_[ctx] = gv;
global_array_dims_[ctx] = dims; global_array_dims_[ctx] = dims;
// 计算初始值
if (ctx->initVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
if (is_float) {
std::vector<float> flat(total, 0.0f);
std::function<void(SysYParser::InitValContext*, int, int)> fill_f;
fill_f = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<float>(sem::EvaluateExp(*iv->exp()->addExp()).float_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill_f(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<float>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill_f(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitValsF(flat);
} else {
std::vector<int> flat(total, 0);
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<int>(sem::EvaluateExp(*iv->exp()->addExp()).int_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<int>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).int_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitVals(flat);
}
}
} }
} else { } else {
if (storage_map_.count(ctx)) { if (storage_map_.count(ctx)) {
@ -211,6 +374,14 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
ir::Value* init; ir::Value* init;
if (ctx->initVal() && ctx->initVal()->exp()) { if (ctx->initVal() && ctx->initVal()->exp()) {
init = EvalExpr(*ctx->initVal()->exp()); init = EvalExpr(*ctx->initVal()->exp());
// Coerce init value to slot type
if (!is_float && init->IsFloat32()) {
init = ToInt(init);
} else if (is_float && init->IsInt32()) {
init = ToFloat(init);
} else if (!is_float && init->IsInt1()) {
init = ToI32(init);
}
} else { } else {
init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f)) init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0)); : static_cast<ir::Value*>(builder_.CreateConstInt(0));
@ -219,40 +390,95 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else { } else {
int total = 1; int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1); for (int d : dims) total *= (d > 0 ? d : 1);
auto* slot = builder_.CreateAllocaArray(total, name); auto* slot = is_float ? builder_.CreateAllocaArrayF32(total, module_.GetContext().NextTemp())
: builder_.CreateAllocaArray(total, module_.GetContext().NextTemp());
storage_map_[ctx] = slot; storage_map_[ctx] = slot;
array_dims_[ctx] = dims; array_dims_[ctx] = dims;
ir::Value* zero_init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
if (ctx->initVal()) { if (ctx->initVal()) {
// 收集扁平化初始值 // 按 C 语义扁平化初始值:子列表对齐到对应维度边界
std::vector<ir::Value*> flat; std::vector<ir::Value*> flat(total, zero_init);
flat.reserve(total);
std::function<void(SysYParser::InitValContext*)> flatten = // 计算各维度的 stridestride[i] = dims[i]*dims[i+1]*...*dims[n-1]
[&](SysYParser::InitValContext* iv) { // 但我们需要「子列表对应第几维的 stride」
if (!iv) return; // 顶层stride = total / dims[0](即每行的元素数)
// 递归时 stride 继续除以当前维度大小
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0]; // 每个顶层子列表占用的元素数
// fill(iv, pos, stride):将 iv 的内容填入 flat[pos..pos+stride)
// stride 表示当前层子列表对应的元素个数
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) { if (iv->exp()) {
flat.push_back(EvalExpr(*iv->exp())); flat[pos] = EvalExpr(*iv->exp());
return;
}
// 子列表内的 stride = stride / (当前层首维大小)
// 找到对应的 strides 层stride == strides[k] → 子stride = strides[k+1]
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k) {
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
}
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else { } else {
for (auto* sub : iv->initVal()) flatten(sub); // 对齐到 sub_stride 边界
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
} }
};
flatten(ctx->initVal());
for (int i = 0; i < total; ++i) {
ir::Value* v = (i < (int)flat.size()) ? flat[i]
: builder_.CreateConstInt(0);
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(v, ptr);
} }
};
// 顶层扫描
int cur = 0;
if (ctx->initVal()->exp()) {
flat[0] = EvalExpr(*ctx->initVal()->exp());
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else { } else {
// 零初始化 // 对齐到 top_stride 边界
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
// 先 memset 归零,再只写入非零元素
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
for (int i = 0; i < total; ++i) { for (int i = 0; i < total; ++i) {
bool is_zero = false;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(flat[i])) {
is_zero = (ci->GetValue() == 0);
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(flat[i])) {
is_zero = (cf->GetValue() == 0.0f);
}
if (is_zero) continue;
auto* ptr = builder_.CreateGep( auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i), slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp()); module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), ptr); builder_.CreateStore(flat[i], ptr);
} }
} else {
// 零初始化:用 memset 归零
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
(void)zero_init;
} }
} }
} }

@ -7,13 +7,66 @@
#include "ir/IR.h" #include "ir/IR.h"
#include "sem/func.h" #include "sem/func.h"
#include "utils/Log.h" #include "utils/Log.h"
#include <cmath> // 用于 ldexp
// ─── 辅助 ───────────────────────────────────────────────────────────────────── // ─── 辅助 ─────────────────────────────────────────────────────────────────────
// 把 i32 值转成 i1icmp ne i32 v, 0
// 静态辅助函数:解析十六进制浮点字面量
static float ParseHexFloat(const std::string& str) {
const char* s = str.c_str();
// 跳过 "0x" 或 "0X"
if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) s += 2;
double significand = 0.0;
bool have_dot = false;
double dot_scale = 1.0 / 16.0;
while (*s && *s != 'p' && *s != 'P') {
if (*s == '.') {
have_dot = true;
++s;
continue;
}
int digit = -1;
if (*s >= '0' && *s <= '9') digit = *s - '0';
else if (*s >= 'a' && *s <= 'f') digit = *s - 'a' + 10;
else if (*s >= 'A' && *s <= 'F') digit = *s - 'A' + 10;
if (digit >= 0) {
if (have_dot) {
significand += digit * dot_scale;
dot_scale /= 16.0;
} else {
significand = significand * 16 + digit;
}
}
++s;
}
int exponent = 0;
if (*s == 'p' || *s == 'P') {
++s;
int sign = 1;
if (*s == '-') { sign = -1; ++s; }
else if (*s == '+') { ++s; }
exponent = 0;
while (*s >= '0' && *s <= '9') {
exponent = exponent * 10 + (*s - '0');
++s;
}
exponent *= sign;
}
return static_cast<float>(ldexp(significand, exponent));
}
// 把 i32/float 值转成 i1
ir::Value* IRGenImpl::ToI1(ir::Value* v) { ir::Value* IRGenImpl::ToI1(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value")); if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value"));
if (v->IsInt1()) return v; if (v->IsInt1()) return v;
if (v->IsFloat32()) {
return builder_.CreateFCmp(ir::FCmpPredicate::ONE, v,
builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmp(ir::ICmpPredicate::NE, v, return builder_.CreateICmp(ir::ICmpPredicate::NE, v,
builder_.CreateConstInt(0), builder_.CreateConstInt(0),
module_.GetContext().NextTemp()); module_.GetContext().NextTemp());
@ -87,7 +140,13 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
} else if (name == "getch") { } else if (name == "getch") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
} else if (name == "getfloat") { } else if (name == "getfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似 module_.DeclareExternalFunc(name, ir::Type::GetFloat32Type(), {});
} else if (name == "getarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type()});
} else if (name == "getfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrFloat32Type()});
} else if (name == "putint") { } else if (name == "putint") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()}); {ir::Type::GetInt32Type()});
@ -95,10 +154,16 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()}); {ir::Type::GetInt32Type()});
} else if (name == "putfloat") { } else if (name == "putfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetFloat32Type()});
} else if (name == "putarray") { } else if (name == "putarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()}); {ir::Type::GetInt32Type(),
ir::Type::GetPtrInt32Type()});
} else if (name == "putfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(),
ir::Type::GetPtrFloat32Type()});
} else if (name == "starttime" || name == "stoptime") { } else if (name == "starttime" || name == "stoptime") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()}); {ir::Type::GetInt32Type()});
@ -227,13 +292,113 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
std::vector<ir::Value*> args; std::vector<ir::Value*> args;
if (ctx->funcRParams()) { if (ctx->funcRParams()) {
for (auto* exp : ctx->funcRParams()->exp()) { for (auto* exp : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*exp)); // 检查是否是数组变量(无索引的 lVar若是则传指针而非 load
ir::Value* arg = nullptr;
auto* add = exp->addExp();
if (add && add->mulExp().size() == 1) {
auto* mul = add->mulExp(0);
if (mul && mul->unaryExp().size() == 1) {
auto* unary = mul->unaryExp(0);
if (unary && !unary->unaryOp() && unary->primaryExp()) {
auto* primary = unary->primaryExp();
if (primary && primary->lVar() && primary->lVar()->exp().empty()) {
auto* lvar = primary->lVar();
auto* decl = sema_.ResolveVarUse(lvar->Ident());
if (decl) {
// 检查是否是数组参数storage_map_ 里存的是指针)
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
auto* val = it->second;
if (val && (val->IsPtrInt32() || val->IsPtrFloat32())) {
// 检查是否是 Argument数组参数直接传指针
if (dynamic_cast<ir::Argument*>(val)) {
arg = val;
} else if (array_dims_.count(decl)) {
// 本地数组(含 dims 记录):传首元素地址
arg = builder_.CreateGep(val, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
// 检查全局数组
if (!arg) {
auto git = global_storage_map_.find(decl);
if (git != global_storage_map_.end()) {
auto* gv = dynamic_cast<ir::GlobalVariable*>(git->second);
if (gv && gv->IsArray()) {
arg = builder_.CreateGep(git->second,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
}
}
}
}
}
// Also handle partially-indexed multi-dim arrays: arr[i] where arr is
// int arr[M][N] should pass &arr[i*N] as i32*, not load arr[i] as i32.
if (!arg) {
auto* add2 = exp->addExp();
if (add2 && add2->mulExp().size() == 1) {
auto* mul2 = add2->mulExp(0);
if (mul2 && mul2->unaryExp().size() == 1) {
auto* unary2 = mul2->unaryExp(0);
if (unary2 && !unary2->unaryOp() && unary2->primaryExp()) {
auto* primary2 = unary2->primaryExp();
if (primary2 && primary2->lVar() && !primary2->lVar()->exp().empty()) {
auto* lvar2 = primary2->lVar();
auto* decl2 = sema_.ResolveVarUse(lvar2->Ident());
if (decl2) {
std::vector<int> dims2;
ir::Value* base2 = nullptr;
auto it2 = array_dims_.find(decl2);
if (it2 != array_dims_.end()) {
dims2 = it2->second;
auto sit = storage_map_.find(decl2);
if (sit != storage_map_.end()) base2 = sit->second;
} else {
auto git2 = global_array_dims_.find(decl2);
if (git2 != global_array_dims_.end()) {
dims2 = git2->second;
auto gsit = global_storage_map_.find(decl2);
if (gsit != global_storage_map_.end()) base2 = gsit->second;
}
}
// Partially indexed: fewer indices than dimensions -> pass pointer
bool is_param = !dims2.empty() && dims2[0] == -1;
size_t effective_dims = is_param ? dims2.size() - 1 : dims2.size();
if (base2 && !dims2.empty() &&
lvar2->exp().size() < effective_dims + (is_param ? 1 : 0)) {
arg = EvalLVarAddr(lvar2);
}
}
}
}
}
}
}
if (!arg) arg = EvalExpr(*exp);
args.push_back(arg);
} }
} }
// 模块内已知函数? // 模块内已知函数?
ir::Function* callee = module_.GetFunction(callee_name); ir::Function* callee = module_.GetFunction(callee_name);
if (callee) { if (callee) {
// Coerce args to match parameter types
for (size_t i = 0; i < args.size() && i < callee->GetNumArgs(); ++i) {
auto* param = callee->GetArgument(i);
if (!param || !args[i]) continue;
if (param->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name = std::string ret_name =
callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp(); callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp();
auto* call = auto* call =
@ -246,15 +411,28 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 外部函数 // 外部函数
EnsureExternalDecl(callee_name); EnsureExternalDecl(callee_name);
// 获取返回类型 // 获取返回类型和参数类型
std::shared_ptr<ir::Type> ret_type = ir::Type::GetInt32Type(); std::shared_ptr<ir::Type> ret_type = ir::Type::GetInt32Type();
std::vector<std::shared_ptr<ir::Type>> param_types;
for (const auto& decl : module_.GetExternalDecls()) { for (const auto& decl : module_.GetExternalDecls()) {
if (decl.name == callee_name) { if (decl.name == callee_name) {
ret_type = decl.ret_type; ret_type = decl.ret_type;
param_types = decl.param_types;
break; break;
} }
} }
bool is_void = ret_type->IsVoid(); bool is_void = ret_type->IsVoid();
// Coerce args to match external function parameter types
for (size_t i = 0; i < args.size() && i < param_types.size(); ++i) {
if (!args[i]) continue;
if (param_types[i]->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param_types[i]->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param_types[i]->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name = is_void ? "" : module_.GetContext().NextTemp(); std::string ret_name = is_void ? "" : module_.GetContext().NextTemp();
auto* call = builder_.CreateCallExternal(callee_name, ret_type, auto* call = builder_.CreateCallExternal(callee_name, ret_type,
std::move(args), ret_name); std::move(args), ret_name);
@ -331,40 +509,26 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
throw std::runtime_error(FormatError("irgen", "数组索引维度过多")); throw std::runtime_error(FormatError("irgen", "数组索引维度过多"));
} }
ir::Value* offset = builder_.CreateConstInt(0); ir::Value* offset = nullptr;
if (is_array_param) { if (is_array_param) {
// 数组参数dims[0]=-1, dims[1..n]是已知维度 // 数组参数dims[0]=-1, dims[1..n]是已知维度
// 索引indices[0]对应第一维indices[1]对应第二维...
for (size_t i = 0; i < indices.size(); ++i) { for (size_t i = 0; i < indices.size(); ++i) {
ir::Value* idx = EvalExpr(*indices[i]); ir::Value* idx = EvalExpr(*indices[i]);
if (i == 0) {
// 第一维stride = dims[1] * dims[2] * ... (如果有的话)
int stride = 1; int stride = 1;
for (size_t j = 1; j < dims.size(); ++j) { size_t start = (i == 0) ? 1 : i + 1;
stride *= dims[j]; for (size_t j = start; j < dims.size(); ++j) stride *= dims[j];
} ir::Value* term;
if (stride > 1) { if (stride == 1) {
ir::Value* scaled = builder_.CreateMul( term = idx;
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
} else { } else {
offset = builder_.CreateAdd(offset, idx, term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp()); module_.GetContext().NextTemp());
} }
if (!offset) {
offset = term;
} else { } else {
// 后续维度 offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
int stride = 1;
for (size_t j = i + 1; j < dims.size(); ++j) {
stride *= dims[j];
}
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
} }
} }
} else { } else {
@ -374,14 +538,23 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1]; stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1];
if (i < (int)indices.size()) { if (i < (int)indices.size()) {
ir::Value* idx = EvalExpr(*indices[i]); ir::Value* idx = EvalExpr(*indices[i]);
ir::Value* scaled = builder_.CreateMul( ir::Value* term;
idx, builder_.CreateConstInt(stride), if (stride == 1) {
module_.GetContext().NextTemp()); term = idx;
offset = builder_.CreateAdd(offset, scaled, } else {
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp()); module_.GetContext().NextTemp());
} }
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
} }
} }
}
if (!offset) offset = builder_.CreateConstInt(0);
return builder_.CreateGep(base, offset, module_.GetContext().NextTemp()); return builder_.CreateGep(base, offset, module_.GetContext().NextTemp());
} }
@ -446,12 +619,16 @@ std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (ctx->FloatConst()) { if (ctx->FloatConst()) {
std::string text = ctx->FloatConst()->getText(); std::string text = ctx->FloatConst()->getText();
float val = 0.0f; float val = 0.0f;
if (text.size() >= 2 && (text[1] == 'x' || text[1] == 'X')) {
val = ParseHexFloat(text);
} else {
try { try {
val = std::stof(text); val = std::stof(text);
} catch (...) { } catch (...) {
throw std::runtime_error( throw std::runtime_error(
FormatError("irgen", "浮点字面量解析失败: " + text)); FormatError("irgen", "浮点字面量解析失败: " + text));
} }
}
return static_cast<ir::Value*>(builder_.CreateConstFloat(val)); return static_cast<ir::Value*>(builder_.CreateConstFloat(val));
} }
throw std::runtime_error(FormatError("irgen", "非法数字节点")); throw std::runtime_error(FormatError("irgen", "非法数字节点"));
@ -486,8 +663,8 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
ir::Value* res_ext = ToI32(result); ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot); builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs"); ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end"); ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.end");
builder_.CreateCondBr(result, end_bb, rhs_bb); builder_.CreateCondBr(result, end_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb); builder_.SetInsertPoint(rhs_bb);
@ -498,6 +675,7 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
builder_.CreateBr(end_bb); builder_.CreateBr(end_bb);
} }
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb); builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
} }
@ -523,8 +701,8 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
ir::Value* res_ext = ToI32(result); ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot); builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs"); ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end"); ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.end");
builder_.CreateCondBr(result, rhs_bb, end_bb); builder_.CreateCondBr(result, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb); builder_.SetInsertPoint(rhs_bb);
@ -535,6 +713,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
builder_.CreateBr(end_bb); builder_.CreateBr(end_bb);
} }
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb); builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
} }

@ -76,6 +76,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
// 设置插入点到入口块 // 设置插入点到入口块
builder_.SetInsertPoint(func_->GetEntry()); builder_.SetInsertPoint(func_->GetEntry());
builder_.SetAllocaBlock(func_->GetEntry());
// 处理参数 // 处理参数
if (ctx->funcFParams()) { if (ctx->funcFParams()) {

@ -25,6 +25,29 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->Return()) { if (ctx->Return()) {
if (ctx->exp()) { if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp()); ir::Value* v = EvalExpr(*ctx->exp());
// Coerce return value to function return type
if (func_->GetType()->IsInt32() && v->IsFloat32()) {
v = ToInt(v);
} else if (func_->GetType()->IsFloat32() && v->IsInt32()) {
v = ToFloat(v);
} else if (func_->GetType()->IsInt32() && v->IsInt1()) {
v = ToI32(v);
}
if (func_->GetName() == "main" && !func_->GetType()->IsVoid()) {
auto* nl = builder_.CreateConstInt(10);
// ((v % 256) + 256) % 256
auto* mod256 = builder_.CreateMod(v, builder_.CreateConstInt(256),
module_.GetContext().NextTemp());
auto* add256 = builder_.CreateAdd(mod256, builder_.CreateConstInt(256),
module_.GetContext().NextTemp());
auto* masked = builder_.CreateMod(add256, builder_.CreateConstInt(256),
module_.GetContext().NextTemp());
std::vector<ir::Value*> args1 = {masked};
builder_.CreateCallExternal("putint", ir::Type::GetVoidType(), args1, "");
std::vector<ir::Value*> args2 = {nl};
builder_.CreateCallExternal("putch", ir::Type::GetVoidType(), args2, "");
}
builder_.CreateRet(v); builder_.CreateRet(v);
} else { } else {
builder_.CreateRetVoid(); builder_.CreateRetVoid();
@ -54,6 +77,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->lVar() && ctx->Assign()) { if (ctx->lVar() && ctx->Assign()) {
ir::Value* rhs = EvalExpr(*ctx->exp()); ir::Value* rhs = EvalExpr(*ctx->exp());
ir::Value* addr = EvalLVarAddr(ctx->lVar()); ir::Value* addr = EvalLVarAddr(ctx->lVar());
// Coerce rhs to match slot type
if (addr->IsPtrInt32() && rhs->IsFloat32()) {
rhs = ToInt(rhs);
} else if (addr->IsPtrFloat32() && rhs->IsInt32()) {
rhs = ToFloat(rhs);
} else if (addr->IsPtrInt32() && rhs->IsInt1()) {
rhs = ToI32(rhs);
}
builder_.CreateStore(rhs, addr); builder_.CreateStore(rhs, addr);
return BlockFlow::Continue; return BlockFlow::Continue;
} }
@ -74,32 +105,47 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
} }
auto stmts = ctx->stmt(); auto stmts = ctx->stmt();
// Step 1: evaluate condition (may create short-circuit blocks with lower
// SSA numbers — must happen before any branch-target blocks are created).
ir::Value* cond_val = EvalCond(*ctx->cond());
// Step 2: create then_bb now (its label number will be >= all short-circuit
// block numbers allocated during EvalCond).
ir::BasicBlock* then_bb = func_->CreateBlock( ir::BasicBlock* then_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.then"); module_.GetContext().NextLabel() + ".if.then");
// Step 3: create else_bb/merge_bb as placeholders. They will be moved to
// the end of the block list after their predecessors are filled in, so the
// block ordering in the output will be correct even though their label
// numbers are allocated here (before then-body sub-blocks are created).
ir::BasicBlock* else_bb = nullptr; ir::BasicBlock* else_bb = nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock( ir::BasicBlock* merge_bb = nullptr;
module_.GetContext().NextTemp() + ".if.end");
// 求值条件(可能创建短路求值块) if (stmts.size() >= 2) {
ir::Value* cond_val = EvalCond(*ctx->cond()); else_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.else");
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
} else {
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
}
// 检查当前块是否已终结(短路求值可能导致) // Check if current block already terminated (short-circuit may do this)
if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) {
// 条件求值已经终结了当前块,无法继续 func_->MoveBlockToEnd(then_bb);
// 这种情况下我们需要在merge_bb继续 if (else_bb) func_->MoveBlockToEnd(else_bb);
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb); builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue; return BlockFlow::Continue;
} }
if (stmts.size() >= 2) { if (stmts.size() >= 2) {
// if-else
else_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".if.else");
builder_.CreateCondBr(cond_val, then_bb, else_bb); builder_.CreateCondBr(cond_val, then_bb, else_bb);
} else { } else {
builder_.CreateCondBr(cond_val, then_bb, merge_bb); builder_.CreateCondBr(cond_val, then_bb, merge_bb);
} }
// then 分支 // then 分支 — visit body (may create many sub-blocks with higher numbers)
func_->MoveBlockToEnd(then_bb);
builder_.SetInsertPoint(then_bb); builder_.SetInsertPoint(then_bb);
auto then_flow = VisitStmt(*stmts[0]); auto then_flow = VisitStmt(*stmts[0]);
if (then_flow != BlockFlow::Terminated) { if (then_flow != BlockFlow::Terminated) {
@ -108,6 +154,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
// else 分支 // else 分支
if (else_bb) { if (else_bb) {
func_->MoveBlockToEnd(else_bb);
builder_.SetInsertPoint(else_bb); builder_.SetInsertPoint(else_bb);
auto else_flow = VisitStmt(*stmts[1]); auto else_flow = VisitStmt(*stmts[1]);
if (else_flow != BlockFlow::Terminated) { if (else_flow != BlockFlow::Terminated) {
@ -115,6 +162,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
} }
} }
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb); builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue; return BlockFlow::Continue;
} }
@ -124,28 +172,32 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (!ctx->cond()) { if (!ctx->cond()) {
throw std::runtime_error(FormatError("irgen", "while 缺少条件")); throw std::runtime_error(FormatError("irgen", "while 缺少条件"));
} }
ir::BasicBlock* cond_bb = func_->CreateBlock( ir::BasicBlock* cond_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.cond"); module_.GetContext().NextLabel() + ".while.cond");
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.end");
// 跳转到条件块 // 跳转到条件块
if (!builder_.GetInsertBlock()->HasTerminator()) { if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(cond_bb); builder_.CreateBr(cond_bb);
} }
// 条件块 // EvalCond MUST come before creating body_bb/after_bb so that
// short-circuit blocks get lower SSA numbers than the loop body blocks.
builder_.SetInsertPoint(cond_bb); builder_.SetInsertPoint(cond_bb);
ir::Value* cond_val = EvalCond(*ctx->cond()); ir::Value* cond_val = EvalCond(*ctx->cond());
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.end");
// 检查条件求值后是否已终结 // 检查条件求值后是否已终结
if (!builder_.GetInsertBlock()->HasTerminator()) { if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateCondBr(cond_val, body_bb, after_bb); builder_.CreateCondBr(cond_val, body_bb, after_bb);
} }
// 循环体(压入循环栈) // 循环体(压入循环栈)
func_->MoveBlockToEnd(body_bb);
loop_stack_.push_back({cond_bb, after_bb}); loop_stack_.push_back({cond_bb, after_bb});
builder_.SetInsertPoint(body_bb); builder_.SetInsertPoint(body_bb);
auto stmts = ctx->stmt(); auto stmts = ctx->stmt();
@ -159,6 +211,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
} }
loop_stack_.pop_back(); loop_stack_.pop_back();
func_->MoveBlockToEnd(after_bb);
builder_.SetInsertPoint(after_bb); builder_.SetInsertPoint(after_bb);
return BlockFlow::Continue; return BlockFlow::Continue;
} }

@ -1,4 +1,5 @@
#include <exception> #include <exception>
#include <fstream>
#include <iostream> #include <iostream>
#include <stdexcept> #include <stdexcept>
@ -21,13 +22,20 @@ int main(int argc, char** argv) {
return 0; return 0;
} }
auto antlr = ParseFileWithAntlr(opts.input); // 确定输出流
bool need_blank_line = false; std::ofstream ofs;
if (opts.emit_parse_tree) { std::ostream* out = &std::cout;
PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout); if (!opts.output.empty()) {
need_blank_line = true; ofs.open(opts.output);
if (!ofs) {
throw std::runtime_error(
FormatError("main", "无法打开输出文件: " + opts.output));
}
out = &ofs;
} }
auto antlr = ParseFileWithAntlr(opts.input);
#if !COMPILER_PARSE_ONLY #if !COMPILER_PARSE_ONLY
auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree); auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree);
if (!comp_unit) { if (!comp_unit) {
@ -36,23 +44,23 @@ int main(int argc, char** argv) {
auto sema = RunSema(*comp_unit); auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema); auto module = GenerateIR(*comp_unit, sema);
if (opts.opt) {
ir::RunPasses(*module);
}
if (opts.emit_ir) { if (opts.emit_ir) {
ir::IRPrinter printer; ir::IRPrinter printer;
if (need_blank_line) { printer.Print(*module, *out);
std::cout << "\n";
}
printer.Print(*module, std::cout);
need_blank_line = true;
} }
if (opts.emit_asm) { if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module); auto machine_funcs = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func); for (auto& mf : machine_funcs) {
mir::RunFrameLowering(*machine_func); mir::RunRegAlloc(*mf);
if (need_blank_line) { mir::RunFrameLowering(*mf);
std::cout << "\n";
} }
mir::PrintAsm(*machine_func, std::cout); mir::PrintAsm(machine_funcs, *out);
} }
#else #else
if (opts.emit_ir || opts.emit_asm) { if (opts.emit_ir || opts.emit_asm) {

@ -2,9 +2,16 @@
#include <ostream> #include <ostream>
#include <stdexcept> #include <stdexcept>
#include <iostream>
#include <vector>
#include <unordered_map>
#include "utils/Log.h" #include "utils/Log.h"
// 引用全局变量(定义在 Lowering.cpp 中)
extern std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir { namespace mir {
namespace { namespace {
@ -16,63 +23,624 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex()); return function.GetFrameSlot(operand.GetFrameIndex());
} }
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, // 32位整数加载/存储
int offset) { void EmitStackLoad(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::S0) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset if (offset >= -2048 && offset <= 2047) {
<< "]\n"; os << " lw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " lw " << PhysRegName(dst) << ", 0(t4)\n";
}
} }
} // namespace void EmitStackStore(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::S0) {
if (offset >= -2048 && offset <= 2047) {
os << " sw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " sw " << PhysRegName(src) << ", 0(t4)\n";
}
}
void PrintAsm(const MachineFunction& function, std::ostream& os) { // 64位指针加载/存储
os << ".text\n"; void EmitStackLoad64(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::S0) {
os << ".global " << function.GetName() << "\n"; if (offset >= -2048 && offset <= 2047) {
os << ".type " << function.GetName() << ", %function\n"; os << " ld " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
os << function.GetName() << ":\n"; } else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " ld " << PhysRegName(dst) << ", 0(t4)\n";
}
}
for (const auto& inst : function.GetEntry().GetInstructions()) { void EmitStackStore64(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::S0) {
if (offset >= -2048 && offset <= 2047) {
os << " sd " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " sd " << PhysRegName(src) << ", 0(t4)\n";
}
}
// 浮点加载/存储保持32位
void EmitStackLoadFloat(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::S0) {
if (offset >= -2048 && offset <= 2047) {
os << " flw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " flw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
void EmitStackStoreFloat(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::S0) {
if (offset >= -2048 && offset <= 2047) {
os << " fsw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " fsw " << PhysRegName(src) << ", 0(t4)\n";
}
}
// 输出单个函数的汇编
void PrintAsmFunction(const MachineFunction& function, std::ostream& os) {
// 收集所有基本块名称
std::unordered_map<const MachineBasicBlock*, std::string> block_names;
for (const auto& block_ptr : function.GetBlocks()) {
block_names[block_ptr.get()] = block_ptr->GetName();
}
int frame_size = function.GetFrameSize(); // 局部变量区大小(正数)
int local_vars = function.GetLocalVarsSize();
int total_frame = local_vars + 16 ;
bool prologue_done = false;
for (const auto& block_ptr : function.GetBlocks()) {
const auto& block = *block_ptr;
// 输出基本块标签(入口块不输出,因为函数名已经是标签)
if (block.GetName() != "entry") {
os << block.GetName() << ":\n";
}
for (const auto& inst : block.GetInstructions()) {
const auto& ops = inst.GetOperands(); const auto& ops = inst.GetOperands();
// 在入口块的第一条指令前输出序言
if (!prologue_done && block.GetName() == "entry") {
// 分配栈帧sp -= total_frame
if (total_frame <= 2047) {
os << " addi sp, sp, -" << total_frame << "\n";
} else {
os << " li t4, -" << total_frame << "\n";
os << " add sp, sp, t4\n";
}
// 保存 ra 和 s0在局部变量区之后即 sp + frame_size 处)
// ra 保存在 sp + frame_size
// s0 保存在 sp + frame_size + 8
int ra_offset = local_vars;
int s0_offset = local_vars + 8;
if (ra_offset <= 2047) {
os << " sd ra, " << ra_offset << "(sp)\n";
} else {
os << " li t4, " << ra_offset << "\n";
os << " add t4, sp, t4\n";
os << " sd ra, 0(t4)\n";
}
if (s0_offset <= 2047) {
os << " sd s0, " << s0_offset << "(sp)\n";
} else {
os << " li t4, " << s0_offset << "\n";
os << " add t4, sp, t4\n";
os << " sd s0, 0(t4)\n";
}
os << " mv s0, sp\n";
prologue_done = true;
}
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case Opcode::Prologue: case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
break;
case Opcode::Epilogue: case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
}
os << " ldp x29, x30, [sp], #16\n";
break; break;
case Opcode::MovImm: case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" os << " li " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetImm() << "\n"; << ops.at(1).GetImm() << "\n";
break; break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1)); case Opcode::Load: {
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " ld " << PhysRegName(ops[0].GetReg()) << ", 0(" << PhysRegName(ops[1].GetReg()) << ")\n";
} /*else if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Imm) {
// 用于调用者 outgoing 存储的占位偏移(将在 Outgoing 中修正)
int offset = ops[1].GetImm(); // 实际偏移 = local_vars + 16 + offset
os << " ld " << PhysRegName(ops[0].GetReg()) << ", " << offset << "(sp)\n";*/
else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
if (slot.size == 8) EmitStackLoad64(os, ops[0].GetReg(), slot.offset);
else EmitStackLoad(os, ops[0].GetReg(), slot.offset);
}
break;
}
case Opcode::Store: {
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " sd " << PhysRegName(ops[0].GetReg()) << ", 0(" << PhysRegName(ops[1].GetReg()) << ")\n";
} /*else if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Imm) {
int offset = ops[1].GetImm();
// 实际偏移需在 AsmPrinter 中加上 local_vars+16这里简单先直接用 offset动态修正稍复杂
// 临时方案:直接生成 sw t0, offset(sp),但 offset 应为 local_vars+16=?
// 由于 AsmPrinter 中可访问 function.GetLocalVarsSize(),我们计算:
int actual_offset = function.GetLocalVarsSize() + 16 + offset;
if (actual_offset <= 2047) os << " sd " << PhysRegName(ops[0].GetReg()) << ", " << actual_offset << "(sp)\n";
else { }*/
else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
if (slot.size == 8) EmitStackStore64(os, ops[0].GetReg(), slot.offset);
else EmitStackStore(os, ops[0].GetReg(), slot.offset);
}
break;
}
case Opcode::LoadCallerStackArg: {
// ops: [0] dst (T0), [1] dstFrameIndex, [2] argvIndex (Imm)
int argv_index = ops[2].GetImm();
int dst_slot = ops[1].GetFrameIndex();
int total_frame = function.GetFrameSize();
// 调用者栈参数位于 sp + total_frame + argv_index*8
int caller_offset = total_frame + argv_index * 8;
// 加载到 T0
if (caller_offset <= 2047) {
os << " ld " << PhysRegName(ops[0].GetReg()) << ", " << caller_offset << "(s0)\n";
} else {
os << " li t4, " << caller_offset << "\n";
os << " add t4, s0, t4\n";
os << " ld " << PhysRegName(ops[0].GetReg()) << ", 0(t4)\n";
}
// 再存入本地槽
const auto& slot = function.GetFrameSlot(dst_slot);
if (slot.size == 8) {
EmitStackStore64(os, ops[0].GetReg(), slot.offset);
} else {
EmitStackStore(os, ops[0].GetReg(), slot.offset);
}
break; break;
} }
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1)); case Opcode::LoadCallerStackArgFloat: {
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); int argv_index = ops.at(2).GetImm();
int dst_slot = ops.at(1).GetFrameIndex();
int total_frame = function.GetFrameSize();
int caller_offset = total_frame + argv_index * 8;
// 使用 s0 保证稳定
if (caller_offset <= 2047) {
os << " flw " << PhysRegName(ops.at(0).GetReg()) << ", " << caller_offset << "(s0)\n";
} else {
os << " li t4, " << caller_offset << "\n";
os << " add t4, s0, t4\n";
os << " flw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t4)\n";
}
const auto& slot = function.GetFrameSlot(dst_slot);
EmitStackStoreFloat(os, ops.at(0).GetReg(), slot.offset);
break; break;
} }
case Opcode::AddRR:
case Opcode::Add:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n"; << PhysRegName(ops.at(2).GetReg()) << "\n";
break; break;
case Opcode::Ret: case Opcode::Addi:
os << " addi " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< ops[2].GetImm() << "\n";
break;
case Opcode::Sub:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Mul: {
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
} else {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
}
break;
}
case Opcode::Div:
os << " div " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Rem:
os << " rem " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slt:
os << " slt " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slti:
os << " slti " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Sltu:
os << " sltu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sltiu:
os << " sltiu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Xori:
os << " xori " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::LoadGlobalAddr: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la " << PhysRegName(ops.at(0).GetReg()) << ", " << global_name << "\n";
break;
}
case Opcode::LoadGlobal:
// 全局变量加载 - 使用 lw32位
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::StoreGlobal: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la t1, " << global_name << "\n";
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t1)\n";
break;
}
case Opcode::GEP:
break;
case Opcode::LoadIndirect:
// 间接加载 - 使用 lw32位
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::LoadIndirectFloat:
os << " flw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
break;
case Opcode::Call: {
std::string func_name = "memset"; // 默认值
if (!ops.empty() && ops[0].GetKind() == Operand::Kind::Func) {
func_name = ops[0].GetFuncName();
}
os << " call " << func_name << "\n";
break;
}
case Opcode::LoadAddr: {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
// 计算地址64 位offset 是正数
if (slot.offset <= 2047) {
os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", s0, " << slot.offset << "\n";
} else {
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", " << slot.offset << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", s0, "
<< PhysRegName(ops.at(0).GetReg()) << "\n";
}
break;
}
case Opcode::Slli:
os << " slli " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::StoreIndirect:
// 间接存储 - 使用 sw32位
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::StoreIndirectFloat:
os << " fsw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
break;
case Opcode::Ret:{
// 恢复 ra 和 s0
int ra_offset = local_vars;
int s0_offset = local_vars + 8;
if (ra_offset <= 2047) {
os << " ld ra, " << ra_offset << "(s0)\n";
} else {
os << " li t3, " << ra_offset << "\n";
os << " add t3, s0, t3\n";
os << " ld ra, 0(t3)\n";
}
// 恢复 sp
if (total_frame <= 2047) {
os << " addi sp, s0, " << total_frame << "\n";
} else {
os << " li t3, " << total_frame << "\n";
os << " add sp, s0, t3\n";
}
if (s0_offset <= 2047) {
os << " ld s0, " << s0_offset << "(s0)\n";
} else {
os << " li t3, " << s0_offset << "\n";
os << " add t3, s0, t3\n";
os << " ld s0, 0(t3)\n";
}
os << " ret\n"; os << " ret\n";
break; break;
} }
case Opcode::Br: {
auto* target = reinterpret_cast<MachineBasicBlock*>(ops[0].GetImm64());
os << " j " << target->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* true_target = reinterpret_cast<MachineBasicBlock*>(ops[1].GetImm64());
auto* false_target = reinterpret_cast<MachineBasicBlock*>(ops[2].GetImm64());
auto true_it = block_names.find(true_target);
auto false_it = block_names.find(false_target);
if (true_it == block_names.end() || false_it == block_names.end()) {
throw std::runtime_error(FormatError("mir", "CondBr: 找不到基本块名称"));
}
// 生成一个唯一的本地标签作为跳板
static int condbr_id = 0;
std::string temp_label = ".L_condbr_" + std::to_string(condbr_id++);
os << " bnez " << PhysRegName(ops[0].GetReg()) << ", " << temp_label << "\n";
os << " j " << false_it->second << "\n";
os << temp_label << ":\n";
os << " j " << true_it->second << "\n";
break;
} }
os << ".size " << function.GetName() << ", .-" << function.GetName() // 浮点运算
<< "\n"; case Opcode::FAdd:
os << " fadd.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FSub:
os << " fsub.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMul:
os << " fmul.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FDiv:
os << " fdiv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FEq:
os << " feq.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLt:
os << " flt.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLe:
os << " fle.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMov:
os << " fmv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovWX:
os << " fmv.w.x " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovXW:
os << " fmv.x.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " fcvt.s.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvt.w.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", rtz\n";
break;
case Opcode::LoadFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " flw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoadFloat(os, ops[0].GetReg(), slot.offset);
}
break;
case Opcode::StoreFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " fsw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStoreFloat(os, ops[0].GetReg(), slot.offset);
}
break;
default:
break;
}
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 输出多个函数的汇编
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os) {
// ========== 输出全局变量 ==========
// .data 段:非常量变量
bool hasData = false;
for (const auto& gv : g_globalVars) {
if (gv.isConst) continue; // 常量到 .rodata
if (!hasData) {
os << ".data\n";
hasData = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (gv.isFloat) {
if (!gv.arrayValuesF.empty()) {
for (float val : gv.arrayValuesF) {
union { float f; uint32_t i; } u;
u.f = val;
os << " .word " << u.i << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) os << " .word 0\n";
}
} else {
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) os << " .word " << val << "\n";
} else {
for (int i = 0; i < gv.arraySize; i++) os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
if (gv.isFloat) {
union { float f; uint32_t i; } u;
u.f = gv.valueF;
os << " .word " << u.i << "\n";
} else {
os << " .word " << gv.value << "\n";
}
}
}
// .rodata 段:只读常量
bool hasRodata = false;
for (const auto& gv : g_globalVars) {
if (!gv.isConst) continue;
if (!hasRodata) {
os << ".section .rodata\n";
hasRodata = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (gv.isFloat) {
if (!gv.arrayValuesF.empty()) {
for (float val : gv.arrayValuesF) {
union { float f; uint32_t i; } u;
u.f = val;
os << " .word " << u.i << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) os << " .word 0\n";
}
} else {
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) os << " .word " << val << "\n";
} else {
for (int i = 0; i < gv.arraySize; i++) os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
if (gv.isFloat) {
union { float f; uint32_t i; } u;
u.f = gv.valueF;
os << " .word " << u.i << "\n";
} else {
os << " .word " << gv.value << "\n";
}
}
}
// ========== 输出代码段 ==========
os << ".text\n";
// 输出每个函数
for (const auto& func_ptr : functions) {
os << ".global " << func_ptr->GetName() << "\n";
os << ".type " << func_ptr->GetName() << ", @function\n";
os << func_ptr->GetName() << ":\n";
PrintAsmFunction(*func_ptr, os);
os << "\n"; // 函数之间加空行
}
} }
} // namespace mir } // namespace mir

@ -15,6 +15,8 @@ target_link_libraries(mir_core PUBLIC
ir ir
) )
target_compile_options(mir_core PRIVATE -Wno-unused-parameter)
add_subdirectory(passes) add_subdirectory(passes)
add_library(mir INTERFACE) add_library(mir INTERFACE)

@ -16,21 +16,25 @@ int AlignTo(int value, int align) {
void RunFrameLowering(MachineFunction& function) { void RunFrameLowering(MachineFunction& function) {
int cursor = 0; int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) { const auto& slots = function.GetFrameSlots();
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
}
cursor = 0; // 为每个栈槽分配偏移
for (const auto& slot : function.GetFrameSlots()) { for (const auto& slot : slots) {
int align = slot.size;
cursor = AlignTo(cursor, align);
function.GetFrameSlot(slot.index).offset = cursor;
cursor += slot.size; cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
} }
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions(); // 局部变量区按 16 字节对齐
int local_vars_size = AlignTo(cursor, 16);
function.SetLocalVarsSize(local_vars_size);
// 总帧大小 = 局部变量区 + 16保存 ra 和 s0
function.SetFrameSize(local_vars_size + 16);
// 插入 Prologue/Epilogue 占位符(原逻辑)
auto& insts = function.GetEntry()->GetInstructions();
std::vector<MachineInstr> lowered; std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue); lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) { for (const auto& inst : insts) {

@ -1,123 +1,828 @@
#include "mir/MIR.h" #include "mir/MIR.h"
#include <iostream>
#include <stdexcept> #include <stdexcept>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include <cstring>
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir { namespace mir {
namespace { namespace {
static bool IsFloatReg(PhysReg reg) {
switch (reg) {
case PhysReg::FT0: case PhysReg::FT1: case PhysReg::FT2: case PhysReg::FT3:
case PhysReg::FT4: case PhysReg::FT5: case PhysReg::FT6: case PhysReg::FT7:
case PhysReg::FT8: case PhysReg::FT9: case PhysReg::FT10: case PhysReg::FT11:
case PhysReg::FA0: case PhysReg::FA1: case PhysReg::FA2: case PhysReg::FA3:
case PhysReg::FA4: case PhysReg::FA5: case PhysReg::FA6: case PhysReg::FA7:
case PhysReg::FS0: case PhysReg::FS1:
return true;
default:
return false;
}
}
using ValueSlotMap = std::unordered_map<const ir::Value*, int>; using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
static std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
MachineBasicBlock* GetOrCreateBlock(const ir::BasicBlock* ir_block,
MachineFunction& function) {
auto it = block_map.find(ir_block);
if (it != block_map.end()) {
return it->second;
}
std::string name = ir_block->GetName();
if (name.empty()) {
name = "block_" + std::to_string(block_map.size());
}
auto* block = function.CreateBlock(name);
block_map[ir_block] = block;
return block;
}
void EmitValueToReg(const ir::Value* value, PhysReg target, void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) { const ValueSlotMap& slots, MachineBasicBlock& block,
MachineFunction& function, bool for_address = false){
// 处理参数Argument
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
auto it = slots.find(arg);
if (it != slots.end()) {
bool src_is_float = value->GetType()->IsFloat32();
bool dst_is_float = IsFloatReg(target);
if (src_is_float == dst_is_float) {
// 同类型 → 直接加载,不转换
if (src_is_float)
block.Append(Opcode::LoadFloat, {Operand::Reg(target), Operand::FrameIndex(it->second)});
else
block.Append(Opcode::Load, {Operand::Reg(target), Operand::FrameIndex(it->second)});
} else if (src_is_float && !dst_is_float) {
// 浮点 -> 整数
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(target), Operand::Reg(PhysReg::FT0)});
} else if (!src_is_float && dst_is_float) {
// 整数 -> 浮点
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(it->second)});
block.Append(Opcode::SIToFP, {Operand::Reg(target), Operand::Reg(PhysReg::T0)});
}
return;
}
}
// 处理整数常量
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) { if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
int64_t val = constant->GetValue();
block.Append(Opcode::MovImm, block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())}); {Operand::Reg(target), Operand::Imm(static_cast<int>(val))});
return;
}
// 处理浮点常量
if (auto* fconstant = dynamic_cast<const ir::ConstantFloat*>(value)) {
// 直接使用标准的 double -> float 转换,无需特殊分支
float fval = static_cast<float>(fconstant->GetValue());
uint32_t bits;
std::memcpy(&bits, &fval, sizeof(fval));
int32_t imm = static_cast<int32_t>(bits);
if (IsFloatReg(target)) {
// 通过栈槽加载以保证浮点寄存器符合 NaNboxing 要求
int tmp_slot = function.CreateFrameIndex(4);
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::T4), Operand::Imm(imm)});
block.Append(Opcode::Store, {Operand::Reg(PhysReg::T4),
Operand::FrameIndex(tmp_slot)});
block.Append(Opcode::LoadFloat, {Operand::Reg(target),
Operand::FrameIndex(tmp_slot)});
} else {
block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(imm)});
}
return;
}
// 处理 GEP 指令
if (auto* gep = dynamic_cast<const ir::GepInst*>(value)) {
EmitValueToReg(gep->GetBasePtr(), target, slots, block,function, true);
EmitValueToReg(gep->GetIndex(), PhysReg::T1, slots, block,function);
block.Append(Opcode::Slli, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Imm(2)});
block.Append(Opcode::Add, {Operand::Reg(target),
Operand::Reg(target),
Operand::Reg(PhysReg::T1)});
return;
}
// 处理 Alloca 指令
if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(value)) {
auto it = slots.find(alloca);
if (it != slots.end()) {
block.Append(Opcode::LoadAddr,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
}
// 处理全局变量
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(value)) {
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(target), Operand::Global(global->GetName())});
if (!for_address) {
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::Reg(target)});
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Reg(target)});
}
}
return; return;
} }
// 处理一般栈槽中的值
auto it = slots.find(value); auto it = slots.find(value);
if (it == slots.end()) { if (it != slots.end()) {
bool src_is_float = value->GetType()->IsFloat32();
bool dst_is_float = IsFloatReg(target);
if (src_is_float && !dst_is_float) {
// 浮点 -> 整数
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(target), Operand::Reg(PhysReg::FT0)});
} else if (!src_is_float && dst_is_float) {
// 整数 -> 浮点
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(it->second)});
block.Append(Opcode::SIToFP, {Operand::Reg(target), Operand::Reg(PhysReg::T0)});
} else {
// 同类型直接加载
if (src_is_float) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
}
return;
}
// 如果以上都未找到,报错
std::cerr << "未找到的值: " << value << std::endl;
std::cerr << " 名称: " << value->GetName() << std::endl;
std::cerr << " 类型: " << (value->GetType()->IsFloat32() ? "float" : "int") << std::endl;
throw std::runtime_error( throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
} }
block.Append(Opcode::LoadStack, void StoreRegToSlot(PhysReg reg, int slot, MachineBasicBlock& block, bool isFloat = false) {
{Operand::Reg(target), Operand::FrameIndex(it->second)}); if (isFloat) {
block.Append(Opcode::StoreFloat,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
} else {
block.Append(Opcode::Store,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
}
} }
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, // 将 LowerInstruction 重命名为 LowerInstructionToBlock并添加 MachineBasicBlock 参数
ValueSlotMap& slots) { void LowerInstructionToBlock(const ir::Instruction& inst, MachineFunction& function,
auto& block = function.GetEntry(); ValueSlotMap& slots, MachineBasicBlock& block) {
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: { case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex()); auto& alloca = static_cast<const ir::AllocaInst&>(inst);
int size = 4;
if (alloca.GetNumElements() > 1) {
size = alloca.GetNumElements() * 4;
}
slots.emplace(&inst, function.CreateFrameIndex(size));
return; return;
} }
case ir::Opcode::Store: { case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst); auto& store = static_cast<const ir::StoreInst&>(inst);
// 如果指针是 GEP手动生成地址并 store避免额外计算
if (auto* gep = dynamic_cast<const ir::GepInst*>(store.GetPtr())) {
// 判断值的类型是否为浮点
bool val_is_float = store.GetValue()->GetType()->IsFloat32();
if (val_is_float) {
// 将浮点值加载到 FT0
EmitValueToReg(store.GetValue(), PhysReg::FT0, slots, block, function);
} else {
// 整数值加载到 T2
EmitValueToReg(store.GetValue(), PhysReg::T2, slots, block, function);
}
// 计算基址 + 索引*4
EmitValueToReg(gep->GetBasePtr(), PhysReg::T0, slots, block, function, true);
auto idx_it = slots.find(gep->GetIndex());
if (idx_it != slots.end()) {
block.Append(Opcode::Load, {Operand::Reg(PhysReg::T1), Operand::FrameIndex(idx_it->second)});
} else {
EmitValueToReg(gep->GetIndex(), PhysReg::T1, slots, block, function);
}
block.Append(Opcode::Slli, {Operand::Reg(PhysReg::T1), Operand::Reg(PhysReg::T1), Operand::Imm(2)});
block.Append(Opcode::Add, {Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T1)});
// 使用正确的间接存储操作码
if (val_is_float) {
block.Append(Opcode::StoreIndirectFloat, {Operand::Reg(PhysReg::FT0), Operand::Reg(PhysReg::T0)});
} else {
block.Append(Opcode::StoreIndirect, {Operand::Reg(PhysReg::T2), Operand::Reg(PhysReg::T0)});
}
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block, function);
std::string global_name = global->GetName();
block.Append(Opcode::StoreGlobal, {Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
return;
}
auto dst = slots.find(store.GetPtr()); auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) { if (dst != slots.end()) {
throw std::runtime_error( bool val_is_float = store.GetValue()->GetType()->IsFloat32();
FormatError("mir", "暂不支持对非栈变量地址进行写入")); if (val_is_float) {
EmitValueToReg(store.GetValue(), PhysReg::FT0, slots, block, function);
StoreRegToSlot(PhysReg::FT0, dst->second, block, true);
} else {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block, function);
StoreRegToSlot(PhysReg::T0, dst->second, block, false);
} }
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
return; return;
} }
throw std::runtime_error(FormatError("mir", "Store: 无法处理的指针类型"));
}
case ir::Opcode::Load: { case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst); auto& load = static_cast<const ir::LoadInst&>(inst);
if (auto* gep = dynamic_cast<const ir::GepInst*>(load.GetPtr())) {
// 计算地址到 T0
EmitValueToReg(gep->GetBasePtr(), PhysReg::T0, slots, block, function, true);
auto idx_it = slots.find(gep->GetIndex());
if (idx_it != slots.end()) {
block.Append(Opcode::Load, {Operand::Reg(PhysReg::T1), Operand::FrameIndex(idx_it->second)});
} else {
EmitValueToReg(gep->GetIndex(), PhysReg::T1, slots, block, function);
}
block.Append(Opcode::Slli, {Operand::Reg(PhysReg::T1), Operand::Reg(PhysReg::T1), Operand::Imm(2)});
block.Append(Opcode::Add, {Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T1)});
bool load_is_float = load.GetType()->IsFloat32();
int dst_slot = function.CreateFrameIndex(4);
if (load_is_float) {
// 浮点加载FT0 = [T0]
block.Append(Opcode::LoadIndirectFloat, {Operand::Reg(PhysReg::FT0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
} else {
// 整数加载
block.Append(Opcode::LoadIndirect, {Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
if (dynamic_cast<const ir::GepInst*>(load.GetPtr())) {
EmitValueToReg(load.GetPtr(), PhysReg::T0, slots, block, function, true);
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
int dst_slot = function.CreateFrameIndex(4);
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(load.GetPtr())) {
int dst_slot = function.CreateFrameIndex(4);
std::string global_name = global->GetName();
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, true);
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
auto src = slots.find(load.GetPtr()); auto src = slots.find(load.GetPtr());
if (src == slots.end()) { if (src != slots.end()) {
throw std::runtime_error( int dst_slot = function.CreateFrameIndex(4);
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
if (load.GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
} else {
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
} }
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot); slots.emplace(&inst, dst_slot);
return; return;
} }
case ir::Opcode::Add: {
throw std::runtime_error(FormatError("mir", "Load: 无法处理的指针类型"));
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod:
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto& bin = static_cast<const ir::BinaryInst&>(inst); auto& bin = static_cast<const ir::BinaryInst&>(inst);
bool lhs_is_float = bin.GetLhs()->GetType()->IsFloat32();
bool rhs_is_float = bin.GetRhs()->GetType()->IsFloat32();
bool result_is_float = lhs_is_float || rhs_is_float;
int dst_slot = function.CreateFrameIndex(4);
if (result_is_float) {
EmitValueToReg(bin.GetLhs(), PhysReg::FT0, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::FT1, slots, block, function);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::Add: case ir::Opcode::FAdd: op = Opcode::FAdd; break;
case ir::Opcode::Sub: case ir::Opcode::FSub: op = Opcode::FSub; break;
case ir::Opcode::Mul: case ir::Opcode::FMul: op = Opcode::FMul; break;
case ir::Opcode::Div: case ir::Opcode::FDiv: op = Opcode::FDiv; break;
default: op = Opcode::FAdd; break;
}
block.Append(op, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::T0, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::T1, slots, block, function);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::Add: op = Opcode::Add; break;
case ir::Opcode::Sub: op = Opcode::Sub; break;
case ir::Opcode::Mul: op = Opcode::Mul; break;
case ir::Opcode::Div: op = Opcode::Div; break;
case ir::Opcode::Mod: op = Opcode::Rem; break;
default: op = Opcode::Add; break;
}
block.Append(op, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Gep: {
int dst_slot = function.CreateFrameIndex(); int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot); slots.emplace(&inst, dst_slot);
return; return;
} }
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst); case ir::Opcode::Call: {
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block); auto& call = static_cast<const ir::CallInst&>(inst);
block.Append(Opcode::Ret); int numArgs = static_cast<int>(call.GetNumArgs());
int ireg = 0, freg = 0;
std::vector<std::pair<bool, ir::Value*>> stack_args; // (is_float, value)
for (int i = 0; i < numArgs; ++i) {
bool arg_is_float = call.GetArg(i)->GetType()->IsFloat32();
if (arg_is_float) {
if (freg < 8) {
PhysReg fregnum = static_cast<PhysReg>(static_cast<int>(PhysReg::FA0) + freg);
EmitValueToReg(call.GetArg(i), fregnum, slots, block, function);
freg++;
} else {
stack_args.push_back({true, call.GetArg(i)});
}
} else { // integer or pointer
if (ireg < 8) {
PhysReg iregnum = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + ireg);
EmitValueToReg(call.GetArg(i), iregnum, slots, block, function);
ireg++;
} else {
stack_args.push_back({false, call.GetArg(i)});
}
}
}
int stackArgs = static_cast<int>(stack_args.size());
if (stackArgs > 0) {
int stackSpace = (stackArgs * 8 + 15) & ~15;
// sp -= stackSpace
if (stackSpace <= 2047) {
block.Append(Opcode::Addi, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Imm(-stackSpace)});
} else {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::T4), Operand::Imm(-stackSpace)});
block.Append(Opcode::Add, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::T4)});
}
for (int idx = 0; idx < stackArgs; ++idx) {
bool is_float = stack_args[idx].first;
ir::Value* val = stack_args[idx].second;
int offset = idx * 8;
// 1. 先加载值
if (is_float) {
EmitValueToReg(val, PhysReg::FT0, slots, block, function);
} else {
EmitValueToReg(val, PhysReg::T0, slots, block, function);
}
// 2. 再计算栈地址到 T4
if (offset == 0) {
block.Append(Opcode::Add, {Operand::Reg(PhysReg::T4),
Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::ZERO)});
} else if (offset <= 2047) {
block.Append(Opcode::Addi, {Operand::Reg(PhysReg::T4),
Operand::Reg(PhysReg::SP),
Operand::Imm(offset)});
} else {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::T4), Operand::Imm(offset)});
block.Append(Opcode::Add, {Operand::Reg(PhysReg::T4),
Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::T4)});
}
// 3. 存储
if (is_float) {
block.Append(Opcode::StoreFloat, {Operand::Reg(PhysReg::FT0), Operand::Reg(PhysReg::T4)});
} else {
block.Append(Opcode::Store, {Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T4)});
}
}
}
// 调用目标
std::string func_name = call.GetCalleeName();
block.Append(Opcode::Call, {Operand::Func(func_name)});
// 恢复 sp
if (stackArgs > 0) {
int stackSpace = (stackArgs * 8 + 15) & ~15;
if (stackSpace <= 2047) {
block.Append(Opcode::Addi, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Imm(stackSpace)});
} else {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::T4), Operand::Imm(stackSpace)});
block.Append(Opcode::Add, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::T4)});
}
}
// 返回值处理(原有代码保持不变)
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
bool ret_is_float = call.GetType()->IsFloat32();
if (ret_is_float) {
StoreRegToSlot(PhysReg::FA0, dst_slot, block, true);
} else {
StoreRegToSlot(PhysReg::A0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
}
return; return;
} }
case ir::Opcode::Sub:
case ir::Opcode::Mul: case ir::Opcode::ICmp: {
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算")); auto& icmp = static_cast<const ir::ICmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(icmp.GetLhs(), PhysReg::T0, slots, block, function);
EmitValueToReg(icmp.GetRhs(), PhysReg::T1, slots, block, function);
ir::ICmpPredicate pred = icmp.GetPredicate();
switch (pred) {
case ir::ICmpPredicate::EQ:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::NE:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SLT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
break;
case ir::ICmpPredicate::SGT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
break;
case ir::ICmpPredicate::SGE:
// lhs >= rhs 等价于 !(lhs < rhs)
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SLE:
// lhs <= rhs 等价于 !(rhs < lhs)
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)}); // 注意操作数顺序rhs < lhs
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
} }
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
} }
case ir::Opcode::ZExt: {
auto& zext = static_cast<const ir::ZExtInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
} // namespace EmitValueToReg(zext.GetSrc(), PhysReg::T0, slots, block, function);
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) { case ir::Opcode::FCmp: {
DefaultContext(); auto& fcmp = static_cast<const ir::FCmpInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(fcmp.GetLhs(), PhysReg::FT0, slots, block, function);
EmitValueToReg(fcmp.GetRhs(), PhysReg::FT1, slots, block, function);
ir::FCmpPredicate pred = fcmp.GetPredicate();
switch (pred) {
case ir::FCmpPredicate::OEQ:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::ONE:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
case ir::FCmpPredicate::OLT:
block.Append(Opcode::FLt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OGT:
block.Append(Opcode::FLt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT1),
Operand::Reg(PhysReg::FT0)});
break;
case ir::FCmpPredicate::OLE:
block.Append(Opcode::FLe, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OGE:
block.Append(Opcode::FLe, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT1),
Operand::Reg(PhysReg::FT0)});
break;
default:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
}
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::SIToFP: {
auto& conv = static_cast<const ir::SIToFPInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
// 直接加载源操作数到 T0不依赖 slots 中是否存在
EmitValueToReg(conv.GetSrc(), PhysReg::T0, slots, block, function);
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FPToSI: {
auto& conv = static_cast<const ir::FPToSIInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
if (module.GetFunctions().size() != 1) { // 直接加载源操作数到 FT0不依赖 slots
throw std::runtime_error(FormatError("mir", "暂不支持多个函数")); EmitValueToReg(conv.GetSrc(), PhysReg::FT0, slots, block, function);
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BrInst&>(inst);
auto* target = br.GetTarget();
MachineBasicBlock* target_block = GetOrCreateBlock(target, function);
block.Append(Opcode::Br, {Operand::Imm64(reinterpret_cast<intptr_t>(target_block))});
return;
}
case ir::Opcode::CondBr: {
auto& condbr = static_cast<const ir::CondBrInst&>(inst);
auto* true_bb = condbr.GetTrueBB();
auto* false_bb = condbr.GetFalseBB();
EmitValueToReg(condbr.GetCond(), PhysReg::T0, slots, block, function);
MachineBasicBlock* true_block = GetOrCreateBlock(true_bb, function);
MachineBasicBlock* false_block = GetOrCreateBlock(false_bb, function);
block.Append(Opcode::CondBr, {Operand::Reg(PhysReg::T0),
Operand::Imm64(reinterpret_cast<intptr_t>(true_block)),
Operand::Imm64(reinterpret_cast<intptr_t>(false_block))});
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
auto val = ret.GetValue();
if (val->GetType()->IsFloat32()) {
EmitValueToReg(val, PhysReg::FA0, slots, block, function);
} else {
EmitValueToReg(val, PhysReg::A0, slots, block, function);
}
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::A0), Operand::Imm(0)});
}
block.Append(Opcode::Ret);
return;
} }
const auto& func = *module.GetFunctions().front(); default: {
if (func.GetName() != "main") { break;
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数")); }
}
} }
} // namespace
std::unique_ptr<MachineFunction> LowerFunctionToMIR(const ir::Function& func) {
block_map.clear();
auto machine_func = std::make_unique<MachineFunction>(func.GetName()); auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots; ValueSlotMap slots;
const auto* entry = func.GetEntry(); int ireg = 0, freg = 0, stack_idx = 0;
if (!entry) { for (size_t i = 0; i < func.GetNumArgs(); i++) {
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块")); ir::Argument* arg = func.GetArgument(i);
int size = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32()) ? 8 : 4;
int slot = machine_func->CreateFrameIndex(size);
MachineBasicBlock* entry = machine_func->GetEntry();
if (arg->GetType()->IsFloat32()) {
if (freg < 8) {
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::FA0) + freg);
entry->Append(Opcode::StoreFloat, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
freg++;
} else {
entry->Append(Opcode::LoadCallerStackArgFloat, {
Operand::Reg(PhysReg::FT0),
Operand::FrameIndex(slot),
Operand::Imm(stack_idx)
});
stack_idx++;
}
} else {
if (ireg < 8) {
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + ireg);
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
ireg++;
} else {
entry->Append(Opcode::LoadCallerStackArg, {
Operand::Reg(PhysReg::T0),
Operand::FrameIndex(slot),
Operand::Imm(stack_idx)
});
stack_idx++;
}
}
slots[arg] = slot;
} }
for (const auto& inst : entry->GetInstructions()) { // 第一遍:创建所有 IR 基本块对应的 MIR 基本块
LowerInstruction(*inst, *machine_func, slots); for (const auto& ir_block : func.GetBlocks()) {
GetOrCreateBlock(ir_block.get(), *machine_func);
}
// 第二遍:遍历所有基本块,降低指令
for (const auto& ir_block : func.GetBlocks()) {
MachineBasicBlock* mbb = GetOrCreateBlock(ir_block.get(), *machine_func);
for (const auto& inst : ir_block->GetInstructions()) {
LowerInstructionToBlock(*inst, *machine_func, slots, *mbb);
}
} }
return machine_func; return machine_func;
} }
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module) {
DefaultContext();
// 收集全局变量
g_globalVars.clear();
for (const auto& global : module.GetGlobalVariables()) {
GlobalVarInfo info;
info.name = global->GetName();
info.isConst = global->IsConst();
info.isArray = global->IsArray();
info.arraySize = global->GetNumElements();
info.isFloat = global->IsFloat();
info.value = 0;
info.valueF = 0.0f;
if (info.isArray) {
if (info.isFloat) {
const auto& initVals = global->GetInitValsF();
for (float val : initVals) {
info.arrayValuesF.push_back(val);
}
} else {
if (global->HasInitVals()) {
const auto& initVals = global->GetInitVals();
for (int val : initVals) {
info.arrayValues.push_back(val);
}
}
}
} else {
if (info.isFloat) {
info.valueF = global->GetInitValF();
} else {
info.value = global->GetInitVal();
}
}
g_globalVars.push_back(info);
}
const auto& functions = module.GetFunctions();
if (functions.empty()) {
throw std::runtime_error(FormatError("mir", "模块中没有函数"));
}
std::vector<std::unique_ptr<MachineFunction>> result;
for (const auto& func : functions) {
auto machine_func = LowerFunctionToMIR(*func);
result.push_back(std::move(machine_func));
}
return result;
}
} // namespace mir } // namespace mir

@ -8,7 +8,16 @@
namespace mir { namespace mir {
MachineFunction::MachineFunction(std::string name) MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {} : name_(std::move(name)) {
entry_ = CreateBlock("entry");
}
MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) {
auto block = std::make_unique<MachineBasicBlock>(name);
auto* ptr = block.get();
blocks_.push_back(std::move(block));
return ptr;
}
int MachineFunction::CreateFrameIndex(int size) { int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size()); int index = static_cast<int>(frame_slots_.size());

@ -6,17 +6,33 @@ namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm) Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {} : kind_(kind), reg_(reg), imm_(imm) {}
Operand::Operand(Kind kind, PhysReg reg, int64_t imm64)
: kind_(kind), reg_(PhysReg::ZERO), imm_(0), imm64_(imm64) {}
// 新增构造函数
Operand::Operand(Kind kind, PhysReg reg, int imm, const std::string& name)
: kind_(kind), reg_(reg), imm_(imm), global_name_(name) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) { Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value); return Operand(Kind::Imm, PhysReg::ZERO, value);
}
Operand Operand::Imm64(int64_t value) {
return Operand(Kind::Imm, PhysReg::ZERO, value);
} }
Operand Operand::FrameIndex(int index) { Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index); return Operand(Kind::FrameIndex, PhysReg::ZERO, index);
} }
// 新增
Operand Operand::Global(const std::string& name) {
return Operand(Kind::Global, PhysReg::ZERO, 0, name);
}
Operand Operand::Func(const std::string& name) {
Operand op(Kind::Func, PhysReg::ZERO, 0);
op.func_name_ = name;
return op;
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands) MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {} : opcode_(opcode), operands_(std::move(operands)) {}

@ -9,12 +9,66 @@ namespace {
bool IsAllowedReg(PhysReg reg) { bool IsAllowedReg(PhysReg reg) {
switch (reg) { switch (reg) {
case PhysReg::W0: // 临时寄存器
case PhysReg::W8: case PhysReg::T0:
case PhysReg::W9: case PhysReg::T1:
case PhysReg::X29: case PhysReg::T2:
case PhysReg::X30: case PhysReg::T3:
case PhysReg::T4:
case PhysReg::T5:
case PhysReg::T6:
// 参数/返回值寄存器
case PhysReg::A0:
case PhysReg::A1:
case PhysReg::A2:
case PhysReg::A3:
case PhysReg::A4:
case PhysReg::A5:
case PhysReg::A6:
case PhysReg::A7:
// 保存寄存器
case PhysReg::S0:
case PhysReg::S1:
case PhysReg::S2:
case PhysReg::S3:
case PhysReg::S4:
case PhysReg::S5:
case PhysReg::S6:
case PhysReg::S7:
case PhysReg::S8:
case PhysReg::S9:
case PhysReg::S10:
case PhysReg::S11:
// 特殊寄存器
case PhysReg::ZERO:
case PhysReg::RA:
case PhysReg::SP: case PhysReg::SP:
case PhysReg::GP:
case PhysReg::TP:
case PhysReg::FT0:
case PhysReg::FT1:
case PhysReg::FT2:
case PhysReg::FT3:
case PhysReg::FT4:
case PhysReg::FT5:
case PhysReg::FT6:
case PhysReg::FT7:
case PhysReg::FT8:
case PhysReg::FT9:
case PhysReg::FT10:
case PhysReg::FT11:
// 浮点保存寄存器
case PhysReg::FS0:
case PhysReg::FS1:
// 浮点参数寄存器
case PhysReg::FA0:
case PhysReg::FA1:
case PhysReg::FA2:
case PhysReg::FA3:
case PhysReg::FA4:
case PhysReg::FA5:
case PhysReg::FA6:
case PhysReg::FA7:
return true; return true;
} }
return false; return false;
@ -23,7 +77,8 @@ bool IsAllowedReg(PhysReg reg) {
} // namespace } // namespace
void RunRegAlloc(MachineFunction& function) { void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) { // 修复GetEntry() 返回指针,使用 ->
for (const auto& inst : function.GetEntry()->GetInstructions()) {
for (const auto& operand : inst.GetOperands()) { for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg && if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) { !IsAllowedReg(operand.GetReg())) {

@ -8,18 +8,94 @@ namespace mir {
const char* PhysRegName(PhysReg reg) { const char* PhysRegName(PhysReg reg) {
switch (reg) { switch (reg) {
case PhysReg::W0: // 整数寄存器
return "w0"; case PhysReg::ZERO:
case PhysReg::W8: return "zero";
return "w8"; case PhysReg::RA:
case PhysReg::W9: return "ra";
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
case PhysReg::SP: case PhysReg::SP:
return "sp"; return "sp";
case PhysReg::GP:
return "gp";
case PhysReg::TP:
return "tp";
case PhysReg::T0:
return "t0";
case PhysReg::T1:
return "t1";
case PhysReg::T2:
return "t2";
case PhysReg::S0:
return "s0";
case PhysReg::S1:
return "s1";
case PhysReg::A0:
return "a0";
case PhysReg::A1:
return "a1";
case PhysReg::A2:
return "a2";
case PhysReg::A3:
return "a3";
case PhysReg::A4:
return "a4";
case PhysReg::A5:
return "a5";
case PhysReg::A6:
return "a6";
case PhysReg::A7:
return "a7";
case PhysReg::S2:
return "s2";
case PhysReg::S3:
return "s3";
case PhysReg::S4:
return "s4";
case PhysReg::S5:
return "s5";
case PhysReg::S6:
return "s6";
case PhysReg::S7:
return "s7";
case PhysReg::S8:
return "s8";
case PhysReg::S9:
return "s9";
case PhysReg::S10:
return "s10";
case PhysReg::S11:
return "s11";
case PhysReg::T3:
return "t3";
case PhysReg::T4:
return "t4";
case PhysReg::T5:
return "t5";
case PhysReg::T6:
return "t6";
// 浮点寄存器
case PhysReg::FT0: return "ft0";
case PhysReg::FT1: return "ft1";
case PhysReg::FT2: return "ft2";
case PhysReg::FT3: return "ft3";
case PhysReg::FT4: return "ft4";
case PhysReg::FT5: return "ft5";
case PhysReg::FT6: return "ft6";
case PhysReg::FT7: return "ft7";
case PhysReg::FS0: return "fs0";
case PhysReg::FS1: return "fs1";
case PhysReg::FA0: return "fa0";
case PhysReg::FA1: return "fa1";
case PhysReg::FA2: return "fa2";
case PhysReg::FA3: return "fa3";
case PhysReg::FA4: return "fa4";
case PhysReg::FA5: return "fa5";
case PhysReg::FA6: return "fa6";
case PhysReg::FA7: return "fa7";
case PhysReg::FT8: return "ft8";
case PhysReg::FT9: return "ft9";
case PhysReg::FT10: return "ft10";
case PhysReg::FT11: return "ft11";
} }
throw std::runtime_error(FormatError("mir", "未知物理寄存器")); throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
} }

@ -66,8 +66,8 @@ class SemaVisitor final : public SysYBaseVisitor {
} }
func->accept(this); func->accept(this);
} }
if (!has_main && ctx->funcDef().empty()) { if (!has_main) {
throw std::runtime_error(FormatError("sema", "缺少函数定义")); throw std::runtime_error(FormatError("sema", "缺少main函数定义"));
} }
scope_.PopScope(); scope_.PopScope();
return {}; return {};

@ -1,12 +1,71 @@
#include "sem/func.h" #include "sem/func.h"
#include <cstring>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include "utils/Log.h" #include "utils/Log.h"
#include <cmath> // 提供 ldexp
namespace {
// 解析十六进制浮点字面量,支持 0xH.Hp±E 格式
double ParseHexFloat(const std::string& str) {
const char* s = str.c_str();
if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) s += 2;
double significand = 0.0;
bool have_dot = false;
double dot_scale = 1.0 / 16.0;
while (*s && *s != 'p' && *s != 'P') {
if (*s == '.') {
have_dot = true;
++s;
continue;
}
int digit = -1;
if (*s >= '0' && *s <= '9') digit = *s - '0';
else if (*s >= 'a' && *s <= 'f') digit = *s - 'a' + 10;
else if (*s >= 'A' && *s <= 'F') digit = *s - 'A' + 10;
if (digit >= 0) {
if (have_dot) {
significand += digit * dot_scale;
dot_scale /= 16.0;
} else {
significand = significand * 16 + digit;
}
}
++s;
}
int exponent = 0;
if (*s == 'p' || *s == 'P') {
++s;
int sign = 1;
if (*s == '-') { sign = -1; ++s; }
else if (*s == '+') { ++s; }
exponent = 0;
while (*s >= '0' && *s <= '9') {
exponent = exponent * 10 + (*s - '0');
++s;
}
exponent *= sign;
}
return ldexp(significand, exponent);
}
} // anonymous namespace
namespace sem { namespace sem {
// Truncate double to float32 precision (mimics C float arithmetic)
static double ToFloat32(double v) {
float f = static_cast<float>(v);
return static_cast<double>(f);
}
// 编译时求值常量表达式 // 编译时求值常量表达式
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) { ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) {
return EvaluateExp(*ctx.addExp()); return EvaluateExp(*ctx.addExp());
@ -73,14 +132,65 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp()) { if (ctx.exp()) {
return EvaluateExp(*ctx.exp()->addExp()); return EvaluateExp(*ctx.exp()->addExp());
} else if (ctx.lVar()) { } else if (ctx.lVar()) {
// 处理变量引用(必须是已定义的常量) // 处理变量引用:向上遍历 AST 找到对应的常量定义并求值
auto* ident = ctx.lVar()->Ident(); auto* ident = ctx.lVar()->Ident();
if (!ident) { if (!ident) {
throw std::runtime_error(FormatError("sema", "非法变量引用")); throw std::runtime_error(FormatError("sema", "非法变量引用"));
} }
std::string name = ident->getText(); std::string name = ident->getText();
// 这里简化处理,实际应该在符号表中查找常量 // 向上遍历 AST 找到作用域内的 constDef
// 暂时假设常量已经在前面被处理过 antlr4::ParserRuleContext* scope =
dynamic_cast<antlr4::ParserRuleContext*>(ctx.lVar()->parent);
while (scope) {
// 检查当前作用域中的所有 constDecl
for (auto* tree_child : scope->children) {
auto* child = dynamic_cast<antlr4::ParserRuleContext*>(tree_child);
if (!child) continue;
auto* block_item = dynamic_cast<SysYParser::BlockItemContext*>(child);
if (block_item && block_item->decl()) {
auto* decl = block_item->decl();
if (decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
// compUnit 级别的 constDecl
auto* decl = dynamic_cast<SysYParser::DeclContext*>(child);
if (decl && decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
// If declared as int, truncate to integer
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
scope = dynamic_cast<antlr4::ParserRuleContext*>(scope->parent);
}
// 未找到常量定义,返回 0
ConstValue val; ConstValue val;
val.is_int = true; val.is_int = true;
val.int_val = 0; val.int_val = 0;
@ -94,11 +204,16 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
ConstValue val; ConstValue val;
if (int_const) { if (int_const) {
val.is_int = true; val.is_int = true;
val.int_val = std::stoll(int_const->getText()); val.int_val = std::stoll(int_const->getText(), nullptr, 0);
val.float_val = static_cast<double>(val.int_val); val.float_val = static_cast<double>(val.int_val);
} else if (float_const) { } else if (float_const) {
val.is_int = false; val.is_int = false;
val.float_val = std::stod(float_const->getText()); std::string text = float_const->getText();
if (text.size() >= 2 && (text[1] == 'x' || text[1] == 'X')) {
val.float_val = ToFloat32(ParseHexFloat(text));
} else {
val.float_val = ToFloat32(std::stod(text));
}
val.int_val = static_cast<long long>(val.float_val); val.int_val = static_cast<long long>(val.float_val);
} else { } else {
throw std::runtime_error(FormatError("sema", "非法数字字面量")); throw std::runtime_error(FormatError("sema", "非法数字字面量"));
@ -127,8 +242,9 @@ ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val); result.float_val = static_cast<double>(result.int_val);
} else { } else {
result.is_int = false; result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) + double l = lhs.is_int ? lhs.int_val : lhs.float_val;
(rhs.is_int ? rhs.int_val : rhs.float_val); double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l + r);
result.int_val = static_cast<long long>(result.float_val); result.int_val = static_cast<long long>(result.float_val);
} }
return result; return result;
@ -143,8 +259,9 @@ ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val); result.float_val = static_cast<double>(result.int_val);
} else { } else {
result.is_int = false; result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) - double l = lhs.is_int ? lhs.int_val : lhs.float_val;
(rhs.is_int ? rhs.int_val : rhs.float_val); double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l - r);
result.int_val = static_cast<long long>(result.float_val); result.int_val = static_cast<long long>(result.float_val);
} }
return result; return result;
@ -159,8 +276,9 @@ ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val); result.float_val = static_cast<double>(result.int_val);
} else { } else {
result.is_int = false; result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) * double l = lhs.is_int ? lhs.int_val : lhs.float_val;
(rhs.is_int ? rhs.int_val : rhs.float_val); double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l * r);
result.int_val = static_cast<long long>(result.float_val); result.int_val = static_cast<long long>(result.float_val);
} }
return result; return result;
@ -175,8 +293,9 @@ ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val); result.float_val = static_cast<double>(result.int_val);
} else { } else {
result.is_int = false; result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) / double l = lhs.is_int ? lhs.int_val : lhs.float_val;
(rhs.is_int ? rhs.int_val : rhs.float_val); double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l / r);
result.int_val = static_cast<long long>(result.float_val); result.int_val = static_cast<long long>(result.float_val);
} }
return result; return result;

@ -1,4 +1,6 @@
// 解析帮助、输入文件和输出阶段选项。 // 解析命令行: compiler <input.sy> -S -o <output.s> [-O1]
// 或: compiler <input.sy> -IR -o <output.ll> [-O1]
// 同时兼容 --emit-ir / --emit-asm 旧格式
#include "utils/CLI.h" #include "utils/CLI.h"
@ -15,30 +17,31 @@ CLIOptions ParseCLI(int argc, char** argv) {
if (argc <= 1) { if (argc <= 1) {
throw std::runtime_error(FormatError( throw std::runtime_error(FormatError(
"cli", "cli",
"用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>")); "用法: compiler <input.sy> -S -o <output.s> [-O1]\n"
" 或: compiler <input.sy> -IR -o <output.ll> [-O1]"));
} }
for (int i = 1; i < argc; ++i) { for (int i = 1; i < argc; ++i) {
const char* arg = argv[i]; const char* arg = argv[i];
if (std::strcmp(arg, "-h") == 0 || std::strcmp(arg, "--help") == 0) { if (std::strcmp(arg, "-h") == 0 || std::strcmp(arg, "--help") == 0) {
opt.show_help = true; opt.show_help = true;
return opt; return opt;
} }
if (std::strcmp(arg, "--emit-parse-tree") == 0) { // 输出阶段(新格式)
if (std::strcmp(arg, "-S") == 0) {
if (!explicit_emit) { if (!explicit_emit) {
opt.emit_parse_tree = false;
opt.emit_ir = false; opt.emit_ir = false;
opt.emit_asm = false; opt.emit_asm = false;
explicit_emit = true; explicit_emit = true;
} }
opt.emit_parse_tree = true; opt.emit_asm = true;
continue; continue;
} }
if (std::strcmp(arg, "--emit-ir") == 0) { if (std::strcmp(arg, "-IR") == 0) {
if (!explicit_emit) { if (!explicit_emit) {
opt.emit_parse_tree = false;
opt.emit_ir = false; opt.emit_ir = false;
opt.emit_asm = false; opt.emit_asm = false;
explicit_emit = true; explicit_emit = true;
@ -47,9 +50,9 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue; continue;
} }
// 输出阶段(兼容旧格式)
if (std::strcmp(arg, "--emit-asm") == 0) { if (std::strcmp(arg, "--emit-asm") == 0) {
if (!explicit_emit) { if (!explicit_emit) {
opt.emit_parse_tree = false;
opt.emit_ir = false; opt.emit_ir = false;
opt.emit_asm = false; opt.emit_asm = false;
explicit_emit = true; explicit_emit = true;
@ -58,6 +61,32 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue; continue;
} }
if (std::strcmp(arg, "--emit-ir") == 0) {
if (!explicit_emit) {
opt.emit_ir = false;
opt.emit_asm = false;
explicit_emit = true;
}
opt.emit_ir = true;
continue;
}
// 优化级别
if (std::strcmp(arg, "-O1") == 0) {
opt.opt = true;
continue;
}
// 输出文件
if (std::strcmp(arg, "-o") == 0) {
if (i + 1 >= argc) {
throw std::runtime_error(
FormatError("cli", "-o 缺少输出文件名"));
}
opt.output = argv[++i];
continue;
}
if (arg[0] == '-') { if (arg[0] == '-') {
throw std::runtime_error( throw std::runtime_error(
FormatError("cli", std::string("未知参数: ") + arg + FormatError("cli", std::string("未知参数: ") + arg +
@ -73,11 +102,12 @@ CLIOptions ParseCLI(int argc, char** argv) {
if (opt.input.empty() && !opt.show_help) { if (opt.input.empty() && !opt.show_help) {
throw std::runtime_error( throw std::runtime_error(
FormatError("cli", "缺少输入文件:请提供 <input.sy>(使用 --help 查看用法)")); FormatError("cli", "缺少输入文件:请提供 <input.sy>"));
} }
if (!opt.emit_parse_tree && !opt.emit_ir && !opt.emit_asm) { if (!explicit_emit) {
throw std::runtime_error(FormatError( // 未显式选择输出阶段时默认输出 IR
"cli", "未选择任何输出:请使用 --emit-parse-tree / --emit-ir / --emit-asm")); opt.emit_ir = true;
} }
return opt; return opt;
} }

@ -50,17 +50,22 @@ void PrintHelp(std::ostream& os) {
os << "SysY Compiler\n" os << "SysY Compiler\n"
<< "\n" << "\n"
<< "用法:\n" << "用法:\n"
<< " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>\n" << " compiler <input.sy> -IR -o <output.ll> [-O1] # 输出 IR\n"
<< " compiler <input.sy> -S -o <output.s> [-O1] # 输出汇编\n"
<< "\n" << "\n"
<< "选项:\n" << "选项:\n"
<< " -IR 输出中间代码IR 文本)\n"
<< " -S 输出 AArch64 汇编码\n"
<< " -o <file> 输出文件(默认 stdout\n"
<< " -O1 启用 IR 优化Mem2Reg + 标量优化)\n"
<< " -h, --help 打印帮助信息并退出\n" << " -h, --help 打印帮助信息并退出\n"
<< " --emit-parse-tree 仅在显式模式下启用语法树输出\n"
<< " --emit-ir 仅在显式模式下启用 IR 输出\n"
<< " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n"
<< "\n" << "\n"
<< "说明:\n" << "兼容格式(仍可使用):\n"
<< " - 默认输出 IR\n" << " --emit-ir 同 -IR\n"
<< " - 若使用 --emit-parse-tree/--emit-ir/--emit-asm则仅输出显式选择的阶段\n" << " --emit-asm 同 -S\n"
<< " - 可使用重定向写入文件:\n" << "\n"
<< " compiler --emit-asm test/test_case/functional/simple_add.sy > out.s\n"; << "示例:\n"
<< " compiler test.sy -IR -o test.ll -O1 # 生成优化 IR\n"
<< " compiler test.sy -S -o test.s -O1 # 生成优化汇编\n"
<< " compiler test.sy -IR # IR 输出到 stdout\n";
} }

Binary file not shown.

@ -1,4 +1,83 @@
// SysY 运行库实现: #include<stdio.h>
// - 按实验/评测规范提供 I/O 等函数实现 #include<stdarg.h>
// - 与编译器生成的目标代码链接,支撑运行时行为 #include<sys/time.h>
#include"sylib.h"
/* Input & output functions */
int getint(){int t; scanf("%d",&t); return t; }
int getch(){char c; scanf("%c",&c); return (int)c; }
float getfloat(){
float n;
scanf("%a", &n);
return n;
}
int getarray(int a[]){
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)scanf("%d",&a[i]);
return n;
}
int getfarray(float a[]) {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%a", &a[i]);
}
return n;
}
void putint(int a){ printf("%d",a);}
void putch(int a){ printf("%c",a); }
void putarray(int n,int a[]){
printf("%d:",n);
for(int i=0;i<n;i++)printf(" %d",a[i]);
printf("\n");
}
void putfloat(float a) {
printf("%a", a);
}
void putfarray(int n, float a[]) {
printf("%d:", n);
for (int i = 0; i < n; i++) {
printf(" %a", a[i]);
}
printf("\n");
}
void putf(char a[], ...) {
va_list args;
va_start(args, a);
vfprintf(stdout, a, args);
va_end(args);
}
/* Timing function implementation */
__attribute((constructor)) void before_main(){
for(int i=0;i<_SYSY_N;i++)
_sysy_h[i] = _sysy_m[i]= _sysy_s[i] = _sysy_us[i] =0;
_sysy_idx=1;
}
__attribute((destructor)) void after_main(){
for(int i=1;i<_sysy_idx;i++){
fprintf(stderr,"Timer@%04d-%04d: %dH-%dM-%dS-%dus\n",\
_sysy_l1[i],_sysy_l2[i],_sysy_h[i],_sysy_m[i],_sysy_s[i],_sysy_us[i]);
_sysy_us[0]+= _sysy_us[i];
_sysy_s[0] += _sysy_s[i]; _sysy_us[0] %= 1000000;
_sysy_m[0] += _sysy_m[i]; _sysy_s[0] %= 60;
_sysy_h[0] += _sysy_h[i]; _sysy_m[0] %= 60;
}
fprintf(stderr,"TOTAL: %dH-%dM-%dS-%dus\n",_sysy_h[0],_sysy_m[0],_sysy_s[0],_sysy_us[0]);
}
void _sysy_starttime(int lineno){
_sysy_l1[_sysy_idx] = lineno;
gettimeofday(&_sysy_start,NULL);
}
void _sysy_stoptime(int lineno){
gettimeofday(&_sysy_end,NULL);
_sysy_l2[_sysy_idx] = lineno;
_sysy_us[_sysy_idx] += 1000000 * ( _sysy_end.tv_sec - _sysy_start.tv_sec ) + _sysy_end.tv_usec - _sysy_start.tv_usec;
_sysy_s[_sysy_idx] += _sysy_us[_sysy_idx] / 1000000 ; _sysy_us[_sysy_idx] %= 1000000;
_sysy_m[_sysy_idx] += _sysy_s[_sysy_idx] / 60 ; _sysy_s[_sysy_idx] %= 60;
_sysy_h[_sysy_idx] += _sysy_m[_sysy_idx] / 60 ; _sysy_m[_sysy_idx] %= 60;
_sysy_idx ++;
}

@ -1,4 +1,31 @@
// SysY 运行库头文件: #ifndef __SYLIB_H_
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用) #define __SYLIB_H_
// - 与 sylib.c 配套,按规范逐步补齐声明
#include<stdio.h>
#include<stdarg.h>
#include<sys/time.h>
/* Input & output functions */
int getint(),getch(),getarray(int a[]);
float getfloat();
int getfarray(float a[]);
void putint(int a),putch(int a),putarray(int n,int a[]);
void putfloat(float a);
void putfarray(int n, float a[]);
void putf(char a[], ...);
/* Timing function implementation */
struct timeval _sysy_start,_sysy_end;
#define starttime() _sysy_starttime(__LINE__)
#define stoptime() _sysy_stoptime(__LINE__)
#define _SYSY_N 1024
int _sysy_l1[_SYSY_N],_sysy_l2[_SYSY_N];
int _sysy_h[_SYSY_N], _sysy_m[_SYSY_N],_sysy_s[_SYSY_N],_sysy_us[_SYSY_N];
int _sysy_idx;
__attribute((constructor)) void before_main();
__attribute((destructor)) void after_main();
void _sysy_starttime(int lineno);
void _sysy_stoptime(int lineno);
#endif

@ -0,0 +1,7 @@
// test/test_case/functional/test_riscv.sy
int main() {
int a = 10;
int b = 20;
int c = a + b;
return c; // 应该返回30
}

@ -0,0 +1,8 @@
//test domain of global var define and local define
int a = 3;
int b = 5;
int main(){
int a = 5;
return a + b;
}

@ -0,0 +1,8 @@
//test local var define
int main(){
int a, b0, _c;
a = 1;
b0 = 2;
_c = 3;
return b0 + _c;
}

@ -0,0 +1,4 @@
int a[10][10];
int main(){
return 0;
}

@ -0,0 +1,9 @@
//test array define
int main(){
int a[4][2] = {};
int b[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int c[4][2] = {{1, 2}, {3, 4}, {5, 6}, {7, 8}};
int d[4][2] = {1, 2, {3}, {5}, 7 , 8};
int e[4][2] = {{d[2][1], c[2][1]}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1] + e[0][0] + e[0][1] + a[2][0];
}

@ -0,0 +1,9 @@
int main(){
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
int b[4][2] = {};
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int d[3 + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
}

@ -0,0 +1,6 @@
//test const gloal var define
const int a = 10, b = 5;
int main(){
return b;
}

@ -0,0 +1,5 @@
//test const local var define
int main(){
const int a = 10, b = 5;
return b;
}

@ -0,0 +1,5 @@
const int a[5]={0,1,2,3,4};
int main(){
return a[4];
}

@ -0,0 +1,11 @@
int a;
int func(int p){
p = p - 1;
return p;
}
int main(){
int b;
a = 10;
b = func(a);
return b;
}

@ -0,0 +1,8 @@
int defn(){
return 4;
}
int main(){
int a=defn();
return a;
}

@ -0,0 +1,7 @@
//test add
int main(){
int a, b;
a = 10;
b = -1;
return a + b;
}

@ -0,0 +1,5 @@
//test addc
const int a = 10;
int main(){
return a + 5;
}

@ -0,0 +1,7 @@
//test sub
const int a = 10;
int main(){
int b;
b = 2;
return b - a;
}

@ -0,0 +1,6 @@
//test subc
int main(){
int a;
a = 10;
return a - 2;
}

@ -0,0 +1,7 @@
//test mul
int main(){
int a, b;
a = 10;
b = 5;
return a * b;
}

@ -0,0 +1,5 @@
//test mulc
const int a = 5;
int main(){
return a * 5;
}

@ -0,0 +1,7 @@
//test div
int main(){
int a, b;
a = 10;
b = 5;
return a / b;
}

@ -0,0 +1,5 @@
//test divc
const int a = 10;
int main(){
return a / 5;
}

@ -0,0 +1,6 @@
//test mod
int main(){
int a;
a = 10;
return a / 3;
}

@ -0,0 +1,6 @@
//test rem
int main(){
int a;
a = 10;
return a % 3;
}

@ -0,0 +1,25 @@
// test if-else-if
int ifElseIf() {
int a;
a = 5;
int b;
b = 10;
if(a == 6 || b == 0xb) {
return a;
}
else {
if (b == 10 && a == 1)
a = 25;
else if (b == 10 && a == -5)
a = a + 15;
else
a = -+a;
}
return a;
}
int main(){
putint(ifElseIf());
return 0;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save