Compare commits

...

10 Commits

1
.gitignore vendored

@ -54,6 +54,7 @@ compile_commands.json
.fleet/
.vs/
*.code-workspace
CLAUDE.md
# CLion
cmake-build-debug/

@ -60,12 +60,14 @@ class Context {
~Context();
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
std::string NextTemp();
std::string NextTemp(); // 用于指令名(数字,连续)
std::string NextLabel(); // 用于块名(字母前缀,独立计数)
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1;
int label_index_ = -1;
};
// ─── Type ─────────────────────────────────────────────────────────────────────
@ -163,7 +165,7 @@ enum class Opcode {
// 函数调用
Call,
// 类型转换
ZExt, SIToFP, FPToSI,
ZExt,
};
// ICmp 谓词
@ -198,16 +200,28 @@ class GlobalValue : public User {
class GlobalVariable : public Value {
public:
GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements = 1);
int num_elements = 1, bool is_array_decl = false,
bool is_float = false);
bool IsConst() const { return is_const_; }
bool IsFloat() const { return is_float_; }
int GetInitVal() const { return init_val_; }
float GetInitValF() const { return init_val_f_; }
int GetNumElements() const { return num_elements_; }
bool IsArray() const { return num_elements_ > 1; }
// GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store
bool IsArray() const { return is_array_decl_ || num_elements_ > 1; }
void SetInitVals(std::vector<int> v) { init_vals_ = std::move(v); }
void SetInitValsF(std::vector<float> v) { init_vals_f_ = std::move(v); }
const std::vector<int>& GetInitVals() const { return init_vals_; }
const std::vector<float>& GetInitValsF() const { return init_vals_f_; }
bool HasInitVals() const { return !init_vals_.empty() || !init_vals_f_.empty(); }
private:
bool is_const_;
bool is_float_;
int init_val_;
float init_val_f_;
int num_elements_;
bool is_array_decl_;
std::vector<int> init_vals_;
std::vector<float> init_vals_f_;
};
// ─── Instruction ──────────────────────────────────────────────────────────────
@ -299,7 +313,7 @@ class FCmpInst : public Instruction {
private:
FCmpPredicate pred_;
};
/*
// 有符号整数转浮点i32 → f32
class SIToFPInst : public Instruction {
public:
@ -313,7 +327,7 @@ class FPToSIInst : public Instruction {
FPToSIInst(Value* val, std::string name);
Value* GetSrc() const;
};
*/
// return 语句val 为 nullptr 表示 void return
class ReturnInst : public Instruction {
public:
@ -409,6 +423,8 @@ class Function : public Value {
Argument* GetArgument(size_t i) const;
size_t GetNumArgs() const { return args_.size(); }
bool IsVoidReturn() const { return type_->IsVoid(); }
// 将某个块移动到 blocks_ 列表末尾(用于确保块顺序正确)
void MoveBlockToEnd(BasicBlock* bb);
private:
BasicBlock* entry_ = nullptr;
@ -437,7 +453,9 @@ class Module {
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const,
int init_val, int num_elements = 1);
int init_val, int num_elements = 1,
bool is_array_decl = false,
bool is_float = false);
GlobalVariable* GetGlobalVariable(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const;
@ -494,9 +512,12 @@ class IRBuilder {
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaF32(const std::string& name);
AllocaInst* CreateAllocaArray(int num_elements, const std::string& name);
AllocaInst* CreateAllocaArrayF32(int num_elements, const std::string& name);
GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
// 零初始化数组emit memset call
void CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod);
// 控制流
ReturnInst* CreateRet(Value* v);
@ -515,8 +536,8 @@ class IRBuilder {
// 类型转换
ZExtInst* CreateZExt(Value* val, const std::string& name);
SIToFPInst* CreateSIToFP(Value* val, const std::string& name);
FPToSIInst* CreateFPToSI(Value* val, const std::string& name);
//SIToFPInst* CreateSIToFP(Value* val, const std::string& name);
//FPToSIInst* CreateFPToSI(Value* val, const std::string& name);
private:
Context& ctx_;

@ -19,39 +19,170 @@ class MIRContext {
MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP };
// RISC-V 64位寄存器定义
enum class PhysReg {
// 通用寄存器
ZERO, // x0, 恒为0
RA, // x1, 返回地址
SP, // x2, 栈指针
GP, // x3, 全局指针
TP, // x4, 线程指针
T0, // x5, 临时寄存器
T1, // x6, 临时寄存器
T2, // x7, 临时寄存器
S0, // x8, 帧指针/保存寄存器
S1, // x9, 保存寄存器
A0, // x10, 参数/返回值
A1, // x11, 参数
A2, // x12, 参数
A3, // x13, 参数
A4, // x14, 参数
A5, // x15, 参数
A6, // x16, 参数
A7, // x17, 参数
S2, // x18, 保存寄存器
S3, // x19, 保存寄存器
S4, // x20, 保存寄存器
S5, // x21, 保存寄存器
S6, // x22, 保存寄存器
S7, // x23, 保存寄存器
S8, // x24, 保存寄存器
S9, // x25, 保存寄存器
S10, // x26, 保存寄存器
S11, // x27, 保存寄存器
T3, // x28, 临时寄存器
T4, // x29, 临时寄存器
T5, // x30, 临时寄存器
T6, // x31, 临时寄存器
FT0, FT1, FT2, FT3, FT4, FT5, FT6, FT7,
FS0, FS1,
FA0, FA1, FA2, FA3, FA4, FA5, FA6, FA7,
FT8, FT9, FT10, FT11,
};
const char* PhysRegName(PhysReg reg);
// 在 MIR.h 中添加(在 Opcode 枚举之前)
struct GlobalVarInfo {
std::string name;
int value;
float valueF;
bool isConst;
bool isArray;
bool isFloat;
std::vector<int> arrayValues;
std::vector<float> arrayValuesF;
int arraySize;
};
enum class Opcode {
Prologue,
Epilogue,
MovImm,
LoadStack,
StoreStack,
AddRR,
Load,
Store,
Add,
Addi,
Sub,
Mul,
Div,
Rem,
Slt,
Slti,
Slli,
Sltu, // 无符号小于
Sltiu,
Xori,
LoadGlobalAddr,
LoadGlobal,
StoreGlobal,
LoadIndirect, // lw rd, 0(rs1) 从寄存器地址加载
StoreIndirect, // sw rs2, 0(rs1)
Call,
GEP,
LoadAddr,
Ret,
// 浮点指令
FMov, // 浮点移动
FMovWX, // fmv.w.x fs, x 整数寄存器移动到浮点寄存器
FMovXW, // fmv.x.w x, fs 浮点寄存器移动到整数寄存器
FAdd,
FSub,
FMul,
FDiv,
FEq, // 浮点相等比较
FLt, // 浮点小于比较
FLe, // 浮点小于等于比较
FNeg, // 浮点取反
FAbs, // 浮点绝对值
// int 转 float
// float 转 int
LoadFloat, // 浮点加载 (flw)
StoreFloat, // 浮点存储 (fsw)
Br,
CondBr,
Label,
And, // 按位与
Andi, // 按位与立即数
Or, // 按位或
Ori, // 按位或立即数
Xor, // 按位异或
Srli, // 逻辑右移立即数
Srai, // 算术右移立即数
Srl, // 逻辑右移
Sra, // 算术右移
};
enum class GlobalKind {
Data, // .data 段(已初始化)
BSS, // .bss 段未初始化初始为0
RoData // .rodata 段(只读常量)
};
// 全局变量信息
struct GlobalInfo {
std::string name;
GlobalKind kind;
int size; // 大小(字节)
int value; // 初始值(对于简单变量)
bool isArray;
int arraySize;
std::vector<int> dimensions; // 数组维度
};
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex };
enum class Kind { Reg, Imm, FrameIndex, Global, Func };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand Imm64(int64_t value); // 新增:存储 64 位值
static Operand FrameIndex(int index);
static Operand Global(const std::string& name);
static Operand Func(const std::string& name);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int64_t GetImm64() const { return imm64_; } // 新增
int GetFrameIndex() const { return imm_; }
const std::string& GetGlobalName() const { return global_name_; }
const std::string& GetFuncName() const { return func_name_; }
private:
Operand(Kind kind, PhysReg reg, int imm);
Operand(Kind kind, PhysReg reg, int64_t imm64); // 新增构造函数
Operand(Kind kind, PhysReg reg, int imm, const std::string& name);
Kind kind_;
PhysReg reg_;
int imm_;
int64_t imm64_; // 新增
std::string global_name_;
std::string func_name_;
};
class MachineInstr {
@ -71,7 +202,6 @@ struct FrameSlot {
int size = 4;
int offset = 0;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
@ -93,9 +223,14 @@ class MachineFunction {
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
// 基本块管理
MachineBasicBlock* CreateBlock(const std::string& name);
MachineBasicBlock* GetEntry() { return entry_; }
const MachineBasicBlock* GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const { return blocks_; }
// 栈帧管理
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
@ -106,14 +241,15 @@ class MachineFunction {
private:
std::string name_;
MachineBasicBlock entry_;
MachineBasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
//std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os);
} // namespace mir
//void PrintAsm(const MachineFunction& function, std::ostream& os);
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module);
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os);
} // namespace mir

@ -0,0 +1,5 @@
define i32 @main() {
entry:
ret i32 42
}

@ -0,0 +1,309 @@
.text
.global main
.type main, @function
main:
addi sp, sp, -272
sw ra, 264(sp)
sw s0, 256(sp)
addi a0, sp, -4
li a1, 0
li a2, 32
call
addi a0, sp, -8
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -8
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -8
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -8
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -8
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -8
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -8
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -8
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -8
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -44
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -44
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -44
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -44
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -44
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -44
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -44
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -44
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -44
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
addi a0, sp, -80
li a1, 0
li a2, 32
call
li t2, 1
addi t0, sp, -80
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 2
addi t0, sp, -80
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -80
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -80
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -80
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -80
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -112(sp)
li t0, 1
lw t1, -112(sp)
add t0, t0, t1
sw t0, -116(sp)
addi t0, sp, -80
lw t1, -116(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -124(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -128(sp)
li t0, 1
lw t1, -128(sp)
add t0, t0, t1
sw t0, -132(sp)
addi t0, sp, -44
lw t1, -132(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -140(sp)
addi a0, sp, -108
li a1, 0
li a2, 32
call
lw t2, -124(sp)
addi t0, sp, -108
li t1, 0
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
lw t2, -140(sp)
addi t0, sp, -108
li t1, 1
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 3
addi t0, sp, -108
li t1, 2
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 4
addi t0, sp, -108
li t1, 3
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 5
addi t0, sp, -108
li t1, 4
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 6
addi t0, sp, -108
li t1, 5
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 7
addi t0, sp, -108
li t1, 6
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t2, 8
addi t0, sp, -108
li t1, 7
slli t1, t1, 2
add t0, t0, t1
sw t2, 0(t0)
li t0, 3
li t1, 2
mul t0, t0, t1
sw t0, -176(sp)
li t0, 1
lw t1, -176(sp)
add t0, t0, t1
sw t0, -180(sp)
addi t0, sp, -108
lw t1, -180(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -188(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -192(sp)
li t0, 0
lw t1, -192(sp)
add t0, t0, t1
sw t0, -196(sp)
addi t0, sp, -108
lw t1, -196(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -204(sp)
lw t0, -188(sp)
lw t1, -204(sp)
add t0, t0, t1
sw t0, -208(sp)
li t0, 0
li t1, 2
mul t0, t0, t1
sw t0, -212(sp)
li t0, 1
lw t1, -212(sp)
add t0, t0, t1
sw t0, -216(sp)
addi t0, sp, -108
lw t1, -216(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -224(sp)
lw t0, -208(sp)
lw t1, -224(sp)
add t0, t0, t1
sw t0, -228(sp)
li t0, 2
li t1, 2
mul t0, t0, t1
sw t0, -232(sp)
li t0, 0
lw t1, -232(sp)
add t0, t0, t1
sw t0, -236(sp)
addi t0, sp, -4
lw t1, -236(sp)
slli t1, t1, 2
add t0, t0, t1
lw t0, 0(t0)
sw t0, -244(sp)
lw t0, -228(sp)
lw t1, -244(sp)
add t0, t0, t1
sw t0, -248(sp)
lw a0, -248(sp)
lw ra, 264(sp)
lw s0, 256(sp)
addi sp, sp, 272
ret
.size main, .-main

@ -0,0 +1,22 @@
compUnit
|-- funcDef
| |-- funcType
| | `-- Int: int
| |-- Ident: main
| |-- L_PAREN: (
| |-- R_PAREN: )
| `-- block
| |-- L_BRACE: {
| |-- blockItem
| | `-- stmt
| | |-- Return: return
| | |-- exp
| | | `-- addExp
| | | `-- mulExp
| | | `-- unaryExp
| | | `-- primaryExp
| | | `-- number
| | | `-- IntConst: 42
| | `-- Semi: ;
| `-- R_BRACE: }
`-- EOF: <EOF>

@ -0,0 +1,137 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/mir"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
exit 1
fi
mkdir -p "$TEST_RESULT_DIR"
echo "=========================================="
echo "RISC-V 后端测试"
echo "=========================================="
echo ""
# 收集测试用例
mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort)
total=${#test_files[@]}
pass_gen=0
fail_gen=0
pass_run=0
fail_run=0
timeout_cnt=0
echo "=== 阶段1汇编生成 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s"
mkdir -p "$(dirname "$output_file")"
"$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file"
if [ $? -eq 0 ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}${NC} $relative_path"
((pass_gen++))
else
echo -e " ${RED}${NC} $relative_path"
((fail_gen++))
fi
done
echo ""
echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---"
echo ""
echo "=== 阶段2运行验证 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
stem="${relative_path%.sy}"
asm_file="$TEST_RESULT_DIR/${stem}.s"
exe_file="$TEST_RESULT_DIR/${stem}"
expected_file="${test_file%.sy}.out"
if [ ! -s "$asm_file" ]; then
echo -e " ${YELLOW}${NC} $relative_path (跳过)"
continue
fi
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e " ${RED}${NC} $relative_path (链接失败)"
((fail_run++))
continue
fi
# 运行程序,设置超时 5 秒
timeout 5 qemu-riscv64 "$exe_file" 2>/dev/null
exit_code=$?
# 检查是否超时
if [ $exit_code -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
continue
fi
# 获取程序输出(需要单独捕获,因为 timeout 会改变输出)
program_output=$(timeout 5 qemu-riscv64 "$exe_file" 2>/dev/null)
if [ $? -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
continue
fi
if [ -f "$expected_file" ]; then
expected=$(cat "$expected_file" | tr -d '\n')
# 判断期望文件是输出内容还是退出码
if [ -z "$expected" ] || [[ "$expected" =~ ^[0-9]+$ ]]; then
# 期望退出码
if [ $exit_code -eq "$expected" ] 2>/dev/null; then
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)"
((fail_run++))
fi
else
# 期望输出内容
if [ "$program_output" = "$expected" ]; then
echo -e " ${GREEN}${NC} $relative_path (输出匹配)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (输出不匹配)"
((fail_run++))
fi
fi
else
# 没有期望文件,默认通过
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
fi
done
echo ""
echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---"
echo ""
echo "=========================================="
echo "测试完成"
echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen"
echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt"
echo "=========================================="

@ -0,0 +1,155 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/mir"
SYLIB_C="$PROJECT_ROOT/sylib/sylib.c"
SYLIB_O="/tmp/sylib.o"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
exit 1
fi
# 编译 sylib 运行时库
echo "编译运行时库..."
if [ ! -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -c "$SYLIB_C" -o "$SYLIB_O" 2>/dev/null
if [ $? -ne 0 ]; then
echo -e "${YELLOW}警告:无法编译 sylib.c部分测试可能链接失败${NC}"
else
echo -e "${GREEN}✓ sylib.o 编译成功${NC}"
fi
fi
echo ""
mkdir -p "$TEST_RESULT_DIR"
echo "=========================================="
echo "RISC-V 后端测试"
echo "=========================================="
echo ""
# 收集测试用例
mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort)
total=${#test_files[@]}
pass_gen=0
fail_gen=0
pass_run=0
fail_run=0
timeout_cnt=0
echo "=== 阶段1汇编生成 ==="
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s"
mkdir -p "$(dirname "$output_file")"
"$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file"
if [ $? -eq 0 ] && [ -s "$output_file" ]; then
echo -e " ${GREEN}${NC} $relative_path"
((pass_gen++))
else
echo -e " ${RED}${NC} $relative_path"
((fail_gen++))
fi
done
echo ""
echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---"
echo ""
for test_file in "${test_files[@]}"; do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
stem="${relative_path%.sy}"
asm_file="$TEST_RESULT_DIR/${stem}.s"
exe_file="$TEST_RESULT_DIR/${stem}"
expected_file="${test_file%.sy}.out"
if [ ! -s "$asm_file" ]; then
echo -e " ${YELLOW}${NC} $relative_path (跳过)"
continue
fi
# 链接
if [ -f "$SYLIB_O" ]; then
riscv64-linux-gnu-gcc -static "$asm_file" "$SYLIB_O" -o "$exe_file" -no-pie 2>/dev/null
else
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null
fi
if [ $? -ne 0 ]; then
echo -e " ${RED}${NC} $relative_path (链接失败)"
((fail_run++))
continue
fi
# 运行程序
input_file="${test_file%.sy}.in"
tmp_out=$(mktemp)
if [ -f "$input_file" ]; then
timeout 10 qemu-riscv64 "$exe_file" < "$input_file" > "$tmp_out" 2>&1
else
timeout 10 qemu-riscv64 "$exe_file" > "$tmp_out" 2>&1
fi
exit_code=$?
if [ $exit_code -eq 124 ]; then
echo -e " ${YELLOW}${NC} $relative_path (超时)"
((timeout_cnt++))
rm -f "$tmp_out"
continue
fi
program_output=$(cat "$tmp_out" | tr -d '\n' | sed 's/[[:space:]]*$//')
rm -f "$tmp_out"
if [ -f "$expected_file" ]; then
expected=$(cat "$expected_file" | tr -d '\n' | sed 's/[[:space:]]*$//')
if [[ "$expected" =~ ^[0-9]+$ ]] && [ "$expected" -ge 0 ] && [ "$expected" -le 255 ] && [ -z "$program_output" ]; then
# 期望退出码(且没有输出)
if [ $exit_code -eq "$expected" ] 2>/dev/null; then
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)"
((fail_run++))
fi
else
# 期望输出内容
if [ "$program_output" = "$expected" ]; then
echo -e " ${GREEN}${NC} $relative_path (输出匹配)"
((pass_run++))
else
echo -e " ${RED}${NC} $relative_path (输出不匹配: 期望 '$expected', 实际 '$program_output')"
((fail_run++))
fi
fi
else
# 没有期望文件
echo -e " ${GREEN}${NC} $relative_path (退出码: $exit_code, 输出: '$program_output')"
((pass_run++))
fi
done
echo ""
echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---"
echo ""
echo "=========================================="
echo "测试完成"
echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen"
echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt"
echo "=========================================="

@ -0,0 +1,65 @@
#!/bin/bash
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_DIR="$PROJECT_ROOT/test/test_case/basic"
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器不存在: $COMPILER"
exit 1
fi
echo "=========================================="
echo "RISC-V 浮点转换测试"
echo "=========================================="
TESTS="
float_conv:3
float_add:13
float_mul:30
"
PASS=0
FAIL=0
for test in $TESTS; do
name=$(echo $test | cut -d: -f1)
expected=$(echo $test | cut -d: -f2)
echo -n "测试 $name (期望 $expected) ... "
"$COMPILER" "$TEST_DIR/$name.sy" --emit-asm > /tmp/test_$name.s 2>&1
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (汇编错误)${NC}"
cat /tmp/test_$name.s | head -3
FAIL=$((FAIL + 1))
continue
fi
riscv64-linux-gnu-gcc -static /tmp/test_$name.s -o /tmp/test_$name -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (链接错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
qemu-riscv64 /tmp/test_$name > /dev/null 2>&1
exit_code=$?
if [ $exit_code -eq $expected ]; then
echo -e "${GREEN}通过${NC}"
PASS=$((PASS + 1))
else
echo -e "${RED}失败 (实际 $exit_code)${NC}"
FAIL=$((FAIL + 1))
fi
done
echo "=========================================="
echo -e "测试结果: ${GREEN}通过 $PASS${NC} / ${RED}失败 $FAIL${NC}"
echo "=========================================="

@ -4,6 +4,9 @@ PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir"
VERIFY_SCRIPT="$PROJECT_ROOT/scripts/verify_ir.sh"
PARALLEL=${PARALLEL:-$(nproc)}
LOG_FILE="$TEST_RESULT_DIR/verify.log"
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
@ -12,47 +15,130 @@ if [ ! -x "$COMPILER" ]; then
fi
mkdir -p "$TEST_RESULT_DIR"
> "$LOG_FILE"
pass_count=0
fail_count=0
failed_cases=()
# ── 阶段1IR 生成(并行)────────────────────────────────────────────────────
echo "=== 阶段1IR 生成 ===" | tee -a "$LOG_FILE"
echo "" | tee -a "$LOG_FILE"
echo "=== 开始测试 IR 生成 ==="
echo ""
GEN_TMPDIR=$(mktemp -d)
while IFS= read -r test_file; do
gen_one() {
test_file="$1"
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll"
mkdir -p "$(dirname "$output_file")"
echo -n "测试: $relative_path ... "
"$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1
exit_code=$?
# Use a per-case tmp file to avoid concurrent write issues
case_id=$(echo "$relative_path" | tr '/' '_')
if [ $exit_code -eq 0 ] && [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then
echo "通过"
pass_count=$((pass_count + 1))
echo "通过: $relative_path" > "$GEN_TMPDIR/pass_${case_id}"
else
echo "失败"
fail_count=$((fail_count + 1))
echo "$relative_path" > "$GEN_TMPDIR/fail_${case_id}"
echo "失败: $relative_path" > "$GEN_TMPDIR/line_fail_${case_id}"
fi
}
export -f gen_one
export COMPILER TEST_CASE_DIR TEST_RESULT_DIR GEN_TMPDIR
find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort | \
xargs -P "$PARALLEL" -I{} bash -c 'gen_one "$@"' _ {}
# Collect results in sorted order
failed_cases=()
for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f")
case_id=$(echo "$relative_path" | tr '/' '_')
if [ -f "$GEN_TMPDIR/pass_${case_id}" ]; then
cat "$GEN_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE"
elif [ -f "$GEN_TMPDIR/fail_${case_id}" ]; then
cat "$GEN_TMPDIR/line_fail_${case_id}" | tee -a "$LOG_FILE"
failed_cases+=("$relative_path")
echo " 错误信息已保存到: $output_file"
fi
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
done
pass_count=$(ls "$GEN_TMPDIR"/pass_* 2>/dev/null | wc -l)
fail_count=${#failed_cases[@]}
rm -rf "$GEN_TMPDIR"
echo "" | tee -a "$LOG_FILE"
echo "--- 生成完成: 通过 $pass_count / 失败 $fail_count ---" | tee -a "$LOG_FILE"
# ── 阶段2IR 运行验证(并行,需要 llc + clang──────────────────────────────
if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then
echo "" | tee -a "$LOG_FILE"
echo "=== 跳过阶段2未找到 llc 或 clang无法运行 IR ===" | tee -a "$LOG_FILE"
else
echo "" | tee -a "$LOG_FILE"
echo "=== 阶段2IR 运行验证 ===" | tee -a "$LOG_FILE"
echo "" | tee -a "$LOG_FILE"
VRF_TMPDIR=$(mktemp -d)
verify_one() {
test_file="$1"
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
relative_dir=$(dirname "$relative_path")
out_dir="$TEST_RESULT_DIR/$relative_dir"
stem=$(basename "${test_file%.sy}")
case_log="$out_dir/$stem.verify.log"
case_id=$(echo "$relative_path" | tr '/' '_')
if bash "$VERIFY_SCRIPT" "$test_file" "$out_dir" --run > "$case_log" 2>&1; then
echo "通过: $relative_path" > "$VRF_TMPDIR/pass_${case_id}"
else
extra=$(grep -E '(退出码|输出不匹配|错误)' "$case_log" | head -3 | sed 's/^/ /' || true)
{ echo "失败: $relative_path"; [ -n "$extra" ] && echo "$extra"; } > "$VRF_TMPDIR/fail_${case_id}"
echo "$relative_path" > "$VRF_TMPDIR/failname_${case_id}"
fi
}
export -f verify_one
export TEST_CASE_DIR TEST_RESULT_DIR VERIFY_SCRIPT VRF_TMPDIR
find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort | \
xargs -P "$PARALLEL" -I{} bash -c 'verify_one "$@"' _ {}
# Collect results in sorted order
verify_failed_cases=()
for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f")
case_id=$(echo "$relative_path" | tr '/' '_')
if [ -f "$VRF_TMPDIR/pass_${case_id}" ]; then
cat "$VRF_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE"
elif [ -f "$VRF_TMPDIR/fail_${case_id}" ]; then
cat "$VRF_TMPDIR/fail_${case_id}" | tee -a "$LOG_FILE"
verify_failed_cases+=("$relative_path")
fi
done
verify_pass=$(ls "$VRF_TMPDIR"/pass_* 2>/dev/null | wc -l)
verify_fail=${#verify_failed_cases[@]}
rm -rf "$VRF_TMPDIR"
echo "" | tee -a "$LOG_FILE"
echo "--- 验证完成: 通过 $verify_pass / 失败 $verify_fail ---" | tee -a "$LOG_FILE"
if [ ${#verify_failed_cases[@]} -gt 0 ]; then
echo "" | tee -a "$LOG_FILE"
echo "=== 验证失败的用例 ===" | tee -a "$LOG_FILE"
for f in "${verify_failed_cases[@]}"; do
[ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE"
done
fi
fi
echo ""
echo "=== 测试完成 ==="
echo "通过: $pass_count"
echo "失败: $fail_count"
echo "结果保存在: $TEST_RESULT_DIR"
# ── 汇总 ─────────────────────────────────────────────────────────────────────
echo "" | tee -a "$LOG_FILE"
echo "=== 测试完成 ===" | tee -a "$LOG_FILE"
echo "IR生成 通过: $pass_count 失败: $fail_count" | tee -a "$LOG_FILE"
echo "结果保存在: $TEST_RESULT_DIR" | tee -a "$LOG_FILE"
echo "日志保存在: $LOG_FILE" | tee -a "$LOG_FILE"
if [ ${#failed_cases[@]} -gt 0 ]; then
echo ""
echo "=== 失败的用例 ==="
echo "" | tee -a "$LOG_FILE"
echo "=== IR生成失败的用例 ===" | tee -a "$LOG_FILE"
for f in "${failed_cases[@]}"; do
echo " - $f"
[ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE"
done
exit 1
fi

@ -0,0 +1,85 @@
#!/bin/bash
# 获取项目根目录
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_DIR="$PROJECT_ROOT/test/test_case/basic"
# 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m'
# 检查编译器
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器不存在: $COMPILER"
exit 1
fi
# 检查工具链
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "错误: 未找到 riscv64-linux-gnu-gcc"
exit 1
fi
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "错误: 未找到 qemu-riscv64"
exit 1
fi
echo "=========================================="
echo "RISC-V 基础功能测试"
echo "=========================================="
# 定义测试用例
TESTS="arith:50 add:30 sub:7 mul:50 div:25 mod:2 var:43"
PASS=0
FAIL=0
for test in $TESTS; do
name=$(echo $test | cut -d: -f1)
expected=$(echo $test | cut -d: -f2)
echo -n "测试 $name (期望 $expected) ... "
# 生成汇编
"$COMPILER" "$TEST_DIR/$name.sy" --emit-asm > /tmp/test_$name.s 2>&1
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (汇编错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
# 链接
riscv64-linux-gnu-gcc -static /tmp/test_$name.s -o /tmp/test_$name -no-pie 2>/dev/null
if [ $? -ne 0 ]; then
echo -e "${RED}失败 (链接错误)${NC}"
FAIL=$((FAIL + 1))
continue
fi
# 运行
qemu-riscv64 /tmp/test_$name > /dev/null 2>&1
exit_code=$?
if [ $exit_code -eq $expected ]; then
echo -e "${GREEN}通过${NC}"
PASS=$((PASS + 1))
else
echo -e "${RED}失败 (实际 $exit_code)${NC}"
FAIL=$((FAIL + 1))
fi
done
echo "=========================================="
echo -e "测试结果: ${GREEN}通过 $PASS${NC} / ${RED}失败 $FAIL${NC}"
echo "=========================================="
if [ $FAIL -eq 0 ]; then
echo -e "${GREEN}✓ 所有基础测试通过!${NC}"
exit 0
else
echo -e "${RED}✗ 有 $FAIL 个测试失败${NC}"
exit 1
fi

@ -3,6 +3,8 @@
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
@ -31,7 +33,7 @@ if [[ ! -f "$input" ]]; then
exit 1
fi
compiler="./build/bin/compiler"
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1
@ -60,13 +62,13 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe"
clang "$obj" "$PROJECT_ROOT/sylib/sylib.c" -o "$exe" -lm
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
(ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file"
else
"$exe" > "$stdout_file"
(ulimit -s unlimited; "$exe") > "$stdout_file"
fi
status=$?
set -e

@ -0,0 +1,101 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/mir"
run_exec=false
input_dir=$(dirname "$input")
shift
while [[ $# -gt 0 ]]; do
case "$1" in
--run)
run_exec=true
;;
*)
out_dir="$1"
;;
esac
shift
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
exit 1
fi
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
mkdir -p "$out_dir"
base=$(basename "$input")
stem=${base%.sy}
mir_file="$out_dir/$stem.mir"
asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
# 生成 MIR
"$compiler" --emit-mir "$input" > "$mir_file"
echo "MIR 已生成: $mir_file"
# 生成汇编
"$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file"
if [[ "$run_exec" == true ]]; then
if ! command -v riscv64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 riscv64-linux-gnu-gcc" >&2
exit 1
fi
if ! command -v qemu-riscv64 >/dev/null 2>&1; then
echo "未找到 qemu-riscv64" >&2
exit 1
fi
riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe" -no-pie
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
qemu-riscv64 "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-riscv64 "$exe" > "$stdout_file"
fi
status=$?
set -e
cat "$stdout_file"
echo "退出码: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi

@ -0,0 +1,256 @@
#pragma once
#include <initializer_list>
#include <iosfwd>
#include <memory>
#include <string>
#include <vector>
namespace ir {
class Module;
}
namespace mir {
class MIRContext {
public:
MIRContext() = default;
};
MIRContext& DefaultContext();
// RISC-V 64位寄存器定义
enum class PhysReg {
// 通用寄存器
ZERO, // x0, 恒为0
RA, // x1, 返回地址
SP, // x2, 栈指针
GP, // x3, 全局指针
TP, // x4, 线程指针
T0, // x5, 临时寄存器
T1, // x6, 临时寄存器
T2, // x7, 临时寄存器
S0, // x8, 帧指针/保存寄存器
S1, // x9, 保存寄存器
A0, // x10, 参数/返回值
A1, // x11, 参数
A2, // x12, 参数
A3, // x13, 参数
A4, // x14, 参数
A5, // x15, 参数
A6, // x16, 参数
A7, // x17, 参数
S2, // x18, 保存寄存器
S3, // x19, 保存寄存器
S4, // x20, 保存寄存器
S5, // x21, 保存寄存器
S6, // x22, 保存寄存器
S7, // x23, 保存寄存器
S8, // x24, 保存寄存器
S9, // x25, 保存寄存器
S10, // x26, 保存寄存器
S11, // x27, 保存寄存器
T3, // x28, 临时寄存器
T4, // x29, 临时寄存器
T5, // x30, 临时寄存器
T6, // x31, 临时寄存器
FT0, FT1, FT2, FT3, FT4, FT5, FT6, FT7,
FS0, FS1,
FA0, FA1, FA2, FA3, FA4, FA5, FA6, FA7,
FT8, FT9, FT10, FT11,
};
const char* PhysRegName(PhysReg reg);
// 在 MIR.h 中添加(在 Opcode 枚举之前)
struct GlobalVarInfo {
std::string name;
int value;
float valueF;
bool isConst;
bool isArray;
bool isFloat;
std::vector<int> arrayValues;
std::vector<float> arrayValuesF;
int arraySize;
};
enum class Opcode {
Prologue,
Epilogue,
MovImm,
Load,
Store,
Add,
Addi,
Sub,
Mul,
Div,
Rem,
Slt,
Slti,
Slli,
Sltu, // 无符号小于
Xori,
LoadGlobalAddr,
LoadGlobal,
StoreGlobal,
LoadIndirect, // lw rd, 0(rs1) 从寄存器地址加载
StoreIndirect, // sw rs2, 0(rs1)
Call,
GEP,
LoadAddr,
Ret,
// 浮点指令
FMov, // 浮点移动
FMovWX, // fmv.w.x fs, x 整数寄存器移动到浮点寄存器
FMovXW, // fmv.x.w x, fs 浮点寄存器移动到整数寄存器
FAdd,
FSub,
FMul,
FDiv,
FEq, // 浮点相等比较
FLt, // 浮点小于比较
FLe, // 浮点小于等于比较
FNeg, // 浮点取反
FAbs, // 浮点绝对值
SIToFP, // int 转 float
FPToSI, // float 转 int
LoadFloat, // 浮点加载 (flw)
StoreFloat, // 浮点存储 (fsw)
Br,
CondBr,
Label,
Srli,
Srai,
Srl,
Sra,
And,
Andi,
Or,
Ori,
Xor,
FNeg,
FAbs
};
enum class GlobalKind {
Data, // .data 段(已初始化)
BSS, // .bss 段未初始化初始为0
RoData // .rodata 段(只读常量)
};
// 全局变量信息
struct GlobalInfo {
std::string name;
GlobalKind kind;
int size; // 大小(字节)
int value; // 初始值(对于简单变量)
bool isArray;
int arraySize;
std::vector<int> dimensions; // 数组维度
};
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex, Global, Func };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand Imm64(int64_t value); // 新增:存储 64 位值
static Operand FrameIndex(int index);
static Operand Global(const std::string& name);
static Operand Func(const std::string& name);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int64_t GetImm64() const { return imm64_; } // 新增
int GetFrameIndex() const { return imm_; }
const std::string& GetGlobalName() const { return global_name_; }
const std::string& GetFuncName() const { return func_name_; }
private:
Operand(Kind kind, PhysReg reg, int imm);
Operand(Kind kind, PhysReg reg, int64_t imm64); // 新增构造函数
Operand(Kind kind, PhysReg reg, int imm, const std::string& name);
Kind kind_;
PhysReg reg_;
int imm_;
int64_t imm64_; // 新增
std::string global_name_;
std::string func_name_;
};
class MachineInstr {
public:
MachineInstr(Opcode opcode, std::vector<Operand> operands = {});
Opcode GetOpcode() const { return opcode_; }
const std::vector<Operand>& GetOperands() const { return operands_; }
private:
Opcode opcode_;
std::vector<Operand> operands_;
};
struct FrameSlot {
int index = 0;
int size = 4;
int offset = 0;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
MachineInstr& Append(Opcode opcode,
std::initializer_list<Operand> operands = {});
private:
std::string name_;
std::vector<MachineInstr> instructions_;
};
class MachineFunction {
public:
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
// 基本块管理
MachineBasicBlock* CreateBlock(const std::string& name);
MachineBasicBlock* GetEntry() { return entry_; }
const MachineBasicBlock* GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const { return blocks_; }
// 栈帧管理
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
int GetFrameSize() const { return frame_size_; }
void SetFrameSize(int size) { frame_size_ = size; }
private:
std::string name_;
MachineBasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
//std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
//void PrintAsm(const MachineFunction& function, std::ostream& os);
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module);
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os);
} // namespace mir

@ -29,4 +29,10 @@ std::string Context::NextTemp() {
return oss.str();
}
std::string Context::NextLabel() {
std::ostringstream oss;
oss << "L" << ++label_index_;
return oss.str();
}
} // namespace ir

@ -17,6 +17,17 @@ BasicBlock* Function::CreateBlock(const std::string& name) {
return ptr;
}
void Function::MoveBlockToEnd(BasicBlock* bb) {
for (size_t i = 0; i < blocks_.size(); ++i) {
if (blocks_[i].get() == bb) {
auto tmp = std::move(blocks_[i]);
blocks_.erase(blocks_.begin() + i);
blocks_.push_back(std::move(tmp));
return;
}
}
}
BasicBlock* Function::GetEntry() { return entry_; }
const BasicBlock* Function::GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {

@ -136,6 +136,15 @@ AllocaInst* IRBuilder::CreateAllocaArray(int num_elements,
num_elements, name);
}
AllocaInst* IRBuilder::CreateAllocaArrayF32(int num_elements,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrFloat32Type(),
num_elements, name);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index,
const std::string& name) {
if (!insert_block_) {
@ -222,7 +231,7 @@ ZExtInst* IRBuilder::CreateZExt(Value* val, const std::string& name) {
}
return insert_block_->Append<ZExtInst>(val, name);
}
/*
SIToFPInst* IRBuilder::CreateSIToFP(Value* val, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -236,5 +245,21 @@ FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) {
}
return insert_block_->Append<FPToSIInst>(val, name);
}
*/
void IRBuilder::CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
// declare memset if not already declared
if (!mod.HasExternalDecl("memset")) {
mod.DeclareExternalFunc("memset", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
}
int byte_count = num_elements * 4;
insert_block_->Append<CallInst>(
std::string("memset"), Type::GetVoidType(),
std::vector<Value*>{ptr, ctx.GetConstInt(0), ctx.GetConstInt(byte_count)},
std::string(""));
}
} // namespace ir

@ -1,6 +1,9 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
@ -50,7 +53,13 @@ static std::string ValStr(const Value* v) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::to_string(cf->GetValue());
// LLVM IR 要求 float 常量用 64 位十六进制表示double 精度)
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
return oss.str();
}
// BasicBlock: 打印为 label %name
if (dynamic_cast<const BasicBlock*>(v)) {
@ -59,10 +68,11 @@ static std::string ValStr(const Value* v) {
// GlobalVariable: 打印为 @name
if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) {
if (gv->IsArray()) {
// 数组全局变量的指针getelementptr [N x i32], [N x i32]* @name, i32 0, i32 0
return "getelementptr ([" + std::to_string(gv->GetNumElements()) +
" x i32], [" + std::to_string(gv->GetNumElements()) +
" x i32]* @" + gv->GetName() + ", i32 0, i32 0)";
// 数组全局变量的指针getelementptr [N x T], [N x T]* @name, i32 0, i32 0
const char* elem_ty = gv->IsFloat() ? "float" : "i32";
return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) +
" x " + elem_ty + "], [" + std::to_string(gv->GetNumElements()) +
" x " + elem_ty + "]* @" + gv->GetName() + ", i32 0, i32 0)";
}
return "@" + v->GetName();
}
@ -76,8 +86,12 @@ static std::string TypeVal(const Value* v) {
std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::string(TypeToStr(*cf->GetType())) + " " +
std::to_string(cf->GetValue());
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
return oss.str();
}
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v);
}
@ -86,13 +100,34 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
// 1. 全局变量/常量
for (const auto& g : module.GetGlobalVariables()) {
if (g->IsArray()) {
// 全局数组zeroinitializer
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
const char* linkage = g->IsConst() ? "constant" : "global";
const char* elem_ty = g->IsFloat() ? "float" : "i32";
os << "@" << g->GetName() << " = " << linkage
<< " [" << g->GetNumElements() << " x " << elem_ty << "] ";
if (g->HasInitVals()) {
os << "[";
if (g->IsFloat()) {
const auto& vals = g->GetInitValsF();
for (int i = 0; i < g->GetNumElements(); ++i) {
if (i > 0) os << ", ";
float fv = (i < (int)vals.size()) ? vals[i] : 0.0f;
double d = static_cast<double>(fv);
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
os << oss.str();
}
} else {
const auto& vals = g->GetInitVals();
for (int i = 0; i < g->GetNumElements(); ++i) {
if (i > 0) os << ", ";
os << "i32 " << (i < (int)vals.size() ? vals[i] : 0);
}
}
os << "]\n";
} else {
os << "@" << g->GetName() << " = global [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
os << "zeroinitializer\n";
}
} else {
if (g->IsConst()) {
@ -209,8 +244,11 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
bool is_float_ptr = gep->GetBasePtr()->GetType()->IsPtrFloat32();
const char* elem_type = is_float_ptr ? "float" : "i32";
const char* ptr_type = is_float_ptr ? "float*" : "i32*";
os << " %" << gep->GetName()
<< " = getelementptr i32, i32* "
<< " = getelementptr " << elem_type << ", " << ptr_type << " "
<< ValStr(gep->GetBasePtr()) << ", i32 "
<< ValStr(gep->GetIndex()) << "\n";
break;
@ -285,6 +323,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
<< ValStr(ze->GetSrc()) << " to i32\n";
break;
}
/*
case Opcode::SIToFP: {
auto* si = static_cast<const SIToFPInst*>(inst);
os << " %" << si->GetName() << " = sitofp i32 "
@ -297,6 +336,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
<< ValStr(fp->GetSrc()) << " to i32\n";
break;
}
*/
}
}
}

@ -166,7 +166,7 @@ ZExtInst::ZExtInst(Value* val, std::string name)
}
Value* ZExtInst::GetSrc() const { return GetOperand(0); }
/*
// ─── SIToFPInst ───────────────────────────────────────────────────────────────
SIToFPInst::SIToFPInst(Value* val, std::string name)
: Instruction(Opcode::SIToFP, Type::GetFloat32Type(), std::move(name)) {
@ -188,7 +188,7 @@ FPToSIInst::FPToSIInst(Value* val, std::string name)
}
Value* FPToSIInst::GetSrc() const { return GetOperand(0); }
*/
// ─── ReturnInst ───────────────────────────────────────────────────────────────
ReturnInst::ReturnInst(Value* val)
: Instruction(Opcode::Ret, Type::GetVoidType(), "") {
@ -224,7 +224,11 @@ AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, int num_elements,
// ─── GepInst ──────────────────────────────────────────────────────────────────
GepInst::GepInst(Value* base_ptr, Value* index, std::string name)
: Instruction(Opcode::Gep, Type::GetPtrInt32Type(), std::move(name)) {
: Instruction(Opcode::Gep,
(base_ptr && base_ptr->GetType()->IsPtrFloat32())
? Type::GetPtrFloat32Type()
: Type::GetPtrInt32Type(),
std::move(name)) {
if (!base_ptr || !index) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数"));
}
@ -265,10 +269,15 @@ GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
// ─── GlobalVariable ────────────────────────────────────────────────────────────
GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements)
: Value(Type::GetPtrInt32Type(), std::move(name)),
int num_elements, bool is_array_decl,
bool is_float)
: Value(is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type(),
std::move(name)),
is_const_(is_const),
is_float_(is_float),
init_val_(init_val),
num_elements_(num_elements) {}
init_val_f_(0.0f),
num_elements_(num_elements),
is_array_decl_(is_array_decl) {}
} // namespace ir

@ -28,9 +28,10 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
// ─── 全局变量管理 ─────────────────────────────────────────────────────────────
GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
bool is_const, int init_val,
int num_elements) {
int num_elements, bool is_array_decl,
bool is_float) {
globals_.push_back(
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements));
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements, is_array_decl, is_float));
GlobalVariable* g = globals_.back().get();
global_map_[name] = g;
return g;

@ -117,33 +117,101 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
for (int d : dims) total *= (d > 0 ? d : 1);
if (in_global_scope_) {
auto* gv = module_.CreateGlobalVariable(name, true, 0, total);
auto* gv = module_.CreateGlobalVariable(name, true, 0, total, true);
global_storage_map_[constDef] = gv;
global_array_dims_[constDef] = dims;
// 计算初始值并存入全局变量
if (constDef->constInitVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::vector<int> flat(total, 0);
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) { flat[pos] = EvalConstExprInt(iv->constExp()); return; }
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; fill(sub,a,sub_stride); cur=a+sub_stride; }
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur+top_stride-1)/top_stride)*top_stride; fill(sub,a,top_stride); cur=a+top_stride; }
}
}
gv->SetInitVals(flat);
}
} else {
auto* slot = builder_.CreateAllocaArray(total, name);
storage_map_[constDef] = slot;
array_dims_[constDef] = dims;
// 扁平化初始化
// 按 C 语义扁平化初始化(子列表对齐到维度边界)
if (constDef->constInitVal()) {
std::vector<int> flat;
flat.reserve(total);
std::function<void(SysYParser::ConstInitValContext*)> flatten =
[&](SysYParser::ConstInitValContext* iv) {
if (!iv) return;
if (iv->constExp()) {
flat.push_back(EvalConstExprInt(iv->constExp()));
} else {
for (auto* sub : iv->constInitVal()) flatten(sub);
}
};
flatten(constDef->constInitVal());
std::vector<int> flat(total, 0);
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) {
flat[pos] = EvalConstExprInt(iv->constExp());
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
for (int i = 0; i < total; ++i) {
int v = (i < (int)flat.size()) ? flat[i] : 0;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(v), ptr);
builder_.CreateStore(builder_.CreateConstInt(flat[i]), ptr);
}
}
}
@ -194,9 +262,97 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total, true, is_float);
global_storage_map_[ctx] = gv;
global_array_dims_[ctx] = dims;
// 计算初始值
if (ctx->initVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
if (is_float) {
std::vector<float> flat(total, 0.0f);
std::function<void(SysYParser::InitValContext*, int, int)> fill_f;
fill_f = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<float>(sem::EvaluateExp(*iv->exp()->addExp()).float_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill_f(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<float>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill_f(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitValsF(flat);
} else {
std::vector<int> flat(total, 0);
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<int>(sem::EvaluateExp(*iv->exp()->addExp()).int_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<int>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).int_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitVals(flat);
}
}
}
} else {
if (storage_map_.count(ctx)) {
@ -211,6 +367,14 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
ir::Value* init;
if (ctx->initVal() && ctx->initVal()->exp()) {
init = EvalExpr(*ctx->initVal()->exp());
// Coerce init value to slot type
if (!is_float && init->IsFloat32()) {
init = ToInt(init);
} else if (is_float && init->IsInt32()) {
init = ToFloat(init);
} else if (!is_float && init->IsInt1()) {
init = ToI32(init);
}
} else {
init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
@ -219,40 +383,95 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* slot = builder_.CreateAllocaArray(total, name);
auto* slot = is_float ? builder_.CreateAllocaArrayF32(total, module_.GetContext().NextTemp())
: builder_.CreateAllocaArray(total, module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
array_dims_[ctx] = dims;
ir::Value* zero_init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
if (ctx->initVal()) {
// 收集扁平化初始值
std::vector<ir::Value*> flat;
flat.reserve(total);
std::function<void(SysYParser::InitValContext*)> flatten =
[&](SysYParser::InitValContext* iv) {
if (!iv) return;
if (iv->exp()) {
flat.push_back(EvalExpr(*iv->exp()));
} else {
for (auto* sub : iv->initVal()) flatten(sub);
}
};
flatten(ctx->initVal());
for (int i = 0; i < total; ++i) {
ir::Value* v = (i < (int)flat.size()) ? flat[i]
: builder_.CreateConstInt(0);
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(v, ptr);
// 按 C 语义扁平化初始值:子列表对齐到对应维度边界
std::vector<ir::Value*> flat(total, zero_init);
// 计算各维度的 stridestride[i] = dims[i]*dims[i+1]*...*dims[n-1]
// 但我们需要「子列表对应第几维的 stride」
// 顶层stride = total / dims[0](即每行的元素数)
// 递归时 stride 继续除以当前维度大小
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0]; // 每个顶层子列表占用的元素数
// fill(iv, pos, stride):将 iv 的内容填入 flat[pos..pos+stride)
// stride 表示当前层子列表对应的元素个数
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
flat[pos] = EvalExpr(*iv->exp());
return;
}
// 子列表内的 stride = stride / (当前层首维大小)
// 找到对应的 strides 层stride == strides[k] → 子stride = strides[k+1]
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k) {
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
}
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 sub_stride 边界
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
// 顶层扫描
int cur = 0;
if (ctx->initVal()->exp()) {
flat[0] = EvalExpr(*ctx->initVal()->exp());
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 top_stride 边界
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
} else {
// 零初始化
// 先 memset 归零,再只写入非零元素
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
for (int i = 0; i < total; ++i) {
bool is_zero = false;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(flat[i])) {
is_zero = (ci->GetValue() == 0);
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(flat[i])) {
is_zero = (cf->GetValue() == 0.0f);
}
if (is_zero) continue;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), ptr);
builder_.CreateStore(flat[i], ptr);
}
} else {
// 零初始化:用 memset 归零
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
(void)zero_init;
}
}
}

