Compare commits

...

11 Commits

Author SHA1 Message Date
Junhe Wu 3c6ffe8e3e fix(ir):修复了一些ir的错误
4 weeks ago
ppxf25tqu de126b93d6 Merge pull request 'feat(mir):修正并完善功能' (#7) from pt9wfaocb/nudt-compiler-cpp:tansiping into develop
4 weeks ago
tansiping 310c7c3697 feat(mir):修正并完善功能
4 weeks ago
ptabmhn4l 248db05cf4 Merge pull request 'feat(mir):实现MIR后端' (#6) from pfwvrotsf/nudt-compiler-cpp:feature/mir into develop
4 weeks ago
cy feaba9abd4 fix(mir):修正测试用例
4 weeks ago
cy 1ff1b543d1 feat(mir): MIR 后端(RISC-V架构)
4 weeks ago
ptabmhn4l 80c46cee7e Merge pull request '初步通过verify测试' (#5) from ptabmhn4l/nudt-compiler-cpp:fix/irgen into develop
1 month ago
Junhe Wu 19ef82738f fix(irgen):通过了除了性能测试外的测试用例。
1 month ago
Junhe Wu 4693253459 Merge branch 'develop' of https://bdgit.educoder.net/ppxf25tqu/nudt-compiler-cpp into develop
1 month ago
ptabmhn4l fd45b74e2e Merge pull request '基本完成了ir生成' (#4) from ptabmhn4l/nudt-compiler-cpp:feature/ir into develop
1 month ago
ptabmhn4l 74bcb45776 Merge pull request '把比赛的测试用例放进来' (#2) from ptabmhn4l/nudt-compiler-cpp:fix/testdata into develop
1 month ago

1
.gitignore vendored

@ -54,6 +54,7 @@ compile_commands.json
.fleet/
.vs/
*.code-workspace
CLAUDE.md
# CLion
cmake-build-debug/

@ -60,12 +60,14 @@ class Context {
~Context();
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
std::string NextTemp();
std::string NextTemp(); // 用于指令名(数字,连续)
std::string NextLabel(); // 用于块名(字母前缀,独立计数)
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1;
int label_index_ = -1;
};
// ─── Type ─────────────────────────────────────────────────────────────────────
@ -198,16 +200,28 @@ class GlobalValue : public User {
class GlobalVariable : public Value {
public:
GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements = 1);
int num_elements = 1, bool is_array_decl = false,
bool is_float = false);
bool IsConst() const { return is_const_; }
bool IsFloat() const { return is_float_; }
int GetInitVal() const { return init_val_; }
float GetInitValF() const { return init_val_f_; }
int GetNumElements() const { return num_elements_; }
bool IsArray() const { return num_elements_ > 1; }
// GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store
bool IsArray() const { return is_array_decl_ || num_elements_ > 1; }
void SetInitVals(std::vector<int> v) { init_vals_ = std::move(v); }
void SetInitValsF(std::vector<float> v) { init_vals_f_ = std::move(v); }
const std::vector<int>& GetInitVals() const { return init_vals_; }
const std::vector<float>& GetInitValsF() const { return init_vals_f_; }
bool HasInitVals() const { return !init_vals_.empty() || !init_vals_f_.empty(); }
private:
bool is_const_;
bool is_float_;
int init_val_;
float init_val_f_;
int num_elements_;
bool is_array_decl_;
std::vector<int> init_vals_;
std::vector<float> init_vals_f_;
};
// ─── Instruction ──────────────────────────────────────────────────────────────
@ -381,6 +395,29 @@ class BasicBlock : public Value {
return ptr;
}
template <typename T, typename... Args>
T* Prepend(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr;
}
// Insert before terminator (or append if no terminator)
template <typename T, typename... Args>
T* InsertBeforeTerminator(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
instructions_.insert(instructions_.end() - 1, std::move(inst));
} else {
instructions_.push_back(std::move(inst));
}
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
@ -409,6 +446,8 @@ class Function : public Value {
Argument* GetArgument(size_t i) const;
size_t GetNumArgs() const { return args_.size(); }
bool IsVoidReturn() const { return type_->IsVoid(); }
// 将某个块移动到 blocks_ 列表末尾(用于确保块顺序正确)
void MoveBlockToEnd(BasicBlock* bb);
private:
BasicBlock* entry_ = nullptr;
@ -437,7 +476,9 @@ class Module {
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const,
int init_val, int num_elements = 1);
int init_val, int num_elements = 1,
bool is_array_decl = false,
bool is_float = false);
GlobalVariable* GetGlobalVariable(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const;
@ -494,9 +535,12 @@ class IRBuilder {
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaF32(const std::string& name);
AllocaInst* CreateAllocaArray(int num_elements, const std::string& name);
AllocaInst* CreateAllocaArrayF32(int num_elements, const std::string& name);
GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
// 零初始化数组emit memset call
void CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod);
// 控制流
ReturnInst* CreateRet(Value* v);
@ -518,9 +562,12 @@ class IRBuilder {
SIToFPInst* CreateSIToFP(Value* val, const std::string& name);
FPToSIInst* CreateFPToSI(Value* val, const std::string& name);
void SetAllocaBlock(BasicBlock* bb) { alloca_block_ = bb; }
private:
Context& ctx_;
BasicBlock* insert_block_;
BasicBlock* alloca_block_ = nullptr;
};
// ─── IRPrinter ────────────────────────────────────────────────────────────────

@ -19,39 +19,161 @@ class MIRContext {
MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP };
// RISC-V 64位寄存器定义
enum class PhysReg {
// 通用寄存器
ZERO, // x0, 恒为0
RA, // x1, 返回地址
SP, // x2, 栈指针
GP, // x3, 全局指针
TP, // x4, 线程指针
T0, // x5, 临时寄存器
T1, // x6, 临时寄存器
T2, // x7, 临时寄存器
S0, // x8, 帧指针/保存寄存器
S1, // x9, 保存寄存器
A0, // x10, 参数/返回值
A1, // x11, 参数
A2, // x12, 参数
A3, // x13, 参数
A4, // x14, 参数
A5, // x15, 参数
A6, // x16, 参数
A7, // x17, 参数
S2, // x18, 保存寄存器
S3, // x19, 保存寄存器
S4, // x20, 保存寄存器
S5, // x21, 保存寄存器
S6, // x22, 保存寄存器
S7, // x23, 保存寄存器
S8, // x24, 保存寄存器
S9, // x25, 保存寄存器
S10, // x26, 保存寄存器
S11, // x27, 保存寄存器
T3, // x28, 临时寄存器
T4, // x29, 临时寄存器
T5, // x30, 临时寄存器
T6, // x31, 临时寄存器
FT0, FT1, FT2, FT3, FT4, FT5, FT6, FT7,
FS0, FS1,
FA0, FA1, FA2, FA3, FA4, FA5, FA6, FA7,
FT8, FT9, FT10, FT11,
};
const char* PhysRegName(PhysReg reg);
// 在 MIR.h 中添加(在 Opcode 枚举之前)
struct GlobalVarInfo {
std::string name;
int value;
float valueF;
bool isConst;
bool isArray;
bool isFloat;
std::vector<int> arrayValues;
std::vector<float> arrayValuesF;
int arraySize;
};
enum class Opcode {
Prologue,
Epilogue,
MovImm,
LoadStack,
StoreStack,
AddRR,
Load,
Store,
Add,
Addi,
Sub,
Mul,
Div,
Rem,
Slt,
Slti,
Slli,
Sltu, // 无符号小于
Sltiu,
Xori,
LoadGlobalAddr,
LoadGlobal,
StoreGlobal,
LoadIndirect, // lw rd, 0(rs1) 从寄存器地址加载
StoreIndirect, // sw rs2, 0(rs1)
Call,
GEP,
LoadAddr,
Ret,
// 浮点指令
FMov, // 浮点移动
FMovWX, // fmv.w.x fs, x 整数寄存器移动到浮点寄存器
FMovXW, // fmv.x.w x, fs 浮点寄存器移动到整数寄存器
FAdd,
FSub,
FMul,
FDiv,
FEq, // 浮点相等比较
FLt, // 浮点小于比较
FLe, // 浮点小于等于比较
FNeg, // 浮点取反
FAbs, // 浮点绝对值
SIToFP, // int 转 float
FPToSI, // float 转 int
LoadFloat, // 浮点加载 (flw)
StoreFloat, // 浮点存储 (fsw)
Br,
CondBr,
Label,
};
enum class GlobalKind {
Data, // .data 段(已初始化)
BSS, // .bss 段未初始化初始为0
RoData // .rodata 段(只读常量)
};
// 全局变量信息
struct GlobalInfo {
std::string name;
GlobalKind kind;
int size; // 大小(字节)
int value; // 初始值(对于简单变量)
bool isArray;
int arraySize;
std::vector<int> dimensions; // 数组维度
};
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex };
enum class Kind { Reg, Imm, FrameIndex, Global, Func };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand Imm64(int64_t value); // 新增:存储 64 位值
static Operand FrameIndex(int index);
static Operand Global(const std::string& name);
static Operand Func(const std::string& name);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int64_t GetImm64() const { return imm64_; } // 新增
int GetFrameIndex() const { return imm_; }
const std::string& GetGlobalName() const { return global_name_; }
const std::string& GetFuncName() const { return func_name_; }
private:
Operand(Kind kind, PhysReg reg, int imm);
Operand(Kind kind, PhysReg reg, int64_t imm64); // 新增构造函数
Operand(Kind kind, PhysReg reg, int imm, const std::string& name);
Kind kind_;
PhysReg reg_;
int imm_;
int64_t imm64_; // 新增
std::string global_name_;
std::string func_name_;
};
class MachineInstr {
@ -71,7 +193,6 @@ struct FrameSlot {
int size = 4;
int offset = 0;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
@ -93,9 +214,14 @@ class MachineFunction {
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
// 基本块管理
MachineBasicBlock* CreateBlock(const std::string& name);
MachineBasicBlock* GetEntry() { return entry_; }
const MachineBasicBlock* GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const { return blocks_; }
// 栈帧管理
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
@ -106,14 +232,15 @@ class MachineFunction {
private:
std::string name_;
MachineBasicBlock entry_;
MachineBasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
//std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os);
} // namespace mir
//void PrintAsm(const MachineFunction& function, std::ostream& os);
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module);
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os);
} // namespace mir

@ -0,0 +1,309 @@
.text
.global main
.type main, @function
main:
addi sp, sp, -272
sw ra, 264(sp)
sw s0, 256(sp)
addi a0, sp, -4
li a1, 0
li a2, 32
call
addi a0, sp, -8
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -8
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -8
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -8
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -8
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -8
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -8
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -8
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -8
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -44
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -44
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -44
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -44
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -44
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -44
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -44
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -44
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -44
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -80
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -80
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -80
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -80
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -80
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -80
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -80
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -112(sp)
li t0, 1
lw t1, -112(sp)
add t0, t0, t1
sw t0, -116(sp)
addi t0, sp, -80
lw t1, -116(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -124(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -128(sp)
li t0, 1
lw t1, -128(sp)
add t0, t0, t1
sw t0, -132(sp)
addi t0, sp, -44
lw t1, -132(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -140(sp)
addi a0, sp, -108
li a1, 0
li a2, 32
call
lw t2, -124(sp)
addi t0, sp, -108
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
lw t2, -140(sp)
addi t0, sp, -108
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -108
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -108
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -108
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -108
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -108
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -108
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 3
li t1, 2
mul t0, t0, t1
sw t0, -176(sp)
li t0, 1
lw t1, -176(sp)
add t0, t0, t1
sw t0, -180(sp)
addi t0, sp, -108
lw t1, -180(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -188(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -192(sp)
li t0, 0
lw t1, -192(sp)
add t0, t0, t1
sw t0, -196(sp)
addi t0, sp, -108
lw t1, -196(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -204(sp)
lw t0, -188(sp)
lw t1, -204(sp)
add t0, t0, t1
sw t0, -208(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -212(sp)
li t0, 1
lw t1, -212(sp)
add t0, t0, t1
sw t0, -216(sp)
addi t0, sp, -108
lw t1, -216(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -224(sp)
lw t0, -208(sp)
lw t1, -224(sp)
add t0, t0, t1
sw t0, -228(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -232(sp)
li t0, 0
lw t1, -232(sp)
add t0, t0, t1
sw t0, -236(sp)
addi t0, sp, -4
lw t1, -236(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -244(sp)
lw t0, -228(sp)
lw t1, -244(sp)
add t0, t0, t1
sw t0, -248(sp)
lw a0, -248(sp)
lw ra, 264(sp)
lw s0, 256(sp)
addi sp, sp, 272
ret
.size main, .-main

@ -0,0 +1,137 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/mir"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
exit 1
fi
mkdir -p "$TEST_RESULT_DIR"
echo "=========================================="
echo "RISC-V 后端测试"
echo "=========================================="
echo ""
# 收集测试用例
mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort)
total=${#test_files[@]}
pass_gen=0
fail_gen=0
pass_run=0
fail_run=0
timeout_cnt=0
echo "=== 阶段1汇编生成 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s"
mkdir -p "$(dirname "$output_file")"
"$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file"
if [ $? -eq 0 ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}${NC} $relative_path"
((pass_gen++))
else
echo -e " ${RED}${NC} $relative_path"
((fail_gen++))
fi
done
echo ""
echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---"
echo ""
echo "=== 阶段2运行验证 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
stem="${relative_path%.sy}"
asm_file="$TEST_RESULT_DIR/${stem}.s"
exe_file="$TEST_RESULT_DIR/${stem}"
expected_file="${test_file%.sy}.out"
if [ ! -s "$asm_file" ]; then
echo -e " ${YELLOW}${NC} $relative_path (跳过)"
continue
fi
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e " ${RED}${NC} $relative_path (链接失败)"
((fail_run++))
continue
fi
# 运行程序,设置超时 5 秒
timeout 5 qemu-riscv64 "$exe_file" 2>/dev/null
exit_code=$?
# 检查是否超时
if [ $exit_code -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
continue
fi
# 获取程序输出(需要单独捕获,因为 timeout 会改变输出)
program_output=$(timeout 5 qemu-riscv64 "$exe_file" 2>/dev/null)
if [ $? -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
continue
fi
if [ -f "$expected_file" ]; then
expected=$(cat "$expected_file" | tr -d '\n')
# 判断期望文件是输出内容还是退出码
if [ -z "$expected" ] || [[ "$expected" =~ ^[0-9]+$ ]]; then
# 期望退出码
if [ $exit_code -eq "$expected" ] 2>/dev/null; then
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)"
((fail_run++))
fi
else
# 期望输出内容
if [ "$program_output" = "$expected" ]; then
echo -e " ${GREEN}${NC} $relative_path (输出匹配)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (输出不匹配)"
((fail_run++))
fi
fi
else
# 没有期望文件,默认通过
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
fi
done
echo ""
echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---"
echo ""
echo "=========================================="
echo "测试完成"
echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen"
echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt"
echo "=========================================="

@ -0,0 +1,153 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/mir"
SYLIB_C="$PROJECT_ROOT/sylib/sylib.c"
SYLIB_O="/tmp/sylib.o"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
exit 1
fi
# 编译 sylib 运行时库
echo "编译运行时库..."
if [ ! -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -c "$SYLIB_C" -o "$SYLIB_O" 2>/dev/null
if [ $? -ne 0 ]; then
echo "警告:无法编译 sylib.c部分测试可能链接失败"
fi
fi
echo ""
mkdir -p "$TEST_RESULT_DIR"
echo "=========================================="
echo "RISC-V 后端测试"
echo "=========================================="
echo ""
# 收集测试用例
mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort)
total=${#test_files[@]}
pass_gen=0
fail_gen=0
pass_run=0
fail_run=0
timeout_cnt=0
echo "=== 阶段1汇编生成 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s"
mkdir -p "$(dirname "$output_file")"
"$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file"
if [ $? -eq 0 ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}${NC} $relative_path"
((pass_gen++))
else
echo -e " ${RED}${NC} $relative_path"
((fail_gen++))
fi
done
echo ""
echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---"
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
stem="${relative_path%.sy}"
asm_file="$TEST_RESULT_DIR/${stem}.s"
exe_file="$TEST_RESULT_DIR/${stem}"
expected_file="${test_file%.sy}.out"
if [ ! -s "$asm_file" ]; then
echo -e " ${YELLOW}${NC} $relative_path (跳过)"
continue
fi
# 链接
if [ -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -static "$asm_file" "$SYLIB_O" -o "$exe_file" -no-pie 2>/dev/null
else
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null
fi
if [ $? -ne 0 ]; then
echo -e " ${RED}${NC} $relative_path (链接失败)"
((fail_run++))
continue
fi
# 运行程序
input_file="${test_file%.sy}.in"
tmp_out=$(mktemp)
if [ -f "$input_file" ]; then
timeout 10 qemu-riscv64 "$exe_file" < "$input_file" > "$tmp_out" 2>&1
else
timeout 10 qemu-riscv64 "$exe_file" > "$tmp_out" 2>&1
fi
exit_code=$?
if [ $exit_code -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
rm -f "$tmp_out"
continue
fi
program_output=$(cat "$tmp_out" | tr -d '\n' | sed 's/[[:space:]]*$//')
rm -f "$tmp_out"
if [ -f "$expected_file" ]; then
expected=$(cat "$expected_file" | tr -d '\n' | sed 's/[[:space:]]*$//')
if [[ "$expected" =~ ^[0-9]+$ ]] && [ "$expected" -ge 0 ] && [ "$expected" -le 255 ] && [ -z "$program_output" ]; then
# 期望退出码(且没有输出)
if [ $exit_code -eq "$expected" ] 2>/dev/null; then
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)"
((fail_run++))
fi
else
# 期望输出内容
if [ "$program_output" = "$expected" ]; then
echo -e " ${GREEN}${NC} $relative_path (输出匹配)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (输出不匹配: 期望 '$expected', 实际 '$program_output')"
((fail_run++))
fi
fi
else
# 没有期望文件
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code, 输出: '$program_output')"
((pass_run++))
fi
done
echo ""
echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---"
echo ""
echo "=========================================="
echo "测试完成"
echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen"
echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt"
echo "=========================================="

@ -0,0 +1,65 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_DIR="$PROJECT_ROOT/test/test_case/basic"
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器不存在: $COMPILER"
exit 1
fi
echo "=========================================="
echo "RISC-V 浮点转换测试"
echo "=========================================="
TESTS="
float_conv:3
float_add:13
float_mul:30
"
PASS=0
FAIL=0
for test in $TESTS; do
name=$(echo $test | cut -d: -f1)
expected=$(echo $test | cut -d: -f2)
echo -n "测试 $name (期望 $expected) ... "
"$COMPILER" "$TEST_DIR/$name.sy" --emit-asm > /tmp/test_$name.s 2>&1
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (汇编错误)${NC}"
cat /tmp/test_$name.s | head -3
FAIL=$((FAIL + 1))
continue
fi
riscv64-linux-gnu-gcc -static /tmp/test_$name.s -o /tmp/test_$name -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (链接错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
qemu-riscv64 /tmp/test_$name > /dev/null 2>&1
exit_code=$?
if [ $exit_code -eq $expected ]; then
echo -e "${GREEN}通过${NC}"
PASS=$((PASS + 1))
else
echo -e "${RED}失败 (实际 $exit_code)${NC}"
FAIL=$((FAIL + 1))
fi
done
echo "=========================================="
echo -e "测试结果: ${GREEN}通过 $PASS${NC} / ${RED}失败 $FAIL${NC}"
echo "=========================================="

@ -4,6 +4,9 @@ PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir"
VERIFY_SCRIPT="$PROJECT_ROOT/scripts/verify_ir.sh"
PARALLEL=${PARALLEL:-$(nproc)}
LOG_FILE="$TEST_RESULT_DIR/verify.log"
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
@ -12,47 +15,130 @@ if [ ! -x "$COMPILER" ]; then
fi
mkdir -p "$TEST_RESULT_DIR"
> "$LOG_FILE"
pass_count=0
fail_count=0
failed_cases=()
# ── 阶段1IR 生成(并行)────────────────────────────────────────────────────
echo "=== 阶段1IR 生成 ===" | tee -a "$LOG_FILE"
echo "" | tee -a "$LOG_FILE"
echo "=== 开始测试 IR 生成 ==="
echo ""
GEN_TMPDIR=$(mktemp -d)
while IFS= read -r test_file; do
gen_one() {
test_file="$1"
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll"
mkdir -p "$(dirname "$output_file")"
echo -n "测试: $relative_path ... "
"$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1
exit_code=$?
# Use a per-case tmp file to avoid concurrent write issues
case_id=$(echo "$relative_path" | tr '/' '_')
if [ $exit_code -eq 0 ] && [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then
echo "通过"
pass_count=$((pass_count + 1))
echo "通过: $relative_path" > "$GEN_TMPDIR/pass_${case_id}"
else
echo "失败"
fail_count=$((fail_count + 1))
echo "$relative_path" > "$GEN_TMPDIR/fail_${case_id}"
echo "失败: $relative_path" > "$GEN_TMPDIR/line_fail_${case_id}"
fi
}
export -f gen_one
export COMPILER TEST_CASE_DIR TEST_RESULT_DIR GEN_TMPDIR
find "$TEST_CASE_DIR" -name "*.sy" | sort | \
xargs -P "$PARALLEL" -I{} bash -c 'gen_one "$@"' _ {}
# Collect results in sorted order
failed_cases=()
for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f")
case_id=$(echo "$relative_path" | tr '/' '_')
if [ -f "$GEN_TMPDIR/pass_${case_id}" ]; then
cat "$GEN_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE"
elif [ -f "$GEN_TMPDIR/fail_${case_id}" ]; then
cat "$GEN_TMPDIR/line_fail_${case_id}" | tee -a "$LOG_FILE"
failed_cases+=("$relative_path")
echo " 错误信息已保存到: $output_file"
fi
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
done
pass_count=$(ls "$GEN_TMPDIR"/pass_* 2>/dev/null | wc -l)
fail_count=${#failed_cases[@]}
rm -rf "$GEN_TMPDIR"
echo "" | tee -a "$LOG_FILE"
echo "--- 生成完成: 通过 $pass_count / 失败 $fail_count ---" | tee -a "$LOG_FILE"
# ── 阶段2IR 运行验证(并行,需要 llc + clang──────────────────────────────
if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then
echo "" | tee -a "$LOG_FILE"
echo "=== 跳过阶段2未找到 llc 或 clang无法运行 IR ===" | tee -a "$LOG_FILE"
else
echo "" | tee -a "$LOG_FILE"
echo "=== 阶段2IR 运行验证 ===" | tee -a "$LOG_FILE"
echo "" | tee -a "$LOG_FILE"
VRF_TMPDIR=$(mktemp -d)
verify_one() {
test_file="$1"
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
relative_dir=$(dirname "$relative_path")
out_dir="$TEST_RESULT_DIR/$relative_dir"
stem=$(basename "${test_file%.sy}")
case_log="$out_dir/$stem.verify.log"
case_id=$(echo "$relative_path" | tr '/' '_')
if bash "$VERIFY_SCRIPT" "$test_file" "$out_dir" --run > "$case_log" 2>&1; then
echo "通过: $relative_path" > "$VRF_TMPDIR/pass_${case_id}"
else
extra=$(grep -E '(退出码|输出不匹配|错误)' "$case_log" | head -3 | sed 's/^/ /' || true)
{ echo "失败: $relative_path"; [ -n "$extra" ] && echo "$extra"; } > "$VRF_TMPDIR/fail_${case_id}"
echo "$relative_path" > "$VRF_TMPDIR/failname_${case_id}"
fi
}
export -f verify_one
export TEST_CASE_DIR TEST_RESULT_DIR VERIFY_SCRIPT VRF_TMPDIR
find "$TEST_CASE_DIR" -name "*.sy" | sort | \
xargs -P "$PARALLEL" -I{} bash -c 'verify_one "$@"' _ {}
# Collect results in sorted order
verify_failed_cases=()
for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f")
case_id=$(echo "$relative_path" | tr '/' '_')
if [ -f "$VRF_TMPDIR/pass_${case_id}" ]; then
cat "$VRF_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE"
elif [ -f "$VRF_TMPDIR/fail_${case_id}" ]; then
cat "$VRF_TMPDIR/fail_${case_id}" | tee -a "$LOG_FILE"
verify_failed_cases+=("$relative_path")
fi
done
verify_pass=$(ls "$VRF_TMPDIR"/pass_* 2>/dev/null | wc -l)
verify_fail=${#verify_failed_cases[@]}
rm -rf "$VRF_TMPDIR"
echo "" | tee -a "$LOG_FILE"
echo "--- 验证完成: 通过 $verify_pass / 失败 $verify_fail ---" | tee -a "$LOG_FILE"
if [ ${#verify_failed_cases[@]} -gt 0 ]; then
echo "" | tee -a "$LOG_FILE"
echo "=== 验证失败的用例 ===" | tee -a "$LOG_FILE"
for f in "${verify_failed_cases[@]}"; do
[ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE"
done
fi
fi
echo ""
echo "=== 测试完成 ==="
echo "通过: $pass_count"
echo "失败: $fail_count"
echo "结果保存在: $TEST_RESULT_DIR"
# ── 汇总 ─────────────────────────────────────────────────────────────────────
echo "" | tee -a "$LOG_FILE"
echo "=== 测试完成 ===" | tee -a "$LOG_FILE"
echo "IR生成 通过: $pass_count 失败: $fail_count" | tee -a "$LOG_FILE"
echo "结果保存在: $TEST_RESULT_DIR" | tee -a "$LOG_FILE"
echo "日志保存在: $LOG_FILE" | tee -a "$LOG_FILE"
if [ ${#failed_cases[@]} -gt 0 ]; then
echo ""
echo "=== 失败的用例 ==="
echo "" | tee -a "$LOG_FILE"
echo "=== IR生成失败的用例 ===" | tee -a "$LOG_FILE"
for f in "${failed_cases[@]}"; do
echo " - $f"
[ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE"
done
exit 1
fi

@ -0,0 +1,85 @@
#!/bin/bash
# 获取项目根目录
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_DIR="$PROJECT_ROOT/test/test_case/basic"
# 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
# 检查编译器
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器不存在: $COMPILER"
exit 1
fi
# 检查工具链
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "错误: 未找到 riscv64-linux-gnu-gcc"
exit 1
fi
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "错误: 未找到 qemu-riscv64"
exit 1
fi
echo "=========================================="
echo "RISC-V 基础功能测试"
echo "=========================================="
# 定义测试用例
TESTS="arith:50 add:30 sub:7 mul:50 div:25 mod:2 var:43"
PASS=0
FAIL=0
for test in $TESTS; do
name=$(echo $test | cut -d: -f1)
expected=$(echo $test | cut -d: -f2)
echo -n "测试 $name (期望 $expected) ... "
# 生成汇编
"$COMPILER" "$TEST_DIR/$name.sy" --emit-asm > /tmp/test_$name.s 2>&1
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (汇编错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
# 链接
riscv64-linux-gnu-gcc -static /tmp/test_$name.s -o /tmp/test_$name -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (链接错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
# 运行
qemu-riscv64 /tmp/test_$name > /dev/null 2>&1
exit_code=$?
if [ $exit_code -eq $expected ]; then
echo -e "${GREEN}通过${NC}"
PASS=$((PASS + 1))
else
echo -e "${RED}失败 (实际 $exit_code)${NC}"
FAIL=$((FAIL + 1))
fi
done
echo "=========================================="
echo -e "测试结果: ${GREEN}通过 $PASS${NC} / ${RED}失败 $FAIL${NC}"
echo "=========================================="
if [ $FAIL -eq 0 ]; then
echo -e "${GREEN}✓ 所有基础测试通过!${NC}"
exit 0
else
echo -e "${RED}✗ 有 $FAIL 个测试失败${NC}"
exit 1
fi

@ -3,6 +3,8 @@
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
@ -31,7 +33,7 @@ if [[ ! -f "$input" ]]; then
exit 1
fi
compiler="./build/bin/compiler"
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1
@ -60,13 +62,13 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe"
clang "$obj" "$PROJECT_ROOT/sylib/sylib.c" -o "$exe" -lm
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
(ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file"
else
"$exe" > "$stdout_file"
(ulimit -s unlimited; "$exe") > "$stdout_file"
fi
status=$?
set -e
@ -81,9 +83,11 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
if diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \
<(tr -d '\r' < "$actual_file" | sed -e '$a\') > /dev/null 2>&1; then
echo "输出匹配: $expected_file"
else
diff -u "$expected_file" "$actual_file" || true
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
exit 1

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/mir"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
mir_file="$out_dir/$stem.mir"
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
# 生成 MIR
"$compiler" --emit-mir "$input" > "$mir_file"
echo "MIR 已生成: $mir_file"
# 生成汇编
"$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file"
if [[ "$run_exec" == true ]]; then
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 riscv64-linux-gnu-gcc" >&2
exit 1
fi
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "未找到 qemu-riscv64" >&2
exit 1
fi
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe" -no-pie
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
qemu-riscv64 "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-riscv64 "$exe" > "$stdout_file"
fi
status=$?
set -e
cat "$stdout_file"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/riscv_asm"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 riscv64-linux-gnu-gcc无法汇编/链接。" >&2
echo "请安装: sudo apt install gcc-riscv64-linux-gnu" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" 2>/dev/null > "$asm_file"
echo "汇编已生成: $asm_file"
# 使用静态链接避免动态链接器问题
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe" -no-pie
echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "未找到 qemu-riscv64无法运行生成的可执行文件。" >&2
echo "请安装: sudo apt install qemu-user" >&2
exit 1
fi
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
qemu-riscv64 "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-riscv64 "$exe" > "$stdout_file"
fi
status=$?
set -e
cat "$stdout_file"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi

@ -29,4 +29,10 @@ std::string Context::NextTemp() {
return oss.str();
}
std::string Context::NextLabel() {
std::ostringstream oss;
oss << "L" << ++label_index_;
return oss.str();
}
} // namespace ir

@ -17,6 +17,17 @@ BasicBlock* Function::CreateBlock(const std::string& name) {
return ptr;
}
void Function::MoveBlockToEnd(BasicBlock* bb) {
for (size_t i = 0; i < blocks_.size(); ++i) {
if (blocks_[i].get() == bb) {
auto tmp = std::move(blocks_[i]);
blocks_.erase(blocks_.begin() + i);
blocks_.push_back(std::move(tmp));
return;
}
}
}
BasicBlock* Function::GetEntry() { return entry_; }
const BasicBlock* Function::GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {

@ -136,6 +136,15 @@ AllocaInst* IRBuilder::CreateAllocaArray(int num_elements,
num_elements, name);
}
AllocaInst* IRBuilder::CreateAllocaArrayF32(int num_elements,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrFloat32Type(),
num_elements, name);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index,
const std::string& name) {
if (!insert_block_) {
@ -237,4 +246,20 @@ FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) {
return insert_block_->Append<FPToSIInst>(val, name);
}
void IRBuilder::CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
// declare memset if not already declared
if (!mod.HasExternalDecl("memset")) {
mod.DeclareExternalFunc("memset", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
}
int byte_count = num_elements * 4;
insert_block_->Append<CallInst>(
std::string("memset"), Type::GetVoidType(),
std::vector<Value*>{ptr, ctx.GetConstInt(0), ctx.GetConstInt(byte_count)},
std::string(""));
}
} // namespace ir

@ -1,8 +1,13 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <algorithm>
#include <unordered_map>
#include "utils/Log.h"
@ -44,87 +49,253 @@ static const char* FPredToStr(FCmpPredicate pred) {
return "?";
}
static std::string ValStr(const Value* v) {
using RenameMap = std::unordered_map<const Value*, int>;
static std::string ValStr(const Value* v, const RenameMap& rm) {
if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
if (dynamic_cast<const ConstantInt*>(v))
return std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::to_string(cf->GetValue());
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
return oss.str();
}
// BasicBlock: 打印为 label %name
if (dynamic_cast<const BasicBlock*>(v)) {
if (dynamic_cast<const BasicBlock*>(v))
return "%" + v->GetName();
}
// GlobalVariable: 打印为 @name
if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) {
if (gv->IsArray()) {
// 数组全局变量的指针getelementptr [N x i32], [N x i32]* @name, i32 0, i32 0
return "getelementptr ([" + std::to_string(gv->GetNumElements()) +
" x i32], [" + std::to_string(gv->GetNumElements()) +
" x i32]* @" + gv->GetName() + ", i32 0, i32 0)";
const char* et = gv->IsFloat() ? "float" : "i32";
return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) +
" x " + et + "], [" + std::to_string(gv->GetNumElements()) +
" x " + et + "]* @" + gv->GetName() + ", i32 0, i32 0)";
}
return "@" + v->GetName();
}
auto it = rm.find(v);
if (it != rm.end()) return "%" + std::to_string(it->second);
return "%" + v->GetName();
}
static std::string TypeVal(const Value* v) {
static std::string TypeVal(const Value* v, const RenameMap& rm) {
if (!v) return "void";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::string(TypeToStr(*ci->GetType())) + " " +
std::to_string(ci->GetValue());
}
if (dynamic_cast<const ConstantInt*>(v))
return std::string(TypeToStr(*v->GetType())) + " " +
std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::string(TypeToStr(*cf->GetType())) + " " +
std::to_string(cf->GetValue());
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
return oss.str();
}
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v, rm);
}
// Print one instruction (non-alloca) using rename map
static void PrintInst(const Instruction* inst, std::ostream& os,
const RenameMap& rm) {
auto N = [&](const Value* v) -> std::string {
auto it = rm.find(v);
if (it != rm.end()) return std::to_string(it->second);
return v->GetName();
};
auto VS = [&](const Value* v) { return ValStr(v, rm); };
auto TV = [&](const Value* v) { return TypeVal(v, rm); };
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op = nullptr;
switch (bin->GetOpcode()) {
case Opcode::Add: op = "add"; break;
case Opcode::Sub: op = "sub"; break;
case Opcode::Mul: op = "mul"; break;
case Opcode::Div: op = "sdiv"; break;
case Opcode::Mod: op = "srem"; break;
default: op = "?"; break;
}
os << " %" << N(bin) << " = " << op << " i32 "
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
break;
}
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op = nullptr;
switch (bin->GetOpcode()) {
case Opcode::FAdd: op = "fadd"; break;
case Opcode::FSub: op = "fsub"; break;
case Opcode::FMul: op = "fmul"; break;
case Opcode::FDiv: op = "fdiv"; break;
default: op = "?"; break;
}
os << " %" << N(bin) << " = " << op << " float "
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " %" << N(cmp) << " = icmp " << PredToStr(cmp->GetPredicate())
<< " i32 " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " %" << N(cmp) << " = fcmp " << FPredToStr(cmp->GetPredicate())
<< " float " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* al = static_cast<const AllocaInst*>(inst);
const char* et = al->GetType()->IsPtrFloat32() ? "float" : "i32";
if (al->IsArray())
os << " %" << N(al) << " = alloca " << et << ", i32 " << al->GetNumElements() << "\n";
else
os << " %" << N(al) << " = alloca " << et << "\n";
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
bool fp = gep->GetBasePtr()->GetType()->IsPtrFloat32();
os << " %" << N(gep) << " = getelementptr " << (fp ? "float" : "i32")
<< ", " << (fp ? "float*" : "i32*") << " "
<< VS(gep->GetBasePtr()) << ", i32 " << VS(gep->GetIndex()) << "\n";
break;
}
case Opcode::Load: {
auto* ld = static_cast<const LoadInst*>(inst);
bool fp = ld->GetPtr()->GetType()->IsPtrFloat32();
os << " %" << N(ld) << " = load " << (fp ? "float" : "i32")
<< ", " << (fp ? "float*" : "i32*") << " " << VS(ld->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* st = static_cast<const StoreInst*>(inst);
os << " store " << TV(st->GetValue()) << ", "
<< TypeToStr(*st->GetPtr()->GetType()) << " " << VS(st->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->HasValue()) os << " ret void\n";
else os << " ret " << TV(ret->GetValue()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BrInst*>(inst);
os << " br label %" << br->GetTarget()->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << VS(cbr->GetCond()) << ", label %"
<< cbr->GetTrueBB()->GetName() << ", label %"
<< cbr->GetFalseBB()->GetName() << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
if (!call->IsVoid() && !call->GetName().empty())
os << " %" << N(call) << " = ";
else
os << " ";
os << "call " << (call->IsVoid() ? "void" : TypeToStr(*call->GetType()))
<< " @" << call->GetCalleeName() << "(";
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
os << TV(call->GetArg(i));
}
os << ")\n";
break;
}
case Opcode::ZExt: {
auto* ze = static_cast<const ZExtInst*>(inst);
os << " %" << N(ze) << " = zext i1 " << VS(ze->GetSrc()) << " to i32\n";
break;
}
case Opcode::SIToFP: {
auto* si = static_cast<const SIToFPInst*>(inst);
os << " %" << N(si) << " = sitofp i32 " << VS(si->GetSrc()) << " to float\n";
break;
}
case Opcode::FPToSI: {
auto* fp = static_cast<const FPToSIInst*>(inst);
os << " %" << N(fp) << " = fptosi float " << VS(fp->GetSrc()) << " to i32\n";
break;
}
}
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v);
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
// 1. 全局变量/常量
for (const auto& g : module.GetGlobalVariables()) {
if (g->IsArray()) {
// 全局数组zeroinitializer
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
for (const auto& gv : module.GetGlobalVariables()) {
if (!gv) continue;
if (gv->IsConstant()) {
os << "@" << gv->GetName() << " = constant i32 " << gv->GetInitVal() << "\n";
} else if (gv->IsArray()) {
const char* et = gv->IsFloat() ? "float" : "i32";
os << "@" << gv->GetName() << " = global [" << gv->GetNumElements()
<< " x " << et << "] ";
if (!gv->HasInitVals()) {
os << "zeroinitializer\n";
} else if (gv->IsFloat()) {
const auto& vals = gv->GetInitValsF();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](float f){ return f == 0.0f; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
float fv = (i < (int)vals.size()) ? vals[i] : 0.0f;
double d = static_cast<double>(fv);
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
os << oss.str();
}
os << "]\n";
}
} else {
os << "@" << g->GetName() << " = global [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
const auto& vals = gv->GetInitVals();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](int v){ return v == 0; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
os << "i32 " << (i < (int)vals.size() ? vals[i] : 0);
}
os << "]\n";
}
}
} else {
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant i32 " << g->GetInitVal()
<< "\n";
} else {
os << "@" << g->GetName() << " = global i32 " << g->GetInitVal()
<< "\n";
}
os << "@" << gv->GetName() << " = global i32 " << gv->GetInitVal() << "\n";
}
}
if (!module.GetGlobalVariables().empty()) os << "\n";
// 2. 外部函数声明
// 2. 外部声明
for (const auto& decl : module.GetExternalDecls()) {
os << "declare " << TypeToStr(*decl.ret_type) << " @" << decl.name << "(";
for (size_t i = 0; i < decl.param_types.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToStr(*decl.param_types[i]);
}
if (decl.is_variadic) {
if (!decl.param_types.empty()) os << ", ";
os << "...";
}
os << ")\n";
}
if (!module.GetExternalDecls().empty()) os << "\n";
// 3. 函数定义
for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName()
<< "(";
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName() << "(";
for (size_t i = 0; i < func->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
auto* arg = func->GetArgument(i);
@ -132,172 +303,54 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
os << ") {\n";
// Build rename map: alloca instructions first (in block order), then rest
RenameMap rm;
int next_id = 0;
auto assign = [&](const Value* v) {
if (!v) return;
if (dynamic_cast<const ConstantInt*>(v)) return;
if (dynamic_cast<const ConstantFloat*>(v)) return;
if (dynamic_cast<const BasicBlock*>(v)) return;
if (dynamic_cast<const GlobalVariable*>(v)) return;
if (dynamic_cast<const Argument*>(v)) return;
if (rm.count(v) == 0) rm[v] = next_id++;
};
// Pass 1: all allocas across all blocks
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca) assign(ip.get());
}
// Pass 2: all non-alloca instructions in block order
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca) assign(ip.get());
}
// Print: entry block first with all allocas hoisted, then rest
bool first_bb = true;
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
os << bb->GetName() << ":\n";
for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get();
switch (inst->GetOpcode()) {
// ── 算术 ──────────────────────────────────────────────────────────
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op_str = nullptr;
switch (bin->GetOpcode()) {
case Opcode::Add: op_str = "add"; break;
case Opcode::Sub: op_str = "sub"; break;
case Opcode::Mul: op_str = "mul"; break;
case Opcode::Div: op_str = "sdiv"; break;
case Opcode::Mod: op_str = "srem"; break;
default: op_str = "?"; break;
}
os << " %" << bin->GetName() << " = " << op_str << " i32 "
<< ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs())
<< "\n";
break;
}
// ── 浮点算术 ──────────────────────────────────────────────────────
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op_str = nullptr;
switch (bin->GetOpcode()) {
case Opcode::FAdd: op_str = "fadd"; break;
case Opcode::FSub: op_str = "fsub"; break;
case Opcode::FMul: op_str = "fmul"; break;
case Opcode::FDiv: op_str = "fdiv"; break;
default: op_str = "?"; break;
}
os << " %" << bin->GetName() << " = " << op_str << " float "
<< ValStr(bin->GetLhs()) << ", " << ValStr(bin->GetRhs())
<< "\n";
break;
}
// ── 比较 ──────────────────────────────────────────────────────────
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " %" << cmp->GetName() << " = icmp "
<< PredToStr(cmp->GetPredicate()) << " i32 "
<< ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs())
<< "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " %" << cmp->GetName() << " = fcmp "
<< FPredToStr(cmp->GetPredicate()) << " float "
<< ValStr(cmp->GetLhs()) << ", " << ValStr(cmp->GetRhs())
<< "\n";
break;
}
// ── 内存 ──────────────────────────────────────────────────────────
case Opcode::Alloca: {
auto* al = static_cast<const AllocaInst*>(inst);
const char* elem_type = al->GetType()->IsPtrFloat32() ? "float" : "i32";
if (al->IsArray()) {
os << " %" << al->GetName() << " = alloca " << elem_type << ", i32 "
<< al->GetNumElements() << "\n";
} else {
os << " %" << al->GetName() << " = alloca " << elem_type << "\n";
}
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
os << " %" << gep->GetName()
<< " = getelementptr i32, i32* "
<< ValStr(gep->GetBasePtr()) << ", i32 "
<< ValStr(gep->GetIndex()) << "\n";
break;
}
case Opcode::Load: {
auto* ld = static_cast<const LoadInst*>(inst);
const char* val_type = ld->GetType()->IsFloat32() ? "float" : "i32";
const char* ptr_type = ld->GetPtr()->GetType()->IsPtrFloat32() ? "float*" : "i32*";
os << " %" << ld->GetName() << " = load " << val_type << ", " << ptr_type << " "
<< ValStr(ld->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* st = static_cast<const StoreInst*>(inst);
os << " store " << TypeVal(st->GetValue()) << ", "
<< TypeToStr(*st->GetPtr()->GetType()) << " "
<< ValStr(st->GetPtr()) << "\n";
break;
}
// ── 控制流 ────────────────────────────────────────────────────────
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->HasValue()) {
os << " ret void\n";
} else {
auto* v = ret->GetValue();
os << " ret " << TypeVal(v) << "\n";
}
break;
}
case Opcode::Br: {
auto* br = static_cast<const BrInst*>(inst);
os << " br label %" << br->GetTarget()->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << ValStr(cbr->GetCond()) << ", label %"
<< cbr->GetTrueBB()->GetName() << ", label %"
<< cbr->GetFalseBB()->GetName() << "\n";
break;
}
// ── 调用 ──────────────────────────────────────────────────────────
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
std::string ret_type_str;
if (call->IsVoid()) {
ret_type_str = "void";
} else {
ret_type_str = TypeToStr(*call->GetType());
}
// 打印赋值部分(仅当有返回值时)
if (!call->IsVoid() && !call->GetName().empty()) {
os << " %" << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << ret_type_str << " @" << call->GetCalleeName()
<< "(";
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
auto* arg = call->GetArg(i);
os << TypeVal(arg);
}
os << ")\n";
break;
}
// ── 类型转换 ──────────────────────────────────────────────────────
case Opcode::ZExt: {
auto* ze = static_cast<const ZExtInst*>(inst);
os << " %" << ze->GetName() << " = zext i1 "
<< ValStr(ze->GetSrc()) << " to i32\n";
break;
}
case Opcode::SIToFP: {
auto* si = static_cast<const SIToFPInst*>(inst);
os << " %" << si->GetName() << " = sitofp i32 "
<< ValStr(si->GetSrc()) << " to float\n";
break;
}
case Opcode::FPToSI: {
auto* fp = static_cast<const FPToSIInst*>(inst);
os << " %" << fp->GetName() << " = fptosi float "
<< ValStr(fp->GetSrc()) << " to i32\n";
break;
}
if (first_bb) {
first_bb = false;
// Print all allocas from all blocks
for (const auto& bb2 : func->GetBlocks()) {
if (!bb2) continue;
for (const auto& ip : bb2->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca)
PrintInst(ip.get(), os, rm);
}
// Print non-alloca instructions of entry block
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca)
PrintInst(ip.get(), os, rm);
} else {
// Non-entry blocks: skip allocas (already printed)
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca)
PrintInst(ip.get(), os, rm);
}
}
os << "}\n\n";

@ -224,7 +224,11 @@ AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, int num_elements,
// ─── GepInst ──────────────────────────────────────────────────────────────────
GepInst::GepInst(Value* base_ptr, Value* index, std::string name)
: Instruction(Opcode::Gep, Type::GetPtrInt32Type(), std::move(name)) {
: Instruction(Opcode::Gep,
(base_ptr && base_ptr->GetType()->IsPtrFloat32())
? Type::GetPtrFloat32Type()
: Type::GetPtrInt32Type(),
std::move(name)) {
if (!base_ptr || !index) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数"));
}
@ -265,10 +269,15 @@ GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
// ─── GlobalVariable ────────────────────────────────────────────────────────────
GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements)
: Value(Type::GetPtrInt32Type(), std::move(name)),
int num_elements, bool is_array_decl,
bool is_float)
: Value(is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type(),
std::move(name)),
is_const_(is_const),
is_float_(is_float),
init_val_(init_val),
num_elements_(num_elements) {}
init_val_f_(0.0f),
num_elements_(num_elements),
is_array_decl_(is_array_decl) {}
} // namespace ir

@ -28,9 +28,10 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
// ─── 全局变量管理 ─────────────────────────────────────────────────────────────
GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
bool is_const, int init_val,
int num_elements) {
int num_elements, bool is_array_decl,
bool is_float) {
globals_.push_back(
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements));
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements, is_array_decl, is_float));
GlobalVariable* g = globals_.back().get();
global_map_[name] = g;
return g;

@ -117,33 +117,101 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
for (int d : dims) total *= (d > 0 ? d : 1);
if (in_global_scope_) {
auto* gv = module_.CreateGlobalVariable(name, true, 0, total);
auto* gv = module_.CreateGlobalVariable(name, true, 0, total, true);
global_storage_map_[constDef] = gv;
global_array_dims_[constDef] = dims;
// 计算初始值并存入全局变量
if (constDef->constInitVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::vector<int> flat(total, 0);
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) { flat[pos] = EvalConstExprInt(iv->constExp()); return; }
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; fill(sub,a,sub_stride); cur=a+sub_stride; }
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur+top_stride-1)/top_stride)*top_stride; fill(sub,a,top_stride); cur=a+top_stride; }
}
}
gv->SetInitVals(flat);
}
} else {
auto* slot = builder_.CreateAllocaArray(total, name);
storage_map_[constDef] = slot;
array_dims_[constDef] = dims;
// 扁平化初始化
// 按 C 语义扁平化初始化(子列表对齐到维度边界)
if (constDef->constInitVal()) {
std::vector<int> flat;
flat.reserve(total);
std::function<void(SysYParser::ConstInitValContext*)> flatten =
[&](SysYParser::ConstInitValContext* iv) {
if (!iv) return;
if (iv->constExp()) {
flat.push_back(EvalConstExprInt(iv->constExp()));
} else {
for (auto* sub : iv->constInitVal()) flatten(sub);
}
};
flatten(constDef->constInitVal());
std::vector<int> flat(total, 0);
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) {
flat[pos] = EvalConstExprInt(iv->constExp());
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
for (int i = 0; i < total; ++i) {
int v = (i < (int)flat.size()) ? flat[i] : 0;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(v), ptr);
builder_.CreateStore(builder_.CreateConstInt(flat[i]), ptr);
}
}
}
@ -194,9 +262,97 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total, true, is_float);
global_storage_map_[ctx] = gv;
global_array_dims_[ctx] = dims;
// 计算初始值
if (ctx->initVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
if (is_float) {
std::vector<float> flat(total, 0.0f);
std::function<void(SysYParser::InitValContext*, int, int)> fill_f;
fill_f = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<float>(sem::EvaluateExp(*iv->exp()->addExp()).float_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill_f(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<float>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill_f(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitValsF(flat);
} else {
std::vector<int> flat(total, 0);
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<int>(sem::EvaluateExp(*iv->exp()->addExp()).int_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<int>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).int_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitVals(flat);
}
}
}
} else {
if (storage_map_.count(ctx)) {
@ -211,6 +367,14 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
ir::Value* init;
if (ctx->initVal() && ctx->initVal()->exp()) {
init = EvalExpr(*ctx->initVal()->exp());
// Coerce init value to slot type
if (!is_float && init->IsFloat32()) {
init = ToInt(init);
} else if (is_float && init->IsInt32()) {
init = ToFloat(init);
} else if (!is_float && init->IsInt1()) {
init = ToI32(init);
}
} else {
init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
@ -219,40 +383,95 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* slot = builder_.CreateAllocaArray(total, name);
auto* slot = is_float ? builder_.CreateAllocaArrayF32(total, module_.GetContext().NextTemp())
: builder_.CreateAllocaArray(total, module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
array_dims_[ctx] = dims;
ir::Value* zero_init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
if (ctx->initVal()) {
// 收集扁平化初始值
std::vector<ir::Value*> flat;
flat.reserve(total);
std::function<void(SysYParser::InitValContext*)> flatten =
[&](SysYParser::InitValContext* iv) {
if (!iv) return;
if (iv->exp()) {
flat.push_back(EvalExpr(*iv->exp()));
} else {
for (auto* sub : iv->initVal()) flatten(sub);
}
};
flatten(ctx->initVal());
for (int i = 0; i < total; ++i) {
ir::Value* v = (i < (int)flat.size()) ? flat[i]
: builder_.CreateConstInt(0);
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(v, ptr);
// 按 C 语义扁平化初始值:子列表对齐到对应维度边界
std::vector<ir::Value*> flat(total, zero_init);
// 计算各维度的 stridestride[i] = dims[i]*dims[i+1]*...*dims[n-1]
// 但我们需要「子列表对应第几维的 stride」
// 顶层stride = total / dims[0](即每行的元素数)
// 递归时 stride 继续除以当前维度大小
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0]; // 每个顶层子列表占用的元素数
// fill(iv, pos, stride):将 iv 的内容填入 flat[pos..pos+stride)
// stride 表示当前层子列表对应的元素个数
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
flat[pos] = EvalExpr(*iv->exp());
return;
}
// 子列表内的 stride = stride / (当前层首维大小)
// 找到对应的 strides 层stride == strides[k] → 子stride = strides[k+1]
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k) {
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
}
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 sub_stride 边界
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
// 顶层扫描
int cur = 0;
if (ctx->initVal()->exp()) {
flat[0] = EvalExpr(*ctx->initVal()->exp());
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 top_stride 边界
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
} else {
// 零初始化
// 先 memset 归零,再只写入非零元素
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
for (int i = 0; i < total; ++i) {
bool is_zero = false;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(flat[i])) {
is_zero = (ci->GetValue() == 0);
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(flat[i])) {
is_zero = (cf->GetValue() == 0.0f);
}
if (is_zero) continue;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), ptr);
builder_.CreateStore(flat[i], ptr);
}
} else {
// 零初始化:用 memset 归零
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
(void)zero_init;
}
}
}