@ -10,10 +10,15 @@
// ─── 辅助 ─────────────────────────────────────────────────────────────────────
// 把 i32 值转成 i1icmp ne i32 v, 0
// 把 i32/float 值转成 i1
ir::Value* IRGenImpl::ToI1(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value"));
if (v->IsInt1()) return v;
if (v->IsFloat32()) {
return builder_.CreateFCmp(ir::FCmpPredicate::ONE, v,
builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmp(ir::ICmpPredicate::NE, v,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
@ -26,27 +31,23 @@ ir::Value* IRGenImpl::ToI32(ir::Value* v) {
return builder_.CreateZExt(v, module_.GetContext().NextTemp());
}
// 转换为 float如果是 int
ir::Value* IRGenImpl::ToFloat(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToFloat: null value"));
if (v->IsFloat32()) return v;
if (v->IsInt32()) return builder_.CreateSIToFP(v, module_.GetContext().NextTemp());
if (v->IsInt1()) {
auto* i32 = ToI32(v);
return builder_.CreateSIToFP(i32, module_.GetContext().NextTemp());
ir::Value* IRGenImpl::ToInt(ir::Value* val) {
if (val->GetType()->IsInt32()) return val;
if (val->GetType()->IsFloat32()) {
throw std::runtime_error("不支持 float 转 int");
}
throw std::runtime_error(FormatError("irgen", "ToFloat: 不支持的类型"));
return val;
}
// 转换为 int如果是 float
ir::Value* IRGenImpl::ToInt(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToInt: null value"));
if (v->IsInt32()) return v;
if (v->IsFloat32()) return builder_.CreateFPToSI(v, module_.GetContext().NextTemp());
if (v->IsInt1()) return ToI32(v);
throw std::runtime_error(FormatError("irgen", "ToInt: 不支持的类型"));
ir::Value* IRGenImpl::ToFloat(ir::Value* val) {
if (val->GetType()->IsFloat32()) return val;
if (val->GetType()->IsInt32()) {
throw std::runtime_error("不支持 int 转 float");
}
return val;
}
// 隐式类型转换确保两个操作数类型一致int 转 float
void IRGenImpl::ImplicitConvert(ir::Value*& lhs, ir::Value*& rhs) {
if (!lhs || !rhs) return;
@ -87,7 +88,13 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
} else if (name == "getch") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
} else if (name == "getfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似
module_.DeclareExternalFunc(name, ir::Type::GetFloat32Type(), {});
} else if (name == "getarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type()});
} else if (name == "getfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrFloat32Type()});
} else if (name == "putint") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -95,10 +102,16 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else if (name == "putfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {});
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetFloat32Type()});
} else if (name == "putarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
{ir::Type::GetInt32Type(),
ir::Type::GetPtrInt32Type()});
} else if (name == "putfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(),
ir::Type::GetPtrFloat32Type()});
} else if (name == "starttime" || name == "stoptime") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -227,13 +240,113 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* exp : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*exp));
// 检查是否是数组变量(无索引的 lVar若是则传指针而非 load
ir::Value* arg = nullptr;
auto* add = exp->addExp();
if (add && add->mulExp().size() == 1) {
auto* mul = add->mulExp(0);
if (mul && mul->unaryExp().size() == 1) {
auto* unary = mul->unaryExp(0);
if (unary && !unary->unaryOp() && unary->primaryExp()) {
auto* primary = unary->primaryExp();
if (primary && primary->lVar() && primary->lVar()->exp().empty()) {
auto* lvar = primary->lVar();
auto* decl = sema_.ResolveVarUse(lvar->Ident());
if (decl) {
// 检查是否是数组参数storage_map_ 里存的是指针)
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
auto* val = it->second;
if (val && (val->IsPtrInt32() || val->IsPtrFloat32())) {
// 检查是否是 Argument数组参数直接传指针
if (dynamic_cast<ir::Argument*>(val)) {
arg = val;
} else if (array_dims_.count(decl)) {
// 本地数组(含 dims 记录):传首元素地址
arg = builder_.CreateGep(val, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
// 检查全局数组
if (!arg) {
auto git = global_storage_map_.find(decl);
if (git != global_storage_map_.end()) {
auto* gv = dynamic_cast<ir::GlobalVariable*>(git->second);
if (gv && gv->IsArray()) {
arg = builder_.CreateGep(git->second,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
}
}
}
}
}
// Also handle partially-indexed multi-dim arrays: arr[i] where arr is
// int arr[M][N] should pass &arr[i*N] as i32*, not load arr[i] as i32.
if (!arg) {
auto* add2 = exp->addExp();
if (add2 && add2->mulExp().size() == 1) {
auto* mul2 = add2->mulExp(0);
if (mul2 && mul2->unaryExp().size() == 1) {
auto* unary2 = mul2->unaryExp(0);
if (unary2 && !unary2->unaryOp() && unary2->primaryExp()) {
auto* primary2 = unary2->primaryExp();
if (primary2 && primary2->lVar() && !primary2->lVar()->exp().empty()) {
auto* lvar2 = primary2->lVar();
auto* decl2 = sema_.ResolveVarUse(lvar2->Ident());
if (decl2) {
std::vector<int> dims2;
ir::Value* base2 = nullptr;
auto it2 = array_dims_.find(decl2);
if (it2 != array_dims_.end()) {
dims2 = it2->second;
auto sit = storage_map_.find(decl2);
if (sit != storage_map_.end()) base2 = sit->second;
} else {
auto git2 = global_array_dims_.find(decl2);
if (git2 != global_array_dims_.end()) {
dims2 = git2->second;
auto gsit = global_storage_map_.find(decl2);
if (gsit != global_storage_map_.end()) base2 = gsit->second;
}
}
// Partially indexed: fewer indices than dimensions -> pass pointer
bool is_param = !dims2.empty() && dims2[0] == -1;
size_t effective_dims = is_param ? dims2.size() - 1 : dims2.size();
if (base2 && !dims2.empty() &&
lvar2->exp().size() < effective_dims + (is_param ? 1 : 0)) {
arg = EvalLVarAddr(lvar2);
}
}
}
}
}
}
}
if (!arg) arg = EvalExpr(*exp);
args.push_back(arg);
}
}
// 模块内已知函数?
ir::Function* callee = module_.GetFunction(callee_name);
if (callee) {
// Coerce args to match parameter types
for (size_t i = 0; i < args.size() && i < callee->GetNumArgs(); ++i) {
auto* param = callee->GetArgument(i);
if (!param || !args[i]) continue;
if (param->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name =
callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp();
auto* call =
@ -246,15 +359,28 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 外部函数
EnsureExternalDecl(callee_name);
// 获取返回类型
// 获取返回类型和参数类型
std::shared_ptr<ir::Type> ret_type = ir::Type::GetInt32Type();
std::vector<std::shared_ptr<ir::Type>> param_types;
for (const auto& decl : module_.GetExternalDecls()) {
if (decl.name == callee_name) {
ret_type = decl.ret_type;
param_types = decl.param_types;
break;
}
}
bool is_void = ret_type->IsVoid();
// Coerce args to match external function parameter types
for (size_t i = 0; i < args.size() && i < param_types.size(); ++i) {
if (!args[i]) continue;
if (param_types[i]->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param_types[i]->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param_types[i]->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name = is_void ? "" : module_.GetContext().NextTemp();
auto* call = builder_.CreateCallExternal(callee_name, ret_type,
std::move(args), ret_name);
@ -331,40 +457,26 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
throw std::runtime_error(FormatError("irgen", "数组索引维度过多"));
}
ir::Value* offset = builder_.CreateConstInt(0);
ir::Value* offset = nullptr;
if (is_array_param) {
// 数组参数dims[0]=-1, dims[1..n]是已知维度
// 索引indices[0]对应第一维indices[1]对应第二维...
for (size_t i = 0; i < indices.size(); ++i) {
ir::Value* idx = EvalExpr(*indices[i]);
if (i == 0) {
// 第一维stride = dims[1] * dims[2] * ... (如果有的话)
int stride = 1;
for (size_t j = 1; j < dims.size(); ++j) {
stride *= dims[j];
}
if (stride > 1) {
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
} else {
offset = builder_.CreateAdd(offset, idx,
module_.GetContext().NextTemp());
}
int stride = 1;
size_t start = (i == 0) ? 1 : i + 1;
for (size_t j = start; j < dims.size(); ++j) stride *= dims[j];
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
// 后续维度
int stride = 1;
for (size_t j = i + 1; j < dims.size(); ++j) {
stride *= dims[j];
}
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
} else {
@ -374,15 +486,24 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1];
if (i < (int)indices.size()) {
ir::Value* idx = EvalExpr(*indices[i]);
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
}
}
if (!offset) offset = builder_.CreateConstInt(0);
return builder_.CreateGep(base, offset, module_.GetContext().NextTemp());
}
@ -486,8 +607,8 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.end");
builder_.CreateCondBr(result, end_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb);
@ -498,6 +619,7 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}
@ -523,8 +645,8 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.end");
builder_.CreateCondBr(result, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
@ -535,6 +657,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}

@ -25,6 +25,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->Return()) {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp());
// Coerce return value to function return type
if (func_->GetType()->IsInt32() && v->IsFloat32()) {
v = ToInt(v);
} else if (func_->GetType()->IsFloat32() && v->IsInt32()) {
v = ToFloat(v);
} else if (func_->GetType()->IsInt32() && v->IsInt1()) {
v = ToI32(v);
}
builder_.CreateRet(v);
} else {
builder_.CreateRetVoid();
@ -54,6 +62,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->lVar() && ctx->Assign()) {
ir::Value* rhs = EvalExpr(*ctx->exp());
ir::Value* addr = EvalLVarAddr(ctx->lVar());
// Coerce rhs to match slot type
if (addr->IsPtrInt32() && rhs->IsFloat32()) {
rhs = ToInt(rhs);
} else if (addr->IsPtrFloat32() && rhs->IsInt32()) {
rhs = ToFloat(rhs);
} else if (addr->IsPtrInt32() && rhs->IsInt1()) {
rhs = ToI32(rhs);
}
builder_.CreateStore(rhs, addr);
return BlockFlow::Continue;
}
@ -74,32 +90,47 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
auto stmts = ctx->stmt();
// Step 1: evaluate condition (may create short-circuit blocks with lower
// SSA numbers — must happen before any branch-target blocks are created).
ir::Value* cond_val = EvalCond(*ctx->cond());
// Step 2: create then_bb now (its label number will be >= all short-circuit
// block numbers allocated during EvalCond).
ir::BasicBlock* then_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.then");
module_.GetContext().NextLabel() + ".if.then");
// Step 3: create else_bb/merge_bb as placeholders. They will be moved to
// the end of the block list after their predecessors are filled in, so the
// block ordering in the output will be correct even though their label
// numbers are allocated here (before then-body sub-blocks are created).
ir::BasicBlock* else_bb = nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.end");
ir::BasicBlock* merge_bb = nullptr;
// 求值条件(可能创建短路求值块)
ir::Value* cond_val = EvalCond(*ctx->cond());
if (stmts.size() >= 2) {
else_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.else");
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
} else {
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
}
// 检查当前块是否已终结(短路求值可能导致)
// Check if current block already terminated (short-circuit may do this)
if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) {
// 条件求值已经终结了当前块,无法继续
// 这种情况下我们需要在merge_bb继续
func_->MoveBlockToEnd(then_bb);
if (else_bb) func_->MoveBlockToEnd(else_bb);
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (stmts.size() >= 2) {
// if-else
else_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".if.else");
builder_.CreateCondBr(cond_val, then_bb, else_bb);
} else {
builder_.CreateCondBr(cond_val, then_bb, merge_bb);
}
// then 分支
// then 分支 — visit body (may create many sub-blocks with higher numbers)
func_->MoveBlockToEnd(then_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = VisitStmt(*stmts[0]);
if (then_flow != BlockFlow::Terminated) {
@ -108,6 +139,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
// else 分支
if (else_bb) {
func_->MoveBlockToEnd(else_bb);
builder_.SetInsertPoint(else_bb);
auto else_flow = VisitStmt(*stmts[1]);
if (else_flow != BlockFlow::Terminated) {
@ -115,6 +147,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
}
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
@ -124,28 +157,32 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (!ctx->cond()) {
throw std::runtime_error(FormatError("irgen", "while 缺少条件"));
}
ir::BasicBlock* cond_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.cond");
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.end");
module_.GetContext().NextLabel() + ".while.cond");
// 跳转到条件块
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(cond_bb);
}
// 条件块
// EvalCond MUST come before creating body_bb/after_bb so that
// short-circuit blocks get lower SSA numbers than the loop body blocks.
builder_.SetInsertPoint(cond_bb);
ir::Value* cond_val = EvalCond(*ctx->cond());
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.end");
// 检查条件求值后是否已终结
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateCondBr(cond_val, body_bb, after_bb);
}
// 循环体(压入循环栈)
func_->MoveBlockToEnd(body_bb);
loop_stack_.push_back({cond_bb, after_bb});
builder_.SetInsertPoint(body_bb);
auto stmts = ctx->stmt();
@ -159,6 +196,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
loop_stack_.pop_back();
func_->MoveBlockToEnd(after_bb);
builder_.SetInsertPoint(after_bb);
return BlockFlow::Continue;
}

@ -46,13 +46,16 @@ int main(int argc, char** argv) {
}
if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
// 修改:支持多函数
auto machine_funcs = mir::LowerToMIR(*module);
for (auto& mf : machine_funcs) {
mir::RunRegAlloc(*mf);
mir::RunFrameLowering(*mf);
}
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(machine_funcs, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {
@ -65,4 +68,4 @@ int main(int argc, char** argv) {
return 1;
}
return 0;
}
}

@ -2,9 +2,16 @@
#include <ostream>
#include <stdexcept>
#include <iostream>
#include <vector>
#include <unordered_map>
#include "utils/Log.h"
// 引用全局变量(定义在 Lowering.cpp 中)
extern std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
@ -16,63 +23,575 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
void EmitStackLoad(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " lw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " lw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
} // namespace
void EmitStackStore(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " sw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " sw " << PhysRegName(src) << ", 0(t4)\n";
}
}
void PrintAsm(const MachineFunction& function, std::ostream& os) {
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
void EmitStackLoadFloat(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " flw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " flw " << PhysRegName(dst) << ", 0(t4)\n";
}
}
void EmitStackStoreFloat(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) {
if (offset >= -2048 && offset <= 2047) {
os << " fsw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n";
} else {
os << " li t4, " << offset << "\n";
os << " add t4, " << PhysRegName(base) << ", t4\n";
os << " fsw " << PhysRegName(src) << ", 0(t4)\n";
}
}
// 输出单个函数的汇编
void PrintAsmFunction(const MachineFunction& function, std::ostream& os) {
// 收集所有基本块名称
std::unordered_map<const MachineBasicBlock*, std::string> block_names;
for (const auto& block_ptr : function.GetBlocks()) {
block_names[block_ptr.get()] = block_ptr->GetName();
}
int total_frame_size = 16 + function.GetFrameSize();
bool prologue_done = false;
for (const auto& block_ptr : function.GetBlocks()) {
const auto& block = *block_ptr;
// 输出基本块标签(入口块不输出,因为函数名已经是标签)
if (block.GetName() != "entry") {
os << block.GetName() << ":\n";
}
for (const auto& inst : block.GetInstructions()) {
const auto& ops = inst.GetOperands();
// 在入口块的第一条指令前输出序言
if (!prologue_done && block.GetName() == "entry") {
// 处理大栈帧的情况
if (total_frame_size <= 2047) {
os << " addi sp, sp, -" << total_frame_size << "\n";
} else {
os << " li t4, -" << total_frame_size << "\n";
os << " add sp, sp, t4\n";
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
// 保存 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " sw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t4, " << ra_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw ra, 0(t4)\n";
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
if (s0_offset <= 2047) {
os << " sw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t4, " << s0_offset << "\n";
os << " add t4, sp, t4\n";
os << " sw s0, 0(t4)\n";
}
prologue_done = true;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
switch (inst.GetOpcode()) {
case Opcode::Prologue:
case Opcode::Epilogue:
break;
case Opcode::MovImm:
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::Load: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoad(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Store: {
if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
} else if (ops.size() == 3 && ops.at(1).GetKind() == Operand::Kind::Reg
&& ops.at(2).GetKind() == Operand::Kind::Imm) {
// 新增:支持 sw rs, offset(rt)
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(2).GetImm() << "(" << PhysRegName(ops.at(1).GetReg()) << ")\n";
} else {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStore(os, ops.at(0).GetReg(), slot.offset);
}
break;
}
case Opcode::Add:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Addi:
if (ops.size() == 3 && ops.at(2).GetKind() == Operand::Kind::Imm) {
os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
}
break;
case Opcode::Sub:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Mul: {
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
} else {
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
}
break;
}
case Opcode::Div:
os << " div " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Rem:
os << " rem " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slt:
os << " slt " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Slti:
os << " slti " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Sltu:
os << " sltu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sltiu: // <-- 添加这个
os << " sltiu " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Srli:
os << " srli " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Srai:
os << " srai " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Srl:
os << " srl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Sra:
os << " sra " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Xori:
os << " xori " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::And:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Andi:
os << " andi " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Or:
os << " or " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Ori:
os << " ori " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::Xor:
os << " xor " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LoadGlobalAddr: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la " << PhysRegName(ops.at(0).GetReg()) << ", " << global_name << "\n";
break;
}
case Opcode::LoadGlobal:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::StoreGlobal: {
std::string global_name = ops.at(1).GetGlobalName();
os << " la t1, " << global_name << "\n";
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t1)\n";
break;
}
case Opcode::GEP:
break;
case Opcode::LoadIndirect:
os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Call: {
std::string func_name = "memset";
if (!ops.empty() && ops[0].GetKind() == Operand::Kind::Func) {
func_name = ops[0].GetFuncName();
}
os << " call " << func_name << "\n";
break;
}
case Opcode::LoadAddr: {
int frame_idx = ops.at(1).GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
if (slot.offset >= -2048 && slot.offset <= 2047) {
os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", sp, " << slot.offset << "\n";
} else {
os << " li " << PhysRegName(ops.at(0).GetReg()) << ", " << slot.offset << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", sp, "
<< PhysRegName(ops.at(0).GetReg()) << "\n";
}
break;
}
case Opcode::Slli:
os << " slli " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::StoreIndirect:
os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0("
<< PhysRegName(ops.at(1).GetReg()) << ")\n";
break;
case Opcode::Ret:{
// 恢复 ra 和 s0
int ra_offset = total_frame_size - 8;
int s0_offset = total_frame_size - 16;
if (ra_offset <= 2047) {
os << " lw ra, " << ra_offset << "(sp)\n";
} else {
os << " li t4, " << ra_offset << "\n";
os << " add t4, sp, t4\n";
os << " lw ra, 0(t4)\n";
}
if (s0_offset <= 2047) {
os << " lw s0, " << s0_offset << "(sp)\n";
} else {
os << " li t4, " << s0_offset << "\n";
os << " add t4, sp, t4\n";
os << " lw s0, 0(t4)\n";
}
// 恢复 sp
if (total_frame_size <= 2047) {
os << " addi sp, sp, " << total_frame_size << "\n";
} else {
os << " li t4, " << total_frame_size << "\n";
os << " add sp, sp, t4\n";
}
os << " ret\n";
break;
}
case Opcode::Br: {
auto* target = reinterpret_cast<MachineBasicBlock*>(ops[0].GetImm64());
os << " j " << target->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* true_target = reinterpret_cast<MachineBasicBlock*>(ops[1].GetImm64());
auto* false_target = reinterpret_cast<MachineBasicBlock*>(ops[2].GetImm64());
auto true_it = block_names.find(true_target);
auto false_it = block_names.find(false_target);
if (true_it == block_names.end() || false_it == block_names.end()) {
throw std::runtime_error(FormatError("mir", "CondBr: 找不到基本块名称"));
}
os << " bnez " << PhysRegName(ops[0].GetReg()) << ", "
<< true_it->second << "\n";
os << " j " << false_it->second << "\n";
break;
}
// 浮点运算
case Opcode::FAdd:
os << " fadd.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FSub:
os << " fsub.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMul:
os << " fmul.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FDiv:
os << " fdiv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FEq:
os << " feq.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLt:
os << " flt.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FLe:
os << " fle.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << ", "
<< PhysRegName(ops[2].GetReg()) << "\n";
break;
case Opcode::FMov:
os << " fmv.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovWX:
os << " fmv.w.x " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FMovXW:
os << " fmv.x.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FNeg:
os << " fneg.s " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FAbs:
os << " fabs.s " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
/*
case Opcode::SIToFP:
os << " fcvt.s.w " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvt.w.s " << PhysRegName(ops[0].GetReg()) << ", "
<< PhysRegName(ops[1].GetReg()) << "\n";
break;
*/
case Opcode::LoadFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " flw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackLoadFloat(os, ops[0].GetReg(), slot.offset);
}
break;
case Opcode::StoreFloat:
if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) {
os << " fsw " << PhysRegName(ops[0].GetReg()) << ", 0("
<< PhysRegName(ops[1].GetReg()) << ")\n";
} else {
int frame_idx = ops[1].GetFrameIndex();
const auto& slot = function.GetFrameSlot(frame_idx);
EmitStackStoreFloat(os, ops[0].GetReg(), slot.offset);
}
break;
default:
break;
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n";
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 输出多个函数的汇编
void PrintAsm(const std::vector<std::unique_ptr<MachineFunction>>& functions, std::ostream& os) {
// ========== 输出全局变量 ==========
// 输出 .data 段(已初始化的全局变量)
bool hasData = false;
for (const auto& gv : g_globalVars) {
if (!gv.isConst) {
if (!hasData) {
os << ".data\n";
hasData = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
}
}
// 输出 .rodata 段(只读常量)
bool hasRodata = false;
for (const auto& gv : g_globalVars) {
if (gv.isConst) {
if (!hasRodata) {
os << ".section .rodata\n";
hasRodata = true;
}
os << " .global " << gv.name << "\n";
os << " .type " << gv.name << ", @object\n";
if (gv.isArray && gv.arraySize > 1) {
int totalSize = gv.arraySize * 4;
os << " .size " << gv.name << ", " << totalSize << "\n";
os << gv.name << ":\n";
if (!gv.arrayValues.empty()) {
for (int val : gv.arrayValues) {
os << " .word " << val << "\n";
}
} else {
for (int i = 0; i < gv.arraySize; i++) {
os << " .word 0\n";
}
}
} else {
os << " .size " << gv.name << ", 4\n";
os << gv.name << ":\n";
os << " .word " << gv.value << "\n";
}
}
}
// ========== 输出代码段 ==========
os << ".text\n";
// 输出每个函数
for (const auto& func_ptr : functions) {
os << ".global " << func_ptr->GetName() << "\n";
os << ".type " << func_ptr->GetName() << ", @function\n";
os << func_ptr->GetName() << ":\n";
PrintAsmFunction(*func_ptr, os);
os << "\n"; // 函数之间加空行
}
}
} // namespace mir
} // namespace mir

@ -15,10 +15,12 @@ target_link_libraries(mir_core PUBLIC
ir
)
target_compile_options(mir_core PRIVATE -Wno-unused-parameter)
add_subdirectory(passes)
add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE
mir_core
mir_passes
)
)

@ -18,9 +18,9 @@ void RunFrameLowering(MachineFunction& function) {
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
//if (-cursor < -2048) {
//throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
//}
}
cursor = 0;
@ -30,7 +30,8 @@ void RunFrameLowering(MachineFunction& function) {
}
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions();
// 修复GetEntry() 返回指针,使用 ->
auto& insts = function.GetEntry()->GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
@ -42,4 +43,4 @@ void RunFrameLowering(MachineFunction& function) {
insts = std::move(lowered);
}
} // namespace mir
} // namespace mir