@ -10,10 +10,15 @@
// ─── 辅助 ─────────────────────────────────────────────────────────────────────
// 把 i32 值转成 i1icmp ne i32 v, 0
// 把 i32/float 值转成 i1
ir::Value* IRGenImpl::ToI1(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value"));
if (v->IsInt1()) return v;
if (v->IsFloat32()) {
return builder_.CreateFCmp(ir::FCmpPredicate::ONE, v,
builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmp(ir::ICmpPredicate::NE, v,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
@ -87,7 +92,13 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
} else if (name == "getch") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
} else if (name == "getfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似
module_.DeclareExternalFunc(name, ir::Type::GetFloat32Type(), {});
} else if (name == "getarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type()});
} else if (name == "getfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrFloat32Type()});
} else if (name == "putint") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -95,10 +106,16 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else if (name == "putfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {});
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetFloat32Type()});
} else if (name == "putarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
{ir::Type::GetInt32Type(),
ir::Type::GetPtrInt32Type()});
} else if (name == "putfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(),
ir::Type::GetPtrFloat32Type()});
} else if (name == "starttime" || name == "stoptime") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -227,13 +244,113 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* exp : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*exp));
// 检查是否是数组变量(无索引的 lVar若是则传指针而非 load
ir::Value* arg = nullptr;
auto* add = exp->addExp();
if (add && add->mulExp().size() == 1) {
auto* mul = add->mulExp(0);
if (mul && mul->unaryExp().size() == 1) {
auto* unary = mul->unaryExp(0);
if (unary && !unary->unaryOp() && unary->primaryExp()) {
auto* primary = unary->primaryExp();
if (primary && primary->lVar() && primary->lVar()->exp().empty()) {
auto* lvar = primary->lVar();
auto* decl = sema_.ResolveVarUse(lvar->Ident());
if (decl) {
// 检查是否是数组参数storage_map_ 里存的是指针)
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
auto* val = it->second;
if (val && (val->IsPtrInt32() || val->IsPtrFloat32())) {
// 检查是否是 Argument数组参数直接传指针
if (dynamic_cast<ir::Argument*>(val)) {
arg = val;
} else if (array_dims_.count(decl)) {
// 本地数组(含 dims 记录):传首元素地址
arg = builder_.CreateGep(val, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
// 检查全局数组
if (!arg) {
auto git = global_storage_map_.find(decl);
if (git != global_storage_map_.end()) {
auto* gv = dynamic_cast<ir::GlobalVariable*>(git->second);
if (gv && gv->IsArray()) {
arg = builder_.CreateGep(git->second,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
}
}
}
}
}
// Also handle partially-indexed multi-dim arrays: arr[i] where arr is
// int arr[M][N] should pass &arr[i*N] as i32*, not load arr[i] as i32.
if (!arg) {
auto* add2 = exp->addExp();
if (add2 && add2->mulExp().size() == 1) {
auto* mul2 = add2->mulExp(0);
if (mul2 && mul2->unaryExp().size() == 1) {
auto* unary2 = mul2->unaryExp(0);
if (unary2 && !unary2->unaryOp() && unary2->primaryExp()) {
auto* primary2 = unary2->primaryExp();
if (primary2 && primary2->lVar() && !primary2->lVar()->exp().empty()) {
auto* lvar2 = primary2->lVar();
auto* decl2 = sema_.ResolveVarUse(lvar2->Ident());
if (decl2) {
std::vector<int> dims2;
ir::Value* base2 = nullptr;
auto it2 = array_dims_.find(decl2);
if (it2 != array_dims_.end()) {
dims2 = it2->second;
auto sit = storage_map_.find(decl2);
if (sit != storage_map_.end()) base2 = sit->second;
} else {
auto git2 = global_array_dims_.find(decl2);
if (git2 != global_array_dims_.end()) {
dims2 = git2->second;
auto gsit = global_storage_map_.find(decl2);
if (gsit != global_storage_map_.end()) base2 = gsit->second;
}
}
// Partially indexed: fewer indices than dimensions -> pass pointer
bool is_param = !dims2.empty() && dims2[0] == -1;
size_t effective_dims = is_param ? dims2.size() - 1 : dims2.size();
if (base2 && !dims2.empty() &&
lvar2->exp().size() < effective_dims + (is_param ? 1 : 0)) {
arg = EvalLVarAddr(lvar2);
}
}
}
}
}
}
}
if (!arg) arg = EvalExpr(*exp);
args.push_back(arg);
}
}
// 模块内已知函数?
ir::Function* callee = module_.GetFunction(callee_name);
if (callee) {
// Coerce args to match parameter types
for (size_t i = 0; i < args.size() && i < callee->GetNumArgs(); ++i) {
auto* param = callee->GetArgument(i);
if (!param || !args[i]) continue;
if (param->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name =
callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp();
auto* call =
@ -246,15 +363,28 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 外部函数
EnsureExternalDecl(callee_name);
// 获取返回类型
// 获取返回类型和参数类型
std::shared_ptr<ir::Type> ret_type = ir::Type::GetInt32Type();
std::vector<std::shared_ptr<ir::Type>> param_types;
for (const auto& decl : module_.GetExternalDecls()) {
if (decl.name == callee_name) {
ret_type = decl.ret_type;
param_types = decl.param_types;
break;
}
}
bool is_void = ret_type->IsVoid();
// Coerce args to match external function parameter types
for (size_t i = 0; i < args.size() && i < param_types.size(); ++i) {
if (!args[i]) continue;
if (param_types[i]->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param_types[i]->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param_types[i]->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name = is_void ? "" : module_.GetContext().NextTemp();
auto* call = builder_.CreateCallExternal(callee_name, ret_type,
std::move(args), ret_name);
@ -331,40 +461,26 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
throw std::runtime_error(FormatError("irgen", "数组索引维度过多"));
}
ir::Value* offset = builder_.CreateConstInt(0);
ir::Value* offset = nullptr;
if (is_array_param) {
// 数组参数dims[0]=-1, dims[1..n]是已知维度
// 索引indices[0]对应第一维indices[1]对应第二维...
for (size_t i = 0; i < indices.size(); ++i) {
ir::Value* idx = EvalExpr(*indices[i]);
if (i == 0) {
// 第一维stride = dims[1] * dims[2] * ... (如果有的话)
int stride = 1;
for (size_t j = 1; j < dims.size(); ++j) {
stride *= dims[j];
}
if (stride > 1) {
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
} else {
offset = builder_.CreateAdd(offset, idx,
module_.GetContext().NextTemp());
}
int stride = 1;
size_t start = (i == 0) ? 1 : i + 1;
for (size_t j = start; j < dims.size(); ++j) stride *= dims[j];
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
// 后续维度
int stride = 1;
for (size_t j = i + 1; j < dims.size(); ++j) {
stride *= dims[j];
}
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
} else {
@ -374,15 +490,24 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1];
if (i < (int)indices.size()) {
ir::Value* idx = EvalExpr(*indices[i]);
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
}
}
if (!offset) offset = builder_.CreateConstInt(0);
return builder_.CreateGep(base, offset, module_.GetContext().NextTemp());
}
@ -486,8 +611,8 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.end");
builder_.CreateCondBr(result, end_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb);
@ -498,6 +623,7 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}
@ -523,8 +649,8 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.end");
builder_.CreateCondBr(result, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
@ -535,6 +661,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}

@ -76,6 +76,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
// 设置插入点到入口块
builder_.SetInsertPoint(func_->GetEntry());
builder_.SetAllocaBlock(func_->GetEntry());
// 处理参数
if (ctx->funcFParams()) {

@ -25,6 +25,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->Return()) {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp());
// Coerce return value to function return type
if (func_->GetType()->IsInt32() && v->IsFloat32()) {
v = ToInt(v);
} else if (func_->GetType()->IsFloat32() && v->IsInt32()) {
v = ToFloat(v);
} else if (func_->GetType()->IsInt32() && v->IsInt1()) {
v = ToI32(v);
}
builder_.CreateRet(v);
} else {
builder_.CreateRetVoid();
@ -54,6 +62,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->lVar() && ctx->Assign()) {
ir::Value* rhs = EvalExpr(*ctx->exp());
ir::Value* addr = EvalLVarAddr(ctx->lVar());
// Coerce rhs to match slot type
if (addr->IsPtrInt32() && rhs->IsFloat32()) {
rhs = ToInt(rhs);
} else if (addr->IsPtrFloat32() && rhs->IsInt32()) {
rhs = ToFloat(rhs);
} else if (addr->IsPtrInt32() && rhs->IsInt1()) {
rhs = ToI32(rhs);
}
builder_.CreateStore(rhs, addr);
return BlockFlow::Continue;
}
@ -74,32 +90,47 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
auto stmts = ctx->stmt();
// Step 1: evaluate condition (may create short-circuit blocks with lower
// SSA numbers — must happen before any branch-target blocks are created).
ir::Value* cond_val = EvalCond(*ctx->cond());
// Step 2: create then_bb now (its label number will be >= all short-circuit
// block numbers allocated during EvalCond).
ir::BasicBlock* then_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.then");
module_.GetContext().NextLabel() + ".if.then");
// Step 3: create else_bb/merge_bb as placeholders. They will be moved to
// the end of the block list after their predecessors are filled in, so the
// block ordering in the output will be correct even though their label
// numbers are allocated here (before then-body sub-blocks are created).
ir::BasicBlock* else_bb = nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.end");
ir::BasicBlock* merge_bb = nullptr;
// 求值条件(可能创建短路求值块)
ir::Value* cond_val = EvalCond(*ctx->cond());
if (stmts.size() >= 2) {
else_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.else");
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
} else {
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
}
// 检查当前块是否已终结(短路求值可能导致)
// Check if current block already terminated (short-circuit may do this)
if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) {
// 条件求值已经终结了当前块,无法继续
// 这种情况下我们需要在merge_bb继续
func_->MoveBlockToEnd(then_bb);
if (else_bb) func_->MoveBlockToEnd(else_bb);
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (stmts.size() >= 2) {
// if-else
else_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".if.else");
builder_.CreateCondBr(cond_val, then_bb, else_bb);
} else {
builder_.CreateCondBr(cond_val, then_bb, merge_bb);
}
// then 分支
// then 分支 — visit body (may create many sub-blocks with higher numbers)
func_->MoveBlockToEnd(then_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = VisitStmt(*stmts[0]);
if (then_flow != BlockFlow::Terminated) {
@ -108,6 +139,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
// else 分支
if (else_bb) {
func_->MoveBlockToEnd(else_bb);
builder_.SetInsertPoint(else_bb);
auto else_flow = VisitStmt(*stmts[1]);
if (else_flow != BlockFlow::Terminated) {
@ -115,6 +147,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
}
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
@ -124,28 +157,32 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (!ctx->cond()) {
throw std::runtime_error(FormatError("irgen", "while 缺少条件"));
}
ir::BasicBlock* cond_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.cond");
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.end");
module_.GetContext().NextLabel() + ".while.cond");
// 跳转到条件块
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(cond_bb);
}
// 条件块
// EvalCond MUST come before creating body_bb/after_bb so that
// short-circuit blocks get lower SSA numbers than the loop body blocks.
builder_.SetInsertPoint(cond_bb);
ir::Value* cond_val = EvalCond(*ctx->cond());
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.end");
// 检查条件求值后是否已终结
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateCondBr(cond_val, body_bb, after_bb);
}
// 循环体(压入循环栈)
func_->MoveBlockToEnd(body_bb);
loop_stack_.push_back({cond_bb, after_bb});
builder_.SetInsertPoint(body_bb);
auto stmts = ctx->stmt();
@ -159,6 +196,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
loop_stack_.pop_back();
func_->MoveBlockToEnd(after_bb);
builder_.SetInsertPoint(after_bb);
return BlockFlow::Continue;
}

@ -46,13 +46,16 @@ int main(int argc, char** argv) {
}
if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
// 修改:支持多函数
auto machine_funcs = mir::LowerToMIR(*module);
for (auto& mf : machine_funcs) {
mir::RunRegAlloc(*mf);
mir::RunFrameLowering(*mf);
}
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(machine_funcs, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {
@ -65,4 +68,4 @@ int main(int argc, char** argv) {
return 1;
}
return 0;
}
}

@ -2,9 +2,16 @@
#include <ostream>
#include <stdexcept>
#include <iostream>
#include <vector>
#include <unordered_map>
#include "utils/Log.h"
// 引用全局变量(定义在 Lowering.cpp 中)
extern std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
@ -16,63 +23,498 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
void EmitStackLoad(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " lw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " lw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
} // namespace
void EmitStackStore(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " sw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " sw " << PhysRegName(src) << ", 0(t4)\n";
}
}
void PrintAsm(const MachineFunction& function, std::ostream& os) {
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
void EmitStackLoadFloat(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " flw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " flw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
void EmitStackStoreFloat(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " fsw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " fsw " << PhysRegName(src) << ", 0(t4)\n";
}
}
// 输出单个函数的汇编
void PrintAsmFunction(const MachineFunction& function, std::ostream& os) {
// 收集所有基本块名称
std::unordered_map<const MachineBasicBlock*, std::string> block_names;
for (const auto& block_ptr : function.GetBlocks()) {
block_names[block_ptr.get()] = block_ptr->GetName();
}
int total_frame_size = 16 + function.GetFrameSize();
bool prologue_done = false;
for (const auto& block_ptr : function.GetBlocks()) {
const auto& block = *block_ptr;
// 输出基本块标签(入口块不输出,因为函数名已经是标签)
if (block.GetName() != "entry") {
os << block.GetName() << ":\n";
}
for (const auto& inst : block.GetInstructions()) {
const auto& ops = inst.GetOperands();
// 在入口块的第一条指令前输出序言
if (!prologue_done && block.GetName() == "entry") {
// 处理大栈帧的情况
if (total_frame_size <= 2047) {
os << " addi sp, sp, -" << total_frame_size << "\n";
} else {
os << " li t4, -" << total_frame_size << "\n";
os << " add sp, sp, t4\n";
}
// 保存 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " sw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t4, " << ra_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw ra, 0(t4)\n";
}
if (s0_offset <= 2047) {
os << " sw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t4, " << s0_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw s0, 0(t4)\n";
}
prologue_done = true;
}
switch (inst.GetOpcode()) {
case Opcode::Prologue:
case Opcode::Epilogue:
break;
case Opcode::MovImm:
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::Load: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoad(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Store: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStore(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Add:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sub:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Mul: {
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
} else {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
}
break;
}
case Opcode::Div:
os << " div " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Rem:
os << " rem " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slt:
os << " slt " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slti:
os << " slti " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Sltu:
os << " sltu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sltiu: // <-- 添加这个
os << " sltiu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Xori:
os << " xori " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::LoadGlobalAddr: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la " << PhysRegName(ops.at(0).GetReg()) << ", " << global_name << "\n";
break;
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
case Opcode::LoadGlobal:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::StoreGlobal: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la t1, " << global_name << "\n";
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t1)\n";
break;
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
case Opcode::GEP:
break;
case Opcode::LoadIndirect:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Call: {
std::string func_name = "memset";
if (!ops.empty() && ops[0].GetKind() == Operand::Kind::Func) {
func_name = ops[0].GetFuncName();
}
os << " call " << func_name << "\n";
break;
}
case Opcode::LoadAddr: {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
if (slot.offset >= -2048 && slot.offset <= 2047) {
os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", sp, " << slot.offset << "\n";
} else {
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", " << slot.offset << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", sp, "
<< PhysRegName(ops.at(0).GetReg()) << "\n";
}
break;
}
case Opcode::Slli:
os << " slli " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::StoreIndirect:
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Ret:{
// 恢复 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " lw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t3, " << ra_offset << "\n";
os << " add t3, sp, t3\n";
os << " lw ra, 0(t3)\n";
}
if (s0_offset <= 2047) {
os << " lw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t3, " << s0_offset << "\n";
os << " add t3, sp, t3\n";
os << " lw s0, 0(t3)\n";
}
// 恢复 sp
if (total_frame_size <= 2047) {
os << " addi sp, sp, " << total_frame_size << "\n";
} else {
os << " li t3, " << total_frame_size << "\n";
os << " add sp, sp, t3\n";
}
os << " ret\n";
break;
}
case Opcode::Br: {
auto* target = reinterpret_cast<MachineBasicBlock*>(ops[0].GetImm64());
os << " j " << target->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* true_target = reinterpret_cast<MachineBasicBlock*>(ops[1].GetImm64());
auto* false_target = reinterpret_cast<MachineBasicBlock*>(ops[2].GetImm64());
auto true_it = block_names.find(true_target);
auto false_it = block_names.find(false_target);
if (true_it == block_names.end() || false_it == block_names.end()) {
throw std::runtime_error(FormatError("mir", "CondBr: 找不到基本块名称"));
}
os << " bnez " << PhysRegName(ops[0].GetReg()) << ", "
<< true_it->second << "\n";
os << " j " << false_it->second << "\n";
break;
}
// 浮点运算
case Opcode::FAdd:
os << " fadd.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FSub:
os << " fsub.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMul:
os << " fmul.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FDiv:
os << " fdiv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FEq:
os << " feq.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLt:
os << " flt.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLe:
os << " fle.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMov:
os << " fmv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovWX:
os << " fmv.w.x " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovXW:
os << " fmv.x.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " fcvt.s.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvt.w.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::LoadFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " flw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoadFloat(os, ops[0].GetReg(), slot.offset);
}
break;
case Opcode::StoreFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " fsw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStoreFloat(os, ops[0].GetReg(), slot.offset);
}
break;
default:
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 输出多个函数的汇编
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os) {
// ========== 输出全局变量 ==========
// 输出 .data 段(已初始化的全局变量)
bool hasData = false;
for (const auto& gv : g_globalVars) {
if (!gv.isConst) {
if (!hasData) {
os << ".data\n";
hasData = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n";
// 输出 .rodata 段(只读常量)
bool hasRodata = false;
for (const auto& gv : g_globalVars) {
if (gv.isConst) {
if (!hasRodata) {
os << ".section .rodata\n";
hasRodata = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
}
}
// ========== 输出代码段 ==========
os << ".text\n";
// 输出每个函数
for (const auto& func_ptr : functions) {
os << ".global " << func_ptr->GetName() << "\n";
os << ".type " << func_ptr->GetName() << ", @function\n";
os << func_ptr->GetName() << ":\n";
PrintAsmFunction(*func_ptr, os);
os << "\n"; // 函数之间加空行
}
}
} // namespace mir
} // namespace mir

@ -15,10 +15,12 @@ target_link_libraries(mir_core PUBLIC
ir
)
target_compile_options(mir_core PRIVATE -Wno-unused-parameter)
add_subdirectory(passes)
add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE
mir_core
mir_passes
)
)

@ -18,9 +18,9 @@ void RunFrameLowering(MachineFunction& function) {
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
//if (-cursor < -2048) {
//throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
//}
}
cursor = 0;
@ -30,7 +30,8 @@ void RunFrameLowering(MachineFunction& function) {
}
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions();
// 修复GetEntry() 返回指针,使用 ->
auto& insts = function.GetEntry()->GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
@ -42,4 +43,4 @@ void RunFrameLowering(MachineFunction& function) {
insts = std::move(lowered);
}
} // namespace mir
} // namespace mir

@ -1,123 +1,642 @@
#include "mir/MIR.h"
#include <iostream>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <cstring>
#include "ir/IR.h"
#include "utils/Log.h"
std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
static std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
MachineBasicBlock* GetOrCreateBlock(const ir::BasicBlock* ir_block,
MachineFunction& function) {
auto it = block_map.find(ir_block);
if (it != block_map.end()) {
return it->second;
}
std::string name = ir_block->GetName();
if (name.empty()) {
name = "block_" + std::to_string(block_map.size());
}
auto* block = function.CreateBlock(name);
block_map[ir_block] = block;
return block;
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
const ValueSlotMap& slots, MachineBasicBlock& block,
bool for_address=false) {
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
auto it = slots.find(arg);
if (it != slots.end()) {
// 从栈槽加载参数值
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
}
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
int64_t val = constant->GetValue();
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
{Operand::Reg(target), Operand::Imm(static_cast<int>(val))});
return;
}
// 处理浮点常量
if (auto* fconstant = dynamic_cast<const ir::ConstantFloat*>(value)) {
float val = fconstant->GetValue();
uint32_t bits;
memcpy(&bits, &val, sizeof(val));
// 检查目标是否是浮点寄存器
bool target_is_fp = (target == PhysReg::FT0 || target == PhysReg::FT1 ||
target == PhysReg::FT2 || target == PhysReg::FT3 ||
target == PhysReg::FT4 || target == PhysReg::FT5 ||
target == PhysReg::FT6 || target == PhysReg::FT7 ||
target == PhysReg::FA0 || target == PhysReg::FA1 ||
target == PhysReg::FA2 || target == PhysReg::FA3 ||
target == PhysReg::FA4 || target == PhysReg::FA5 ||
target == PhysReg::FA6 || target == PhysReg::FA7);
if (target_is_fp) {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::T0), Operand::Imm(static_cast<int>(bits))});
block.Append(Opcode::FMovWX, {Operand::Reg(target), Operand::Reg(PhysReg::T0)});
} else {
// 目标是整数寄存器,直接加载
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(static_cast<int>(bits))});
}
return;
}
if (auto* gep = dynamic_cast<const ir::GepInst*>(value)) {
EmitValueToReg(gep->GetBasePtr(), target, slots, block, true);
EmitValueToReg(gep->GetIndex(), PhysReg::T1, slots, block);
block.Append(Opcode::Slli, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Imm(2)});
block.Append(Opcode::Add, {Operand::Reg(target),
Operand::Reg(target),
Operand::Reg(PhysReg::T1)});
return;
}
if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(value)) {
auto it = slots.find(alloca);
if (it != slots.end()) {
block.Append(Opcode::LoadAddr,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(value)) {
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(target), Operand::Global(global->GetName())});
if (!for_address) {
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::Reg(target)});
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Reg(target)});
}
}
return;
}
// 关键:在 slots 中查找,并根据类型生成正确的加载指令
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
if (it != slots.end()) {
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
std::cerr << "未找到的值: " << value << std::endl;
std::cerr << " 名称: " << value->GetName() << std::endl;
std::cerr << " 类型: " << (value->GetType()->IsFloat32() ? "float" : "int") << std::endl;
std::cerr << " 是否是 ConstantInt: " << (dynamic_cast<const ir::ConstantInt*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 ConstantFloat: " << (dynamic_cast<const ir::ConstantFloat*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 Instruction: " << (dynamic_cast<const ir::Instruction*>(value) != nullptr) << std::endl;
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
void StoreRegToSlot(PhysReg reg, int slot, MachineBasicBlock& block, bool isFloat = false) {
if (isFloat) {
block.Append(Opcode::StoreFloat,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
} else {
block.Append(Opcode::Store,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
}
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots) {
auto& block = function.GetEntry();
// 将 LowerInstruction 重命名为 LowerInstructionToBlock并添加 MachineBasicBlock 参数
void LowerInstructionToBlock(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots, MachineBasicBlock& block) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex());
auto& alloca = static_cast<const ir::AllocaInst&>(inst);
int size = 4;
if (alloca.GetNumElements() > 1) {
size = alloca.GetNumElements() * 4;
}
slots.emplace(&inst, function.CreateFrameIndex(size));
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T2, slots, block);
EmitValueToReg(store.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::T2), Operand::Reg(PhysReg::T0)});
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
std::string global_name = global->GetName();
block.Append(Opcode::StoreGlobal,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
return;
}
auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
if (dst != slots.end()) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
StoreRegToSlot(PhysReg::T0, dst->second, block);
return;
}
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
return;
throw std::runtime_error(FormatError("mir", "Store: 无法处理的指针类型"));
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(load.GetPtr())) {
EmitValueToReg(load.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
int dst_slot = function.CreateFrameIndex(4);
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(load.GetPtr())) {
int dst_slot = function.CreateFrameIndex(4);
std::string global_name = global->GetName();
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, true);
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
auto src = slots.find(load.GetPtr());
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
if (src != slots.end()) {
int dst_slot = function.CreateFrameIndex(4);
if (load.GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
} else {
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
throw std::runtime_error(FormatError("mir", "Load: 无法处理的指针类型"));
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
EmitValueToReg(bin.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::T1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::Add: op = Opcode::Add; break;
case ir::Opcode::Sub: op = Opcode::Sub; break;
case ir::Opcode::Mul: op = Opcode::Mul; break;
case ir::Opcode::Div: op = Opcode::Div; break;
case ir::Opcode::Mod: op = Opcode::Rem; break;
default: op = Opcode::Add; break;
}
block.Append(op, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
case ir::Opcode::Gep: {
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
block.Append(Opcode::Ret);
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
for (size_t i = 0; i < call.GetNumArgs() && i < 8; i++) {
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
EmitValueToReg(call.GetArg(i), argReg, slots, block);
}
std::string func_name = call.GetCalleeName();
block.Append(Opcode::Call, {Operand::Func(func_name)});
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
StoreRegToSlot(PhysReg::A0, dst_slot, block);
slots.emplace(&inst, dst_slot);
}
return;
}
case ir::Opcode::Sub:
case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
}
case ir::Opcode::ICmp: {
auto& icmp = static_cast<const ir::ICmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(icmp.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(icmp.GetRhs(), PhysReg::T1, slots, block);
ir::ICmpPredicate pred = icmp.GetPredicate();
switch (pred) {
case ir::ICmpPredicate::EQ:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
}
case ir::ICmpPredicate::NE:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SLT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
break;
case ir::ICmpPredicate::SLE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SGT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
break;
case ir::ICmpPredicate::SGE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
}
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::ZExt: {
auto& zext = static_cast<const ir::ZExtInst&>(inst);
int dst_slot = function.CreateFrameIndex(4); // i32 是 4 字节
// 获取源操作数的值
EmitValueToReg(zext.GetSrc(), PhysReg::T0, slots, block);
// 存储到新栈槽
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(bin.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::FT1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::FAdd: op = Opcode::FAdd; break;
case ir::Opcode::FSub: op = Opcode::FSub; break;
case ir::Opcode::FMul: op = Opcode::FMul; break;
case ir::Opcode::FDiv: op = Opcode::FDiv; break;
default: op = Opcode::FAdd; break;
}
block.Append(op, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
} // namespace
case ir::Opcode::FCmp: {
auto& fcmp = static_cast<const ir::FCmpInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(fcmp.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(fcmp.GetRhs(), PhysReg::FT1, slots, block);
ir::FCmpPredicate pred = fcmp.GetPredicate();
switch (pred) {
case ir::FCmpPredicate::OEQ:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLT:
block.Append(Opcode::FLt, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLE:
block.Append(Opcode::FLe, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
default:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
}
block.Append(Opcode::FMov, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
DefaultContext();
case ir::Opcode::SIToFP: {
auto& conv = static_cast<const ir::SIToFPInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "SIToFP: 找不到源操作数的栈槽"));
}
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
slots.emplace(&inst, dst_slot);
return;
}
if (module.GetFunctions().size() != 1) {
throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
}
case ir::Opcode::FPToSI: {
auto& conv = static_cast<const ir::FPToSIInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "FPToSI: 找不到源操作数的栈槽"));
}
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BrInst&>(inst);
auto* target = br.GetTarget();
MachineBasicBlock* target_block = GetOrCreateBlock(target, function);
block.Append(Opcode::Br, {Operand::Imm64(reinterpret_cast<intptr_t>(target_block))});
return;
}
case ir::Opcode::CondBr: {
auto& condbr = static_cast<const ir::CondBrInst&>(inst);
auto* true_bb = condbr.GetTrueBB();
auto* false_bb = condbr.GetFalseBB();
const auto& func = *module.GetFunctions().front();
if (func.GetName() != "main") {
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
// 如果条件涉及函数调用,需要特殊处理
// 简单方案:将条件值保存到栈槽
int cond_slot = function.CreateFrameIndex(4);
EmitValueToReg(condbr.GetCond(), PhysReg::T0, slots, block);
// 保存条件值到栈
block.Append(Opcode::Store, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)});
// 从栈加载条件值(确保函数调用后还能获取)
block.Append(Opcode::Load, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)});
block.Append(Opcode::Sltu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::ZERO),
Operand::Reg(PhysReg::T0)});
MachineBasicBlock* true_block = GetOrCreateBlock(true_bb, function);
MachineBasicBlock* false_block = GetOrCreateBlock(false_bb, function);
block.Append(Opcode::CondBr, {Operand::Reg(PhysReg::T1),
Operand::Imm64(reinterpret_cast<intptr_t>(true_block)),
Operand::Imm64(reinterpret_cast<intptr_t>(false_block))});
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
auto val = ret.GetValue();
if (val->GetType()->IsFloat32()) {
auto it = slots.find(val);
if (it != slots.end()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::A0),
Operand::Reg(PhysReg::FT0)});
} else {
throw std::runtime_error(FormatError("mir", "Ret: 找不到浮点返回值的栈槽"));
}
} else {
EmitValueToReg(val, PhysReg::A0, slots, block);
}
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::A0), Operand::Imm(0)});
}
block.Append(Opcode::Ret);
return;
}
default: {
break;
}
}
}
} // namespace
std::unique_ptr<MachineFunction> LowerFunctionToMIR(const ir::Function& func) {
block_map.clear();
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
const auto* entry = func.GetEntry();
if (!entry) {
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
// ========== 新增:为函数参数分配栈槽 ==========
for (size_t i = 0; i < func.GetNumArgs(); i++) {
ir::Argument* arg = func.GetArgument(i);
int slot = machine_func->CreateFrameIndex(4); // int 和指针都是 4 字节
// 将参数值从寄存器存储到栈槽
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
MachineBasicBlock* entry = machine_func->GetEntry();
// 存储参数到栈槽
if (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32()) {
// 指针类型
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsInt32()) {
// 整数类型
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsFloat32()) {
// 浮点类型
entry->Append(Opcode::StoreFloat, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
}
slots[arg] = slot;
}
for (const auto& inst : entry->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots);
// 第一遍:创建所有 IR 基本块对应的 MIR 基本块
for (const auto& ir_block : func.GetBlocks()) {
GetOrCreateBlock(ir_block.get(), *machine_func);
}
// 第二遍:遍历所有基本块,降低指令
for (const auto& ir_block : func.GetBlocks()) {
MachineBasicBlock* mbb = GetOrCreateBlock(ir_block.get(), *machine_func);
for (const auto& inst : ir_block->GetInstructions()) {
LowerInstructionToBlock(*inst, *machine_func, slots, *mbb);
}
}
return machine_func;
}
} // namespace mir
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module) {
DefaultContext();
// 收集全局变量(只做一次)
g_globalVars.clear();
for (const auto& global : module.GetGlobalVariables()) {
GlobalVarInfo info;
info.name = global->GetName();
info.isConst = global->IsConst();
info.isArray = global->IsArray();
info.arraySize = global->GetNumElements();
info.isFloat = global->IsFloat();
info.value = 0;
info.valueF = 0.0f;
if (info.isArray) {
if (info.isFloat) {
const auto& initVals = global->GetInitValsF();
for (float val : initVals) {
info.arrayValuesF.push_back(val);
}
} else {
if (global->HasInitVals()) {
const auto& initVals = global->GetInitVals();
for (int val : initVals) {
info.arrayValues.push_back(val);
}
}
}
} else {
if (info.isFloat) {
info.valueF = global->GetInitValF();
} else {
info.value = global->GetInitVal();
}
}
g_globalVars.push_back(info);
}
const auto& functions = module.GetFunctions();
if (functions.empty()) {
throw std::runtime_error(FormatError("mir", "模块中没有函数"));
}
std::vector<std::unique_ptr<MachineFunction>> result;
// 为每个函数生成 MachineFunction
for (const auto& func : functions) {
auto machine_func = LowerFunctionToMIR(*func);
result.push_back(std::move(machine_func));
}
return result;
}
} // namespace mir