@ -1,123 +1,716 @@
#include "mir/MIR.h"
#include <iostream>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <cstring>
#include "ir/IR.h"
#include "utils/Log.h"
std::vector<mir::GlobalVarInfo> g_globalVars;
namespace mir {
namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
static std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
MachineBasicBlock* GetOrCreateBlock(const ir::BasicBlock* ir_block,
MachineFunction& function) {
auto it = block_map.find(ir_block);
if (it != block_map.end()) {
return it->second;
}
std::string name = ir_block->GetName();
if (name.empty()) {
name = "block_" + std::to_string(block_map.size());
}
auto* block = function.CreateBlock(name);
block_map[ir_block] = block;
return block;
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
const ValueSlotMap& slots, MachineBasicBlock& block,
bool for_address=false) {
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
auto it = slots.find(arg);
if (it != slots.end()) {
// 从栈槽加载参数值
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
}
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
int64_t val = constant->GetValue();
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
{Operand::Reg(target), Operand::Imm(static_cast<int>(val))});
return;
}
// 处理浮点常量
if (auto* fconstant = dynamic_cast<const ir::ConstantFloat*>(value)) {
float val = fconstant->GetValue();
uint32_t bits;
memcpy(&bits, &val, sizeof(val));
// 检查目标是否是浮点寄存器
bool target_is_fp = (target == PhysReg::FT0 || target == PhysReg::FT1 ||
target == PhysReg::FT2 || target == PhysReg::FT3 ||
target == PhysReg::FT4 || target == PhysReg::FT5 ||
target == PhysReg::FT6 || target == PhysReg::FT7 ||
target == PhysReg::FA0 || target == PhysReg::FA1 ||
target == PhysReg::FA2 || target == PhysReg::FA3 ||
target == PhysReg::FA4 || target == PhysReg::FA5 ||
target == PhysReg::FA6 || target == PhysReg::FA7);
if (target_is_fp) {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::T0), Operand::Imm(static_cast<int>(bits))});
block.Append(Opcode::FMovWX, {Operand::Reg(target), Operand::Reg(PhysReg::T0)});
} else {
// 目标是整数寄存器,直接加载
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(static_cast<int>(bits))});
}
return;
}
if (auto* gep = dynamic_cast<const ir::GepInst*>(value)) {
EmitValueToReg(gep->GetBasePtr(), target, slots, block, true);
EmitValueToReg(gep->GetIndex(), PhysReg::T1, slots, block);
block.Append(Opcode::Slli, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Imm(2)});
block.Append(Opcode::Add, {Operand::Reg(target),
Operand::Reg(target),
Operand::Reg(PhysReg::T1)});
return;
}
if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(value)) {
auto it = slots.find(alloca);
if (it != slots.end()) {
block.Append(Opcode::LoadAddr,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(value)) {
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(target), Operand::Global(global->GetName())});
if (!for_address) {
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::Reg(target)});
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Reg(target)});
}
}
return;
}
// 关键:在 slots 中查找,并根据类型生成正确的加载指令
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
if (it != slots.end()) {
if (value->GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} else {
block.Append(Opcode::Load,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
return;
}
std::cerr << "未找到的值: " << value << std::endl;
std::cerr << " 名称: " << value->GetName() << std::endl;
std::cerr << " 类型: " << (value->GetType()->IsFloat32() ? "float" : "int") << std::endl;
std::cerr << " 是否是 ConstantInt: " << (dynamic_cast<const ir::ConstantInt*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 ConstantFloat: " << (dynamic_cast<const ir::ConstantFloat*>(value) != nullptr) << std::endl;
std::cerr << " 是否是 Instruction: " << (dynamic_cast<const ir::Instruction*>(value) != nullptr) << std::endl;
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
}
void StoreRegToSlot(PhysReg reg, int slot, MachineBasicBlock& block, bool isFloat = false) {
if (isFloat) {
block.Append(Opcode::StoreFloat,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
} else {
block.Append(Opcode::Store,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
}
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots) {
auto& block = function.GetEntry();
// 将 LowerInstruction 重命名为 LowerInstructionToBlock并添加 MachineBasicBlock 参数
void LowerInstructionToBlock(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots, MachineBasicBlock& block) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex());
auto& alloca = static_cast<const ir::AllocaInst&>(inst);
int size = 4;
if (alloca.GetNumElements() > 1) {
size = alloca.GetNumElements() * 4;
}
slots.emplace(&inst, function.CreateFrameIndex(size));
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T2, slots, block);
EmitValueToReg(store.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::T2), Operand::Reg(PhysReg::T0)});
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(store.GetPtr())) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
std::string global_name = global->GetName();
block.Append(Opcode::StoreGlobal,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
return;
}
auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
if (dst != slots.end()) {
EmitValueToReg(store.GetValue(), PhysReg::T0, slots, block);
StoreRegToSlot(PhysReg::T0, dst->second, block);
return;
}
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
return;
throw std::runtime_error(FormatError("mir", "Store: 无法处理的指针类型"));
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
if (dynamic_cast<const ir::GepInst*>(load.GetPtr())) {
EmitValueToReg(load.GetPtr(), PhysReg::T0, slots, block, true);
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
int dst_slot = function.CreateFrameIndex(4);
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
if (auto* global = dynamic_cast<const ir::GlobalVariable*>(load.GetPtr())) {
int dst_slot = function.CreateFrameIndex(4);
std::string global_name = global->GetName();
block.Append(Opcode::LoadGlobalAddr,
{Operand::Reg(PhysReg::T0), Operand::Global(global_name)});
if (global->IsFloat()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, true);
} else {
block.Append(Opcode::LoadGlobal,
{Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
auto src = slots.find(load.GetPtr());
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
if (src != slots.end()) {
int dst_slot = function.CreateFrameIndex(4);
if (load.GetType()->IsFloat32()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
} else {
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src->second)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
}
slots.emplace(&inst, dst_slot);
return;
}
throw std::runtime_error(FormatError("mir", "Load: 无法处理的指针类型"));
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
EmitValueToReg(bin.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::T1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::Add: op = Opcode::Add; break;
case ir::Opcode::Sub: op = Opcode::Sub; break;
case ir::Opcode::Mul: op = Opcode::Mul; break;
case ir::Opcode::Div: op = Opcode::Div; break;
case ir::Opcode::Mod: op = Opcode::Rem; break;
default: op = Opcode::Add; break;
}
block.Append(op, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
case ir::Opcode::Gep: {
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
block.Append(Opcode::Ret);
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
size_t num_args = call.GetNumArgs();
// 前 8 个参数通过寄存器 a0-a7 传递
for (size_t i = 0; i < num_args && i < 8; i++) {
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
EmitValueToReg(call.GetArg(i), argReg, slots, block);
}
// 超过 8 个参数通过栈传递
// RISC-V 调用约定:栈参数从高地址到低地址传递
// 需要在调用前预留空间
if (num_args > 8) {
// 计算需要预留的栈空间(每个参数 4 字节,按 16 字节对齐)
int stack_args = num_args - 8;
int stack_space = (stack_args * 4 + 15) & ~15; // 16 字节对齐
// 预留栈空间
block.Append(Opcode::Addi, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Imm(-stack_space)});
// 将参数存入栈(从最后一个参数开始)
int stack_offset = 0;
for (size_t i = 8; i < num_args; i++) {
// 将参数值加载到临时寄存器
EmitValueToReg(call.GetArg(i), PhysReg::T3, slots, block);
// 存储到栈上
if (stack_offset <= 2047) {
block.Append(Opcode::Store, {Operand::Reg(PhysReg::T5),
Operand::Reg(PhysReg::SP),
Operand::Imm(stack_offset)});
} else {
// 大偏移处理
block.Append(Opcode::Addi, {Operand::Reg(PhysReg::T4),
Operand::Reg(PhysReg::SP),
Operand::Imm(stack_offset)});
block.Append(Opcode::StoreIndirect, {Operand::Reg(PhysReg::T5),
Operand::Reg(PhysReg::T4)});
}
stack_offset += 4;
}
}
std::string func_name = call.GetCalleeName();
block.Append(Opcode::Call, {Operand::Func(func_name)});
// 调用后恢复栈指针(如果有栈参数)
if (num_args > 8) {
int stack_args = num_args - 8;
int stack_space = (stack_args * 4 + 15) & ~15;
block.Append(Opcode::Addi, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Imm(stack_space)});
}
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
StoreRegToSlot(PhysReg::A0, dst_slot, block);
slots.emplace(&inst, dst_slot);
}
return;
}
case ir::Opcode::Sub:
case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
}
case ir::Opcode::ICmp: {
auto& icmp = static_cast<const ir::ICmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(icmp.GetLhs(), PhysReg::T0, slots, block);
EmitValueToReg(icmp.GetRhs(), PhysReg::T1, slots, block);
ir::ICmpPredicate pred = icmp.GetPredicate();
switch (pred) {
case ir::ICmpPredicate::EQ:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
break;
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
}
case ir::ICmpPredicate::NE:
block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0),
Operand::Imm(1)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
} // namespace
case ir::ICmpPredicate::SLT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1)});
break;
case ir::ICmpPredicate::SLE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
case ir::ICmpPredicate::SGT:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
break;
case ir::ICmpPredicate::SGE:
block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::T0)});
block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::T1),
Operand::Imm(1)});
break;
}
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::ZExt: {
auto& zext = static_cast<const ir::ZExtInst&>(inst);
int dst_slot = function.CreateFrameIndex(4); // i32 是 4 字节
// 获取源操作数的值
EmitValueToReg(zext.GetSrc(), PhysReg::T0, slots, block);
// 存储到新栈槽
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(bin.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::FT1, slots, block);
Opcode op;
switch (inst.GetOpcode()) {
case ir::Opcode::FAdd: op = Opcode::FAdd; break;
case ir::Opcode::FSub: op = Opcode::FSub; break;
case ir::Opcode::FMul: op = Opcode::FMul; break;
case ir::Opcode::FDiv: op = Opcode::FDiv; break;
default: op = Opcode::FAdd; break;
}
block.Append(op, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
DefaultContext();
case ir::Opcode::FCmp: {
auto& fcmp = static_cast<const ir::FCmpInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
EmitValueToReg(fcmp.GetLhs(), PhysReg::FT0, slots, block);
EmitValueToReg(fcmp.GetRhs(), PhysReg::FT1, slots, block);
ir::FCmpPredicate pred = fcmp.GetPredicate();
switch (pred) {
case ir::FCmpPredicate::OEQ:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLT:
block.Append(Opcode::FLt, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
case ir::FCmpPredicate::OLE:
block.Append(Opcode::FLe, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
default:
block.Append(Opcode::FEq, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::FT1)});
break;
}
block.Append(Opcode::FMov, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block);
slots.emplace(&inst, dst_slot);
return;
}
/*
case ir::Opcode::SIToFP: {
auto& conv = static_cast<const ir::SIToFPInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "SIToFP: 找不到源操作数的栈槽"));
}
block.Append(Opcode::Load,
{Operand::Reg(PhysReg::T0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::FT0),
Operand::Reg(PhysReg::T0)});
StoreRegToSlot(PhysReg::FT0, dst_slot, block, true);
slots.emplace(&inst, dst_slot);
return;
}
if (module.GetFunctions().size() != 1) {
throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
}
case ir::Opcode::FPToSI: {
auto& conv = static_cast<const ir::FPToSIInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
auto src_it = slots.find(conv.GetSrc());
if (src_it == slots.end()) {
throw std::runtime_error(FormatError("mir", "FPToSI: 找不到源操作数的栈槽"));
}
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(src_it->second)});
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::T0),
Operand::Reg(PhysReg::FT0)});
StoreRegToSlot(PhysReg::T0, dst_slot, block, false);
slots.emplace(&inst, dst_slot);
return;
}
*/
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BrInst&>(inst);
auto* target = br.GetTarget();
MachineBasicBlock* target_block = GetOrCreateBlock(target, function);
block.Append(Opcode::Br, {Operand::Imm64(reinterpret_cast<intptr_t>(target_block))});
return;
}
case ir::Opcode::CondBr: {
auto& condbr = static_cast<const ir::CondBrInst&>(inst);
auto* true_bb = condbr.GetTrueBB();
auto* false_bb = condbr.GetFalseBB();
// 如果条件涉及函数调用,需要特殊处理
// 简单方案:将条件值保存到栈槽
int cond_slot = function.CreateFrameIndex(4);
EmitValueToReg(condbr.GetCond(), PhysReg::T0, slots, block);
// 保存条件值到栈
block.Append(Opcode::Store, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)});
// 从栈加载条件值(确保函数调用后还能获取)
block.Append(Opcode::Load, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)});
block.Append(Opcode::Sltu, {Operand::Reg(PhysReg::T1),
Operand::Reg(PhysReg::ZERO),
Operand::Reg(PhysReg::T0)});
MachineBasicBlock* true_block = GetOrCreateBlock(true_bb, function);
MachineBasicBlock* false_block = GetOrCreateBlock(false_bb, function);
const auto& func = *module.GetFunctions().front();
if (func.GetName() != "main") {
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
block.Append(Opcode::CondBr, {Operand::Reg(PhysReg::T1),
Operand::Imm64(reinterpret_cast<intptr_t>(true_block)),
Operand::Imm64(reinterpret_cast<intptr_t>(false_block))});
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
auto val = ret.GetValue();
if (val->GetType()->IsFloat32()) {
auto it = slots.find(val);
if (it != slots.end()) {
block.Append(Opcode::LoadFloat,
{Operand::Reg(PhysReg::FT0), Operand::FrameIndex(it->second)});
//block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::A0),
//Operand::Reg(PhysReg::FT0)});
} else {
throw std::runtime_error(FormatError("mir", "Ret: 找不到浮点返回值的栈槽"));
}
} else {
EmitValueToReg(val, PhysReg::A0, slots, block);
}
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::A0), Operand::Imm(0)});
}
block.Append(Opcode::Ret);
return;
}
default: {
break;
}
}
}
} // namespace
std::unique_ptr<MachineFunction> LowerFunctionToMIR(const ir::Function& func) {
block_map.clear();
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
const auto* entry = func.GetEntry();
if (!entry) {
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
// ========== 新增:为函数参数分配栈槽 ==========
// ========== 为函数参数分配栈槽 ==========
for (size_t i = 0; i < func.GetNumArgs(); i++) {
ir::Argument* arg = func.GetArgument(i);
int slot = machine_func->CreateFrameIndex(4);
MachineBasicBlock* entry = machine_func->GetEntry();
if (i < 8) {
// 前 8 个参数通过寄存器传递
PhysReg argReg = static_cast<PhysReg>(static_cast<int>(PhysReg::A0) + i);
if (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32()) {
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsInt32()) {
entry->Append(Opcode::Store, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
} else if (arg->GetType()->IsFloat32()) {
entry->Append(Opcode::StoreFloat, {Operand::Reg(argReg), Operand::FrameIndex(slot)});
}
} else {
// 超过 8 个的参数通过栈传递
// 栈参数位于进入函数时的 sp + (i - 8) * 4 位置
// 但由于函数序言减去了栈帧,实际偏移需要加上栈帧大小
// 我们先用临时寄存器加载栈参数,然后保存到栈槽
int stack_offset = (i - 8) * 4;
int frame_size = machine_func->GetFrameSize(); // 注意:此时栈帧大小还未计算!
// 由于此时还不知道栈帧大小,我们先记录栈参数的偏移
// 在 FrameLowering 后再处理?不,我们可以直接用 sp 的相对偏移
// 在函数入口sp 已经减去了栈帧,所以栈参数在 sp + frame_size + stack_offset
// 简单方案:先用 t3 从 sp + stack_offset 加载(假设在调整 sp 之前)
// 但我们的序言在 Lowering 之后才插入,所以这里直接生成加载指令
// 使用 t3 作为临时寄存器
entry->Append(Opcode::Addi, {Operand::Reg(PhysReg::T3),
Operand::Reg(PhysReg::SP),
Operand::Imm(stack_offset)});
if (arg->GetType()->IsFloat32()) {
entry->Append(Opcode::LoadFloat, {Operand::Reg(PhysReg::FT0), Operand::Reg(PhysReg::T3)});
entry->Append(Opcode::StoreFloat, {Operand::Reg(PhysReg::FT0), Operand::FrameIndex(slot)});
} else {
entry->Append(Opcode::Load, {Operand::Reg(PhysReg::T3), Operand::Reg(PhysReg::T3)});
entry->Append(Opcode::Store, {Operand::Reg(PhysReg::T3), Operand::FrameIndex(slot)});
}
}
slots[arg] = slot;
}
for (const auto& inst : entry->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots);
// 第一遍:创建所有 IR 基本块对应的 MIR 基本块
for (const auto& ir_block : func.GetBlocks()) {
GetOrCreateBlock(ir_block.get(), *machine_func);
}
// 第二遍:遍历所有基本块,降低指令
for (const auto& ir_block : func.GetBlocks()) {
MachineBasicBlock* mbb = GetOrCreateBlock(ir_block.get(), *machine_func);
for (const auto& inst : ir_block->GetInstructions()) {
LowerInstructionToBlock(*inst, *machine_func, slots, *mbb);
}
}
return machine_func;
}
} // namespace mir
std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module) {
DefaultContext();
// 收集全局变量(只做一次)
g_globalVars.clear();
for (const auto& global : module.GetGlobalVariables()) {
GlobalVarInfo info;
info.name = global->GetName();
info.isConst = global->IsConst();
info.isArray = global->IsArray();
info.arraySize = global->GetNumElements();
info.isFloat = global->IsFloat();
info.value = 0;
info.valueF = 0.0f;
if (info.isArray) {
if (info.isFloat) {
const auto& initVals = global->GetInitValsF();
for (float val : initVals) {
info.arrayValuesF.push_back(val);
}
} else {
if (global->HasInitVals()) {
const auto& initVals = global->GetInitVals();
for (int val : initVals) {
info.arrayValues.push_back(val);
}
}
}
} else {
if (info.isFloat) {
info.valueF = global->GetInitValF();
} else {
info.value = global->GetInitVal();
}
}
g_globalVars.push_back(info);
}
const auto& functions = module.GetFunctions();
if (functions.empty()) {
throw std::runtime_error(FormatError("mir", "模块中没有函数"));
}
std::vector<std::unique_ptr<MachineFunction>> result;
// 为每个函数生成 MachineFunction
for (const auto& func : functions) {
auto machine_func = LowerFunctionToMIR(*func);
result.push_back(std::move(machine_func));
}
return result;
}
} // namespace mir