@ -8,7 +8,16 @@
namespace mir {
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {}
: name_(std::move(name)) {
entry_ = CreateBlock("entry");
}
MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) {
auto block = std::make_unique<MachineBasicBlock>(name);
auto* ptr = block.get();
blocks_.push_back(std::move(block));
return ptr;
}
int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size());
@ -30,4 +39,4 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
return frame_slots_[index];
}
} // namespace mir
} // namespace mir

@ -6,18 +6,34 @@ namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand::Operand(Kind kind, PhysReg reg, int64_t imm64)
: kind_(kind), reg_(PhysReg::ZERO), imm_(0), imm64_(imm64) {}
// 新增构造函数
Operand::Operand(Kind kind, PhysReg reg, int imm, const std::string& name)
: kind_(kind), reg_(reg), imm_(imm), global_name_(name) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
return Operand(Kind::Imm, PhysReg::ZERO, value);
}
Operand Operand::Imm64(int64_t value) {
return Operand(Kind::Imm, PhysReg::ZERO, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index);
return Operand(Kind::FrameIndex, PhysReg::ZERO, index);
}
// 新增
Operand Operand::Global(const std::string& name) {
return Operand(Kind::Global, PhysReg::ZERO, 0, name);
}
Operand Operand::Func(const std::string& name) {
Operand op(Kind::Func, PhysReg::ZERO, 0);
op.func_name_ = name;
return op;
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
} // namespace mir
} // namespace mir

@ -9,12 +9,66 @@ namespace {
bool IsAllowedReg(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29:
case PhysReg::X30:
// 临时寄存器
case PhysReg::T0:
case PhysReg::T1:
case PhysReg::T2:
case PhysReg::T3:
case PhysReg::T4:
case PhysReg::T5:
case PhysReg::T6:
// 参数/返回值寄存器
case PhysReg::A0:
case PhysReg::A1:
case PhysReg::A2:
case PhysReg::A3:
case PhysReg::A4:
case PhysReg::A5:
case PhysReg::A6:
case PhysReg::A7:
// 保存寄存器
case PhysReg::S0:
case PhysReg::S1:
case PhysReg::S2:
case PhysReg::S3:
case PhysReg::S4:
case PhysReg::S5:
case PhysReg::S6:
case PhysReg::S7:
case PhysReg::S8:
case PhysReg::S9:
case PhysReg::S10:
case PhysReg::S11:
// 特殊寄存器
case PhysReg::ZERO:
case PhysReg::RA:
case PhysReg::SP:
case PhysReg::GP:
case PhysReg::TP:
case PhysReg::FT0:
case PhysReg::FT1:
case PhysReg::FT2:
case PhysReg::FT3:
case PhysReg::FT4:
case PhysReg::FT5:
case PhysReg::FT6:
case PhysReg::FT7:
case PhysReg::FT8:
case PhysReg::FT9:
case PhysReg::FT10:
case PhysReg::FT11:
// 浮点保存寄存器
case PhysReg::FS0:
case PhysReg::FS1:
// 浮点参数寄存器
case PhysReg::FA0:
case PhysReg::FA1:
case PhysReg::FA2:
case PhysReg::FA3:
case PhysReg::FA4:
case PhysReg::FA5:
case PhysReg::FA6:
case PhysReg::FA7:
return true;
}
return false;
@ -23,7 +77,8 @@ bool IsAllowedReg(PhysReg reg) {
} // namespace
void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) {
// 修复GetEntry() 返回指针,使用 ->
for (const auto& inst : function.GetEntry()->GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
@ -33,4 +88,4 @@ void RunRegAlloc(MachineFunction& function) {
}
}
} // namespace mir
} // namespace mir