@ -8,7 +8,16 @@
namespace mir {
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {}
: name_(std::move(name)) {
entry_ = CreateBlock("entry");
}
MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) {
auto block = std::make_unique<MachineBasicBlock>(name);
auto* ptr = block.get();
blocks_.push_back(std::move(block));
return ptr;
}
int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size());
@ -30,4 +39,4 @@ const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
return frame_slots_[index];
}
} // namespace mir
} // namespace mir

@ -6,18 +6,34 @@ namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand::Operand(Kind kind, PhysReg reg, int64_t imm64)
: kind_(kind), reg_(PhysReg::ZERO), imm_(0), imm64_(imm64) {}
// 新增构造函数
Operand::Operand(Kind kind, PhysReg reg, int imm, const std::string& name)
: kind_(kind), reg_(reg), imm_(imm), global_name_(name) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
return Operand(Kind::Imm, PhysReg::ZERO, value);
}
Operand Operand::Imm64(int64_t value) {
return Operand(Kind::Imm, PhysReg::ZERO, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index);
return Operand(Kind::FrameIndex, PhysReg::ZERO, index);
}
// 新增
Operand Operand::Global(const std::string& name) {
return Operand(Kind::Global, PhysReg::ZERO, 0, name);
}
Operand Operand::Func(const std::string& name) {
Operand op(Kind::Func, PhysReg::ZERO, 0);
op.func_name_ = name;
return op;
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
} // namespace mir
} // namespace mir