@ -8,20 +8,96 @@ namespace mir {
const char* PhysRegName(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
return "w0";
case PhysReg::W8:
return "w8";
case PhysReg::W9:
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
// 整数寄存器
case PhysReg::ZERO:
return "zero";
case PhysReg::RA:
return "ra";
case PhysReg::SP:
return "sp";
case PhysReg::GP:
return "gp";
case PhysReg::TP:
return "tp";
case PhysReg::T0:
return "t0";
case PhysReg::T1:
return "t1";
case PhysReg::T2:
return "t2";
case PhysReg::S0:
return "s0";
case PhysReg::S1:
return "s1";
case PhysReg::A0:
return "a0";
case PhysReg::A1:
return "a1";
case PhysReg::A2:
return "a2";
case PhysReg::A3:
return "a3";
case PhysReg::A4:
return "a4";
case PhysReg::A5:
return "a5";
case PhysReg::A6:
return "a6";
case PhysReg::A7:
return "a7";
case PhysReg::S2:
return "s2";
case PhysReg::S3:
return "s3";
case PhysReg::S4:
return "s4";
case PhysReg::S5:
return "s5";
case PhysReg::S6:
return "s6";
case PhysReg::S7:
return "s7";
case PhysReg::S8:
return "s8";
case PhysReg::S9:
return "s9";
case PhysReg::S10:
return "s10";
case PhysReg::S11:
return "s11";
case PhysReg::T3:
return "t3";
case PhysReg::T4:
return "t4";
case PhysReg::T5:
return "t5";
case PhysReg::T6:
return "t6";
// 浮点寄存器
case PhysReg::FT0: return "ft0";
case PhysReg::FT1: return "ft1";
case PhysReg::FT2: return "ft2";
case PhysReg::FT3: return "ft3";
case PhysReg::FT4: return "ft4";
case PhysReg::FT5: return "ft5";
case PhysReg::FT6: return "ft6";
case PhysReg::FT7: return "ft7";
case PhysReg::FS0: return "fs0";
case PhysReg::FS1: return "fs1";
case PhysReg::FA0: return "fa0";
case PhysReg::FA1: return "fa1";
case PhysReg::FA2: return "fa2";
case PhysReg::FA3: return "fa3";
case PhysReg::FA4: return "fa4";
case PhysReg::FA5: return "fa5";
case PhysReg::FA6: return "fa6";
case PhysReg::FA7: return "fa7";
case PhysReg::FT8: return "ft8";
case PhysReg::FT9: return "ft9";
case PhysReg::FT10: return "ft10";
case PhysReg::FT11: return "ft11";
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
}
} // namespace mir
} // namespace mir