@ -9,12 +9,66 @@ namespace {
bool IsAllowedReg(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29:
case PhysReg::X30:
// 临时寄存器
case PhysReg::T0:
case PhysReg::T1:
case PhysReg::T2:
case PhysReg::T3:
case PhysReg::T4:
case PhysReg::T5:
case PhysReg::T6:
// 参数/返回值寄存器
case PhysReg::A0:
case PhysReg::A1:
case PhysReg::A2:
case PhysReg::A3:
case PhysReg::A4:
case PhysReg::A5:
case PhysReg::A6:
case PhysReg::A7:
// 保存寄存器
case PhysReg::S0:
case PhysReg::S1:
case PhysReg::S2:
case PhysReg::S3:
case PhysReg::S4:
case PhysReg::S5:
case PhysReg::S6:
case PhysReg::S7:
case PhysReg::S8:
case PhysReg::S9:
case PhysReg::S10:
case PhysReg::S11:
// 特殊寄存器
case PhysReg::ZERO:
case PhysReg::RA:
case PhysReg::SP:
case PhysReg::GP:
case PhysReg::TP:
case PhysReg::FT0:
case PhysReg::FT1:
case PhysReg::FT2:
case PhysReg::FT3:
case PhysReg::FT4:
case PhysReg::FT5:
case PhysReg::FT6:
case PhysReg::FT7:
case PhysReg::FT8:
case PhysReg::FT9:
case PhysReg::FT10:
case PhysReg::FT11:
// 浮点保存寄存器
case PhysReg::FS0:
case PhysReg::FS1:
// 浮点参数寄存器
case PhysReg::FA0:
case PhysReg::FA1:
case PhysReg::FA2:
case PhysReg::FA3:
case PhysReg::FA4:
case PhysReg::FA5:
case PhysReg::FA6:
case PhysReg::FA7:
return true;
}
return false;
@ -23,7 +77,8 @@ bool IsAllowedReg(PhysReg reg) {
} // namespace
void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) {
// 修复GetEntry() 返回指针,使用 ->
for (const auto& inst : function.GetEntry()->GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
@ -33,4 +88,4 @@ void RunRegAlloc(MachineFunction& function) {
}
}
} // namespace mir
} // namespace mir