@ -1,5 +1,6 @@
#include "sem/func.h"
#include <cstring>
#include <stdexcept>
#include <string>
@ -7,6 +8,12 @@
namespace sem {
// Truncate double to float32 precision (mimics C float arithmetic)
static double ToFloat32(double v) {
float f = static_cast<float>(v);
return static_cast<double>(f);
}
// 编译时求值常量表达式
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) {
return EvaluateExp(*ctx.addExp());
@ -73,14 +80,65 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp()) {
return EvaluateExp(*ctx.exp()->addExp());
} else if (ctx.lVar()) {
// 处理变量引用(必须是已定义的常量)
// 处理变量引用:向上遍历 AST 找到对应的常量定义并求值
auto* ident = ctx.lVar()->Ident();
if (!ident) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
std::string name = ident->getText();
// 这里简化处理,实际应该在符号表中查找常量
// 暂时假设常量已经在前面被处理过
// 向上遍历 AST 找到作用域内的 constDef
antlr4::ParserRuleContext* scope =
dynamic_cast<antlr4::ParserRuleContext*>(ctx.lVar()->parent);
while (scope) {
// 检查当前作用域中的所有 constDecl
for (auto* tree_child : scope->children) {
auto* child = dynamic_cast<antlr4::ParserRuleContext*>(tree_child);
if (!child) continue;
auto* block_item = dynamic_cast<SysYParser::BlockItemContext*>(child);
if (block_item && block_item->decl()) {
auto* decl = block_item->decl();
if (decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
// compUnit 级别的 constDecl
auto* decl = dynamic_cast<SysYParser::DeclContext*>(child);
if (decl && decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
// If declared as int, truncate to integer
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
scope = dynamic_cast<antlr4::ParserRuleContext*>(scope->parent);
}
// 未找到常量定义,返回 0
ConstValue val;
val.is_int = true;
val.int_val = 0;
@ -94,11 +152,11 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
ConstValue val;
if (int_const) {
val.is_int = true;
val.int_val = std::stoll(int_const->getText());
val.int_val = std::stoll(int_const->getText(), nullptr, 0);
val.float_val = static_cast<double>(val.int_val);
} else if (float_const) {
val.is_int = false;
val.float_val = std::stod(float_const->getText());
val.float_val = ToFloat32(std::stod(float_const->getText()));
val.int_val = static_cast<long long>(val.float_val);
} else {
throw std::runtime_error(FormatError("sema", "非法数字字面量"));
@ -127,8 +185,9 @@ ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) +
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l + r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -143,8 +202,9 @@ ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) -
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l - r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -159,8 +219,9 @@ ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) *
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l * r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -175,8 +236,9 @@ ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) /
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l / r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;

@ -2,3 +2,41 @@
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为
#include<stdio.h>
#include"sylib.h"
/* Input & output functions */
int getint(){int t; scanf("%d",&t); return t; }
int getch(){char c; scanf("%c",&c); return (int)c; }
int getarray(int a[]){
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)scanf("%d",&a[i]);
return n;
}
void putint(int a){ printf("%d",a);}
void putch(int a){ printf("%c",a); }
void putarray(int n,int a[]){
printf("%d:",n);
for(int i=0;i<n;i++)printf(" %d",a[i]);
printf("\n");
}
float getfloat(){ float f; scanf("%f",&f); return f; }
void putfloat(float a){ printf("%a",a); }
int getfarray(float a[]){
int n; scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%a",&a[i]);
return n;
}
void putfarray(int n,float a[]){
printf("%d:",n);
for(int i=0;i<n;i++) printf(" %a",a[i]);
printf("\n");
}
static struct timeval _sysy_start, _sysy_end;
void starttime(){ gettimeofday(&_sysy_start, NULL); }
void stoptime(){
gettimeofday(&_sysy_end, NULL);
long us = (_sysy_end.tv_sec - _sysy_start.tv_sec) * 1000000L
+ (_sysy_end.tv_usec - _sysy_start.tv_usec);
fprintf(stderr, "Timer: %ldus\n", us);
}

@ -2,3 +2,20 @@
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明
#ifndef __SYLIB_H_
#define __SYLIB_H_
#include<stdio.h>
#include<stdarg.h>
#include<sys/time.h>
/* Input & output functions */
int getint(),getch(),getarray(int a[]);
void putint(int a),putch(int a),putarray(int n,int a[]);
float getfloat();
void putfloat(float a);
int getfarray(float a[]);
void putfarray(int n,float a[]);
/* Timing functions */
void starttime();
void stoptime();
#endif

@ -0,0 +1,7 @@
// test/test_case/functional/test_riscv.sy
int main() {
int a = 10;
int b = 20;
int c = a + b;
return c; // 应该返回30
}
Loading…
Cancel
Save