@ -8,20 +8,96 @@ namespace mir {
const char* PhysRegName(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
return "w0";
case PhysReg::W8:
return "w8";
case PhysReg::W9:
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
// 整数寄存器
case PhysReg::ZERO:
return "zero";
case PhysReg::RA:
return "ra";
case PhysReg::SP:
return "sp";
case PhysReg::GP:
return "gp";
case PhysReg::TP:
return "tp";
case PhysReg::T0:
return "t0";
case PhysReg::T1:
return "t1";
case PhysReg::T2:
return "t2";
case PhysReg::S0:
return "s0";
case PhysReg::S1:
return "s1";
case PhysReg::A0:
return "a0";
case PhysReg::A1:
return "a1";
case PhysReg::A2:
return "a2";
case PhysReg::A3:
return "a3";
case PhysReg::A4:
return "a4";
case PhysReg::A5:
return "a5";
case PhysReg::A6:
return "a6";
case PhysReg::A7:
return "a7";
case PhysReg::S2:
return "s2";
case PhysReg::S3:
return "s3";
case PhysReg::S4:
return "s4";
case PhysReg::S5:
return "s5";
case PhysReg::S6:
return "s6";
case PhysReg::S7:
return "s7";
case PhysReg::S8:
return "s8";
case PhysReg::S9:
return "s9";
case PhysReg::S10:
return "s10";
case PhysReg::S11:
return "s11";
case PhysReg::T3:
return "t3";
case PhysReg::T4:
return "t4";
case PhysReg::T5:
return "t5";
case PhysReg::T6:
return "t6";
// 浮点寄存器
case PhysReg::FT0: return "ft0";
case PhysReg::FT1: return "ft1";
case PhysReg::FT2: return "ft2";
case PhysReg::FT3: return "ft3";
case PhysReg::FT4: return "ft4";
case PhysReg::FT5: return "ft5";
case PhysReg::FT6: return "ft6";
case PhysReg::FT7: return "ft7";
case PhysReg::FS0: return "fs0";
case PhysReg::FS1: return "fs1";
case PhysReg::FA0: return "fa0";
case PhysReg::FA1: return "fa1";
case PhysReg::FA2: return "fa2";
case PhysReg::FA3: return "fa3";
case PhysReg::FA4: return "fa4";
case PhysReg::FA5: return "fa5";
case PhysReg::FA6: return "fa6";
case PhysReg::FA7: return "fa7";
case PhysReg::FT8: return "ft8";
case PhysReg::FT9: return "ft9";
case PhysReg::FT10: return "ft10";
case PhysReg::FT11: return "ft11";
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
}
} // namespace mir
} // namespace mir

@ -1,5 +1,6 @@
#include "sem/func.h"
#include <cstring>
#include <stdexcept>
#include <string>
@ -7,6 +8,12 @@
namespace sem {
// Truncate double to float32 precision (mimics C float arithmetic)
static double ToFloat32(double v) {
float f = static_cast<float>(v);
return static_cast<double>(f);
}
// 编译时求值常量表达式
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) {
return EvaluateExp(*ctx.addExp());
@ -73,14 +80,65 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp()) {
return EvaluateExp(*ctx.exp()->addExp());
} else if (ctx.lVar()) {
// 处理变量引用(必须是已定义的常量)
// 处理变量引用:向上遍历 AST 找到对应的常量定义并求值
auto* ident = ctx.lVar()->Ident();
if (!ident) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
std::string name = ident->getText();
// 这里简化处理,实际应该在符号表中查找常量
// 暂时假设常量已经在前面被处理过
// 向上遍历 AST 找到作用域内的 constDef
antlr4::ParserRuleContext* scope =
dynamic_cast<antlr4::ParserRuleContext*>(ctx.lVar()->parent);
while (scope) {
// 检查当前作用域中的所有 constDecl
for (auto* tree_child : scope->children) {
auto* child = dynamic_cast<antlr4::ParserRuleContext*>(tree_child);
if (!child) continue;
auto* block_item = dynamic_cast<SysYParser::BlockItemContext*>(child);
if (block_item && block_item->decl()) {
auto* decl = block_item->decl();
if (decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
// compUnit 级别的 constDecl
auto* decl = dynamic_cast<SysYParser::DeclContext*>(child);
if (decl && decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
// If declared as int, truncate to integer
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
scope = dynamic_cast<antlr4::ParserRuleContext*>(scope->parent);
}
// 未找到常量定义,返回 0
ConstValue val;
val.is_int = true;
val.int_val = 0;
@ -94,11 +152,11 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
ConstValue val;
if (int_const) {
val.is_int = true;
val.int_val = std::stoll(int_const->getText());
val.int_val = std::stoll(int_const->getText(), nullptr, 0);
val.float_val = static_cast<double>(val.int_val);
} else if (float_const) {
val.is_int = false;
val.float_val = std::stod(float_const->getText());
val.float_val = ToFloat32(std::stod(float_const->getText()));
val.int_val = static_cast<long long>(val.float_val);
} else {
throw std::runtime_error(FormatError("sema", "非法数字字面量"));
@ -127,8 +185,9 @@ ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) +
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l + r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -143,8 +202,9 @@ ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) -
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l - r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -159,8 +219,9 @@ ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) *
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l * r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -175,8 +236,9 @@ ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) /
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l / r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;

@ -2,3 +2,41 @@
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为
#include<stdio.h>
#include"sylib.h"
/* Input & output functions */
int getint(){int t; scanf("%d",&t); return t; }
int getch(){char c; scanf("%c",&c); return (int)c; }
int getarray(int a[]){
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)scanf("%d",&a[i]);
return n;
}
void putint(int a){ printf("%d",a);}
void putch(int a){ printf("%c",a); }
void putarray(int n,int a[]){
printf("%d:",n);
for(int i=0;i<n;i++)printf(" %d",a[i]);
printf("\n");
}
float getfloat(){ float f; scanf("%f",&f); return f; }
void putfloat(float a){ printf("%a",a); }
int getfarray(float a[]){
int n; scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%a",&a[i]);
return n;
}
void putfarray(int n,float a[]){
printf("%d:",n);
for(int i=0;i<n;i++) printf(" %a",a[i]);
printf("\n");
}
static struct timeval _sysy_start, _sysy_end;
void starttime(){ gettimeofday(&_sysy_start, NULL); }
void stoptime(){
gettimeofday(&_sysy_end, NULL);
long us = (_sysy_end.tv_sec - _sysy_start.tv_sec) * 1000000L
+ (_sysy_end.tv_usec - _sysy_start.tv_usec);
fprintf(stderr, "Timer: %ldus\n", us);
}

@ -2,3 +2,20 @@
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明
#ifndef __SYLIB_H_
#define __SYLIB_H_
#include<stdio.h>
#include<stdarg.h>
#include<sys/time.h>
/* Input & output functions */
int getint(),getch(),getarray(int a[]);
void putint(int a),putch(int a),putarray(int n,int a[]);
float getfloat();
void putfloat(float a);
int getfarray(float a[]);
void putfarray(int n,float a[]);
/* Timing functions */
void starttime();
void stoptime();
#endif

@ -0,0 +1,14 @@
.text
.global main
.type main, @function
main:
addi sp, sp, -16
sw ra, 8(sp)
sw s0, 0(sp)
li a0, 42
lw ra, 8(sp)
lw s0, 0(sp)
addi sp, sp, 16
ret
.size main, .-main

@ -0,0 +1 @@
int main() { return 42; }

@ -0,0 +1,6 @@
int main() {
int a = 10;
int b = 20;
int c = a + b;
return c;
}

@ -0,0 +1,9 @@
// 基础算术运算测试
int main() {
int a = 10;
int b = 20;
int c = a + b; // 30
int d = c - 5; // 25
int e = d * 2; // 50
return e;
}

@ -0,0 +1,6 @@
// 除法测试
int main() {
int a = 100;
int b = 4;
return a / b; // 25
}

@ -0,0 +1,6 @@
int main() {
float a = 10;
float b = 3;
float c = a + b;
return (int)c;
}

@ -0,0 +1,7 @@
const float arr[3] = {1.1, 2.2, 3.3};
float carr[3] = {4.4, 5.5, 6.6};
int main() {
float sum = arr[0] + carr[0];
return (int)sum; // 5
}

@ -0,0 +1,8 @@
int main() {
float a = 5.0;
float b = 3.0;
int cmp1 = (a > b); // 1
int cmp2 = (a < b); // 0
int cmp3 = (a == 5.0); // 1
return cmp1 + cmp2 + cmp3; // 2
}

@ -0,0 +1,7 @@
int main() {
float a = (float)10;
float b = (float)3;
float c = a / b;
int d = c;
return d;
}

@ -0,0 +1,9 @@
const float pi = 3.14159;
float g = 2.71828;
int main() {
float a = pi;
float b = g;
float c = a + b;
return (int)c; // 5
}

@ -0,0 +1,6 @@
int main() {
float a = 10;
float b = 3;
float c = a * b;
return (int)c;
}

@ -0,0 +1,9 @@
int main() {
float a = 10.0;
float b = 3.0;
float add = a + b; // 13
float sub = a - b; // 7
float mul = a * b; // 30
float div = a / b; // 3.333...
return (int)(add + sub + mul + (int)div);
}

@ -0,0 +1,6 @@
int main() {
int a = 5;
int b = 10;
int c = a < b;
return c;
}

@ -0,0 +1,6 @@
int main() {
int a = 5;
int b = 10;
int c = a < b;
return c;
}

@ -0,0 +1,9 @@
int main() {
int a = 5;
int b = 10;
int c = 0;
if (a < b) {
c = 1;
}
return c;
}

@ -0,0 +1,13 @@
int main() {
int a = 10;
int b = 20;
int c = 0;
if (a < b) {
c = 100;
} else {
c = 200;
}
return c;
}

@ -0,0 +1,6 @@
// 取模测试
int main() {
int a = 17;
int b = 5;
return a % b; // 2
}

@ -0,0 +1,6 @@
// 乘法测试
int main() {
int a = 5;
int b = 10;
return a * b; // 50
}

@ -0,0 +1,18 @@
int main() {
int a = 5;
int b = 10;
int c = 15;
int result = 0;
if (a < b) {
if (b < c) {
result = 100;
} else {
result = 200;
}
} else {
result = 300;
}
return result;
}

@ -0,0 +1,3 @@
int main() {
return 42;
}

@ -0,0 +1,5 @@
int main() {
int a = 10;
int b = 3;
return a - b;
}

@ -0,0 +1,4 @@
int main() {
float a = 3.14;
return a;
}

@ -0,0 +1,7 @@
// 变量读写测试
int main() {
int x = 42;
int y = x;
int z = y + 1;
return z; // 43
}

@ -0,0 +1,11 @@
int main() {
int i = 0;
int sum = 0;
while (i < 10) {
sum = sum + i;
i = i + 1;
}
return sum;
}

@ -0,0 +1,9 @@
int main(){
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
const int N = 3;
int b[4][2] = {};
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int d[N + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
}

@ -0,0 +1,11 @@
int a;
int func(int p){
p = p - 1;
return p;
}
int main(){
int b;
a = 10;
b = func(a);
return b;
}

@ -0,0 +1,7 @@
//test add
int main(){
int a, b;
a = 10;
b = -1;
return a + b;
}

@ -0,0 +1,7 @@
//test sub
const int a = 10;
int main(){
int b;
b = 2;
return b - a;
}

@ -0,0 +1,69 @@
const int V = 4;
const int space = 32;
const int LF = 10;
void printSolution(int color[]) {
int i = 0;
while (i < V) {
putint(color[i]);
putch(space);
i = i + 1;
}
putch(LF);
}
void printMessage() {
putch(78);putch(111);putch(116);
putch(space);
putch(101);putch(120);putch(105);putch(115);putch(116);
}
int isSafe(int graph[][V], int color[]) {
int i = 0;
while (i < V) {
int j = i + 1;
while (j < V) {
if (graph[i][j] && color[j] == color[i])
return 0;
j = j + 1;
}
i = i + 1;
}
return 1;
}
int graphColoring(int graph[][V], int m, int i, int color[]) {
if (i == V) {
if (isSafe(graph, color)) {
printSolution(color);
return 1;
}
return 0;
}
int j = 1;
while (j <= m) {
color[i] = j;
if (graphColoring(graph, m, i + 1, color))
return 1;
color[i] = 0;
j = j + 1;
}
return 0;
}
int main() {
int graph[V][V] = {
{0, 1, 1, 1},
{1, 0, 1, 0},
{1, 1, 0, 1},
{1, 0, 1, 0}
}, m = 3;
int color[V], i = 0;
while (i < V) {
color[i] = 0;
i = i + 1;
}
if (!graphColoring(graph, m, 0, color))
printMessage();
return 0;
}

@ -0,0 +1,10 @@
4 4
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
4 3
9 5 1
10 6 2
11 7 3
12 8 4

@ -0,0 +1,5 @@
110 70 30
278 174 70
446 278 110
614 382 150
0

@ -0,0 +1,60 @@
const int MAX_SIZE = 100;
int a[MAX_SIZE][MAX_SIZE], b[MAX_SIZE][MAX_SIZE];
int res[MAX_SIZE][MAX_SIZE];
int n1, m1, n2, m2;
void matrix_multiply() {
int i = 0;
while (i < m1) {
int j = 0;
while (j < n2) {
int k = 0;
while (k < n1) {
res[i][j] = res[i][j] + a[i][k] * b[k][j];
k = k + 1;
}
j = j + 1;
}
i = i + 1;
}
}
int main()
{
int i, j;
m1 = getint();
n1 = getint();
i = 0;
while (i < m1) {
j = 0;
while (j < n1) {
a[i][j] = getint();
j = j + 1;
}
i = i + 1;
}
m2 = getint();
n2 = getint();
i = 0;
while (i < m2) {
j = 0;
while (j < n2) {
b[i][j] = getint();
j = j + 1;
}
i = i + 1;
}
matrix_multiply();
i = 0;
while (i < m1) {
j = 0;
while (j < n2) {
putint(res[i][j]);
putch(32);
j = j + 1;
}
putch(10);
i = i + 1;
}
return 0;
}

@ -0,0 +1 @@
int main() { /* scope test */ putch(97); putch(10); int a = 1, putch = 0; { a = a + 2; int b = a + 3; b = b + 4; putch = putch + a + b; { b = b + 5; int main = b + 6; a = a + main; putch = putch + a + b + main; { b = b + a; int a = main + 7; a = a + 8; putch = putch + a + b + main; { b = b + a; int b = main + 9; a = a + 10; const int a = 11; b = b + 12; putch = putch + a + b + main; { main = main + b; int main = b + 13; main = main + a; putch = putch + a + b + main; } putch = putch - main; } putch = putch - b; } putch = putch - a; } } return putch % 77; }

@ -0,0 +1,15 @@
//test break
int main(){
int i;
i = 0;
int sum;
sum = 0;
while(i < 100){
if(i == 50){
break;
}
sum = sum + i;
i = i + 1;
}
return sum;
}

@ -0,0 +1,9 @@
//test the priority of add and mul
int main(){
int a, b, c, d;
a = 10;
b = 4;
c = 2;
d = 2;
return (c + a) * (b - d);
}

@ -0,0 +1,13 @@
10
0x1.999999999999ap-4 0x1.999999999999ap-3 0x1.3333333333333p-2 0x1.999999999999ap-2 0x1.0000000000000p-1
0x1.3333333333333p-1 0x1.6666666666666p-1 0x1.999999999999ap-1 0x1.ccccccccccccdp-1 0x1.0000000000000p+0
0x1.199999999999ap+0
0x1.199999999999ap+1
0x1.a666666666666p+1
0x1.199999999999ap+2
0x1.6000000000000p+2
0x1.a666666666666p+2
0x1.ecccccccccccdp+2
0x1.199999999999ap+3
0x1.3cccccccccccdp+3
0x1.4333333333333p+3

@ -0,0 +1,19 @@
ok
ok
ok
ok
ok
ok
ok
ok
0x1.e691e6p+1 3
0x1.e691e6p+3 12
0x1.11b21p+5 28
0x1.e691e6p+5 50
0x1.7c21fcp+6 78
0x1.11b21p+7 113
0x1.7487b2p+7 153
0x1.e691e6p+7 201
0x1.33e85p+8 254
10: 0x1.333334p+0 0x1.333334p+1 0x1.ccccccp+1 0x1.333334p+2 0x1.8p+2 0x1.ccccccp+2 0x1.0cccccp+3 0x1.333334p+3 0x1.599998p+3 0x1p+0
0

@ -0,0 +1,98 @@
// float global constants
const float RADIUS = 5.5, PI = 03.141592653589793, EPS = 1e-6;
// hexadecimal float constant
const float PI_HEX = 0x1.921fb6p+1, HEX2 = 0x.AP-3;
// float constant evaluation
const float FACT = -.33E+5, EVAL1 = PI * RADIUS * RADIUS, EVAL2 = 2 * PI_HEX * RADIUS, EVAL3 = PI * 2 * RADIUS;
// float constant implicit conversion
const float CONV1 = 233, CONV2 = 0xfff;
const int MAX = 1e9, TWO = 2.9, THREE = 3.2, FIVE = TWO + THREE;
// float -> float function
float float_abs(float x) {
if (x < 0) return -x;
return x;
}
// int -> float function & float/int expression
float circle_area(int radius) {
return (PI * radius * radius + (radius * radius) * PI) / 2;
}
// float -> float -> int function & float/int expression
int float_eq(float a, float b) {
if (float_abs(a - b) < EPS) {
return 1 * 2. / 2;
} else {
return 0;
}
}
void error() {
putch(101);
putch(114);
putch(114);
putch(111);
putch(114);
putch(10);
}
void ok() {
putch(111);
putch(107);
putch(10);
}
void assert(int cond) {
if (!cond) {
error();
} else {
ok();
}
}
void assert_not(int cond) {
if (cond) {
error();
} else {
ok();
}
}
int main() {
assert_not(float_eq(HEX2, FACT));
assert_not(float_eq(EVAL1, EVAL2));
assert(float_eq(EVAL2, EVAL3));
assert(float_eq(circle_area(RADIUS) /* f->i implicit conversion */,
circle_area(FIVE)));
assert_not(float_eq(CONV1, CONV2) /* i->f implicit conversion */);
// float conditional expressions
if (1.5) ok();
if (!!3.3) ok();
if (.0 && 3) error();
if (0 || 0.3) ok();
// float array & I/O functions
int i = 1, p = 0;
float arr[10] = {1., 2};
int len = getfarray(arr);
while (i < MAX) {
float input = getfloat();
float area = PI * input * input, area_trunc = circle_area(input);
arr[p] = arr[p] + input;
putfloat(area);
putch(32);
putint(area_trunc); // f->i implicit conversion
putch(10);
i = i * - -1e1;
p = p + 1;
}
putfarray(len, arr);
return 0;
}

@ -0,0 +1,5 @@
int main() {
int a = 1;
int b = 2;
return a + b;
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,89 @@
const int N = 1024;
void mm(int n, int A[][N], int B[][N], int C[][N]){
int i, j, k;
i = 0; j = 0;
while (i < n){
j = 0;
while (j < n){
C[i][j] = 0;
j = j + 1;
}
i = i + 1;
}
i = 0; j = 0; k = 0;
while (k < n){
i = 0;
while (i < n){
if (A[i][k] == 0){
i = i + 1;
continue;
}
j = 0;
while (j < n){
C[i][j] = C[i][j] + A[i][k] * B[k][j];
j = j + 1;
}
i = i + 1;
}
k = k + 1;
}
}
int A[N][N];
int B[N][N];
int C[N][N];
int main(){
int n = getint();
int i, j;
i = 0;
j = 0;
while (i < n){
j = 0;
while (j < n){
A[i][j] = getint();
j = j + 1;
}
i = i + 1;
}
i = 0;
j = 0;
while (i < n){
j = 0;
while (j < n){
B[i][j] = getint();
j = j + 1;
}
i = i + 1;
}
starttime();
i = 0;
while (i < 5){
mm(n, A, B, C);
mm(n, A, C, B);
i = i + 1;
}
int ans = 0;
i = 0;
while (i < n){
j = 0;
while (j < n){
ans = ans + B[i][j];
j = j + 1;
}
i = i + 1;
}
stoptime();
putint(ans);
putch(10);
return 0;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

@ -0,0 +1,71 @@
int x;
const int N = 2010;
void mv(int n, int A[][N], int b[], int res[]){
int x, y;
y = 0;
x = 11;
int i, j;
i = 0;
while(i < n){
res[i] = 0;
i = i + 1;
}
i = 0;
j = 0;
while (i < n){
j = 0;
while (j < n){
if (A[i][j] == 0){
x = x * b[i] + b[j];
y = y - x;
}else{
res[i] = res[i] + A[i][j] * b[j];
}
j = j + 1;
}
i = i + 1;
}
}
int A[N][N];
int B[N];
int C[N];
int main(){
int n = getint();
int i, j;
i = 0;
while (i < n){
j = 0;
while (j < n){
A[i][j] = getint();
j = j + 1;
}
i = i + 1;
}
i = 0;
while (i < n){
B[i] = getint();
i = i + 1;
}
starttime();
i = 0;
while (i < 50){
mv(n, A, B, C);
mv(n, A, C, B);
i = i + 1;
}
stoptime();
putarray(n, C);
return 0;
}

File diff suppressed because one or more lines are too long

@ -0,0 +1,106 @@
const int base = 16;
int getMaxNum(int n, int arr[]){
int ret = 0;
int i = 0;
while (i < n){
if (arr[i] > ret) ret = arr[i];
i = i + 1;
}
return ret;
}
int getNumPos(int num, int pos){
int tmp = 1;
int i = 0;
while (i < pos){
num = num / base;
i = i + 1;
}
return num % base;
}
void radixSort(int bitround, int a[], int l, int r){
int head[base] = {};
int tail[base] = {};
int cnt[base] = {};
if (bitround == -1 || l + 1 >= r) return;
{
int i = l;
while (i < r){
cnt[getNumPos(a[i], bitround)]
= cnt[getNumPos(a[i], bitround)] + 1;
i = i + 1;
}
head[0] = l;
tail[0] = l + cnt[0];
i = 1;
while (i < base){
head[i] = tail[i - 1];
tail[i] = head[i] + cnt[i];
i = i + 1;
}
i = 0;
while (i < base){
while (head[i] < tail[i]){
int v = a[head[i]];
while (getNumPos(v, bitround) != i){
int t = v;
v = a[head[getNumPos(t, bitround)]];
a[head[getNumPos(t, bitround)]] = t;
head[getNumPos(t, bitround)] = head[getNumPos(t, bitround)] + 1;
}
a[head[i]] = v;
head[i] = head[i] + 1;
}
i = i + 1;
}
}
{
int i = l;
head[0] = l;
tail[0] = l + cnt[0];
i = 0;
while (i < base){
if (i > 0){
head[i] = tail[i - 1];
tail[i] = head[i] + cnt[i];
}
radixSort(bitround - 1, a, head[i], tail[i]);
i = i + 1;
}
}
return;
}
int a[30000010];
int ans;
int main(){
int n = getarray(a);
starttime();
radixSort(8, a, 0, n);
int i = 0;
while (i < n){
ans = ans + i * (a[i] % (2 + i));
i = i + 1;
}
if (ans < 0)
ans = -ans;
stoptime();
putint(ans);
putch(10);
return 0;
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,110 @@
int A[1024][1024];
int B[1024][1024];
int C[1024][1024];
int main() {
int T = getint(); // 矩阵规模
int R = getint(); // 重复次数
int i = 0;
while (i < T) {
if (i < T / 2) {
getarray(A[i]);
}
i = i + 1;
}
i = 0;
while (i < T) {
if (i >= T / 2) {
getarray(B[i]);
}
i = i + 1;
}
starttime();
i = 0;
while (i < T) {
if (i >= T / 2) {
int j = 0;
while (j < T) {
A[i][j] = -1;
j = j + 1;
}
}
i = i + 1;
}
i = 0;
while (i < T) {
if (i < T / 2) {
int j = 0;
while (j < T) {
B[i][j] = -1;
j = j + 1;
}
}
i = i + 1;
}
i = 0;
while (i < T) {
int j = 0;
while (j < T) {
C[i][j] = A[i][j] * 2 + B[i][j] * 3;
j = j + 1;
}
i = i + 1;
}
i = 0;
while (i < T) {
int j = 0;
while (j < T) {
int val = C[i][j];
val = val * val + 7;
val = val / 3;
C[i][j] = val;
j = j + 1;
}
i = i + 1;
}
i = 0;
while (i < T) {
int j = 0;
while (j < T) {
int k = 0;
int sum = 0;
while (k < T) {
sum = sum + C[i][k] * A[k][j];
k = k + 1;
}
A[i][j] = sum;
j = j + 1;
}
i = i + 1;
}
int total = 0;
int r = 0;
while (r < R) {
i = 0;
while (i < T) {
int j = 0;
while (j < T) {
total = total + A[i][j] * A[i][j];
j = j + 1;
}
i = i + 1;
}
r = r + 1;
}
stoptime();
putint(total);
putch(10);
return 0;
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -0,0 +1,82 @@
const int mod = 998244353;
int d;
int multiply(int a, int b){
if (b == 0) return 0;
if (b == 1) return a % mod;
int cur = multiply(a, b/2);
cur = (cur + cur) % mod;
if (b % 2 == 1) return (cur + a) % mod;
else return cur;
}
int power(int a, int b){
if (b == 0) return 1;
int cur = power(a, b/2);
cur = multiply(cur, cur);
if (b % 2 == 1) return multiply(cur, a);
else return cur;
}
const int maxlen = 2097152;
int temp[maxlen], a[maxlen], b[maxlen], c[maxlen];
int memmove(int dst[], int dst_pos, int src[], int len){
int i = 0;
while (i < len){
dst[dst_pos + i] = src[i];
i = i + 1;
}
return i;
}
int fft(int arr[], int begin_pos, int n, int w){
if (n == 1) return 1;
int i = 0;
while (i < n){
if (i % 2 == 0) temp[i / 2] = arr[i + begin_pos];
else temp[n / 2 + i / 2] = arr[i + begin_pos];
i = i + 1;
}
memmove(arr, begin_pos, temp, n);
fft(arr, begin_pos, n / 2, multiply(w, w));
fft(arr, begin_pos + n / 2, n / 2, multiply(w, w));
i = 0;
int wn = 1;
while (i < n / 2){
int x = arr[begin_pos + i];
int y = arr[begin_pos + i + n / 2];
arr[begin_pos + i] = (x + multiply(wn, y)) % mod;
arr[begin_pos + i + n / 2] = (x - multiply(wn, y) + mod) % mod;
wn = multiply(wn, w);
i = i + 1;
}
return 0;
}
int main(){
int n = getarray(a);
int m = getarray(b);
starttime();
d = 1;
while (d < n + m - 1){
d = d * 2;
}
fft(a, 0, d, power(3, (mod - 1) / d));
fft(b, 0, d, power(3, (mod - 1) / d));
int i = 0;
while (i < d){
a[i] = multiply(a[i], b[i]);
i = i + 1;
}
fft(a, 0, d, power(3, mod-1 - (mod-1)/d));
i = 0;
while (i < d){
a[i] = multiply(a[i], power(d, mod-2));
i = i + 1;
}
stoptime();
putarray(n + m - 1, a);
return 0;
}

@ -0,0 +1,51 @@
50 50 353434
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
.....................#..#.........................
.....................#..#.........................
...................##.##.##.......................
.....................#..#.........................
.....................#..#.........................
...................##.##.##.......................
.....................#..#.........................
.....................#..#.........................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................
..................................................

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save