Compare commits

..

1 Commits

Author SHA1 Message Date
mxr e79d677644 fix(sem)解决常量数组初始化错误问题
2 months ago

3
.gitignore vendored

@ -73,5 +73,4 @@ test/test_result/
# mxr
# =========================
result.txt
build.sh
gdb.sh
build.sh

@ -109,25 +109,18 @@ class Context {
std::string NextTemp();
private:
// 数组常量缓存需要添加到类中
struct ArrayKey {
std::shared_ptr<ArrayType> type;
std::vector<ConstantValue*> elements;
bool operator==(const ArrayKey& other) const;
};
struct ArrayKeyHash {
size_t operator()(const ArrayKey& key) const;
};
std::unordered_map<ArrayKey, std::unique_ptr<ConstantArray>, ArrayKeyHash> array_cache_;
// 其他现有成员...
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
// 浮点常量:使用整数表示浮点数位模式作为键(避免浮点精度问题)
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
// 零常量缓存(按类型指针)
std::unordered_map<Type*, std::unique_ptr<ConstantZero>> zero_constants_;
std::unordered_map<Type*, std::unique_ptr<ConstantAggregateZero>> aggregate_zeros_;
// 数组常量简单存储,不去重(因为数组常量通常组合多样,去重成本高)
std::vector<std::unique_ptr<ConstantArray>> const_arrays_;
int temp_index_ = -1;
};
@ -364,18 +357,18 @@ class User : public Value {
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
// ir/IR.h - GlobalValue 定义需要添加这些方法
// ir/IR.h - 修正 GlobalValue 定义
// ir/IR.h - 修正 GlobalValue 定义
class GlobalValue : public User {
private:
std::vector<ConstantValue*> initializer_;
bool is_constant_ = false;
bool is_extern_ = false;
std::vector<ConstantValue*> initializer_; // 初始化值列表
bool is_constant_ = false; // 是否为常量如const变量
bool is_extern_ = false; // 是否为外部声明
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
// 初始化器相关
// 初始化器相关 - 使用 ConstantValue*
void SetInitializer(ConstantValue* init);
void SetInitializer(const std::vector<ConstantValue*>& init);
const std::vector<ConstantValue*>& GetInitializer() const { return initializer_; }
@ -389,28 +382,17 @@ public:
void SetExtern(bool is_extern) { is_extern_ = is_extern; }
bool IsExtern() const { return is_extern_; }
// 类型判断
// 类型判断 - 使用 Type 的方法
bool IsArray() const { return GetType()->IsArray(); }
bool IsScalar() const { return GetType()->IsInt32() || GetType()->IsFloat(); }
// 数组常量相关方法
bool IsArrayConstant() const;
ConstantValue* GetArrayElement(size_t index) const;
size_t GetArraySize() const;
// 获取数组大小(如果是数组类型)
int GetArraySizeInElements() const {
int GetArraySize() const {
if (auto* array_ty = dynamic_cast<ArrayType*>(GetType().get())) {
return array_ty->GetElementCount();
}
return 0;
}
private:
// 辅助方法
std::shared_ptr<Type> GetValueType() const;
bool CheckTypeCompatibility(std::shared_ptr<Type> value_type,
ConstantValue* init) const;
};
class Instruction : public User {
@ -760,20 +742,6 @@ class BasicBlock : public Value {
return ptr;
}
template <typename T, typename... Args>
T* InsertBeforeTerminator(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
auto pos = instructions_.end();
if (HasTerminator()) {
pos = instructions_.end() - 1;
}
instructions_.insert(pos, std::move(inst));
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
@ -844,7 +812,6 @@ class IRBuilder {
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaFloat(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);

@ -7,23 +7,12 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <iostream>
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/Sema.h"
//#define DEBUG_IRGen
#ifdef DEBUG_IRGen
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[IRGen Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace ir {
class Module;
class Function;
@ -33,10 +22,7 @@ class Value;
class IRGenImpl final : public SysYBaseVisitor {
public:
// 修改构造函数,添加 SymbolTable 参数
IRGenImpl(ir::Module& module,
const SemanticContext& sema,
const SymbolTable& sym_table); // 新增
IRGenImpl(ir::Module& module, const SemanticContext& sema);
// 顶层
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
@ -81,21 +67,9 @@ public:
ir::Value* EvalCond(SysYParser::CondContext& cond);
std::any visitCallExp(SysYParser::UnaryExpContext* ctx);
std::vector<ir::Value*> ProcessNestedInitVals(SysYParser::InitValContext* ctx);
// 带维度感知的展平:按 C 语言花括号对齐规则填充 total_size 个槽位
// dims[0] 是最外层维度dims.back() 是最内层维度(元素层)
// 返回已展平并补零的 total_size 大小的向量
std::vector<ir::Value*> FlattenInitVal(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
bool is_float);
int TryEvaluateConstInt(SysYParser::ConstExpContext* ctx);
void AddRuntimeFunctions();
ir::Function* CreateRuntimeFunctionDecl(const std::string& funcName);
ir::BasicBlock* EnsureCleanupBlock();
void RegisterCleanup(ir::Function* free_func, ir::Value* ptr);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name);
ir::AllocaInst* CreateEntryAllocaI32(const std::string& name);
ir::AllocaInst* CreateEntryAllocaFloat(const std::string& name);
private:
// 辅助函数声明
enum class BlockFlow{
@ -134,7 +108,6 @@ private:
ir::Module& module_;
const SemanticContext& sema_;
const SymbolTable& symbol_table_; // 新增成员
ir::Function* func_;
ir::IRBuilder builder_;
ir::Value* EvalAssign(SysYParser::StmtContext* ctx);
@ -146,8 +119,6 @@ private:
std::unordered_map<std::string, ir::Value*> local_var_map_; // 局部变量
std::unordered_map<std::string, ir::GlobalValue*> global_map_; // 全局变量
std::unordered_map<std::string, ir::Value*> param_map_; // 函数参数
std::unordered_set<std::string> pointer_param_names_; // 指针/数组形参名
std::unordered_set<std::string> heap_local_array_names_; // 堆分配的局部数组名
// 常量映射:常量名 -> 常量值(标量常量)
std::unordered_map<std::string, ir::ConstantValue*> const_value_map_;
@ -160,23 +131,21 @@ private:
std::unordered_map<SysYParser::VarDefContext*, ArrayInfo> array_info_map_;
std::string current_function_name_;
bool current_function_is_recursive_ = false;
ir::AllocaInst* function_return_slot_ = nullptr;
ir::BasicBlock* function_cleanup_block_ = nullptr;
std::vector<std::pair<ir::Function*, ir::Value*>> function_cleanup_actions_;
// 新增:处理全局和局部变量的辅助函数
// 修改处理函数的签名,使用 Symbol* 参数
std::any HandleGlobalVariable(SysYParser::VarDefContext* ctx,
const std::string& varName,
const Symbol* sym);
const std::string& varName,
bool is_array);
std::any HandleLocalVariable(SysYParser::VarDefContext* ctx,
const std::string& varName,
const Symbol* sym);
const std::string& varName,
bool is_array);
// 常量求值辅助函数
int EvaluateConstAddExp(SysYParser::AddExpContext* ctx);
int EvaluateConstMulExp(SysYParser::MulExpContext* ctx);
int EvaluateConstUnaryExp(SysYParser::UnaryExpContext* ctx);
int EvaluateConstPrimaryExp(SysYParser::PrimaryExpContext* ctx);
int EvaluateConstExp(SysYParser::ExpContext* ctx);
};
// 修改 GenerateIR 函数签名
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemaResult& sema_result);
const SemanticContext& sema);

@ -19,161 +19,41 @@ class MIRContext {
MIRContext& DefaultContext();
enum class PhysReg {
//W0, W8, W9, X29, X30, SP
// 32位通用寄存器
W0, W1, W2, W3, W4, W5, W6, W7, // 参数传递/临时
W8, W9, // 临时寄存器(当前主要使用)
W10, W11, W12, W13, W14, W15, // 临时寄存器(扩展)
W16, W17, // intra-procedure-call 临时
W18, // 平台预留
W19, W20, W21, W22, W23, W24, // 被调用者保存(扩展用)
W25, W26, W27, W28, // 被调用者保存
W29, // 帧指针 (FP)
W30, // 链接寄存器 (LR)
// 64位版本
X0, X1, X2, X3, X4, X5, X6, X7,
X8, X9, X10, X11, X12, X13, X14, X15,
X16, X17, X18,
X19, X20, X21, X22, X23, X24, X25, X26, X27, X28,
X29, // FP
X30, // LR
// 浮点寄存器 (32位)
S0, S1, S2, S3, S4, S5, S6, S7,
S8, S9, S10, S11, S12, S13, S14, S15,
S16, S17, S18, S19, S20, S21, S22, S23,
S24, S25, S26, S27, S28, S29, S30, S31,
// 特殊寄存器
SP, // 栈指针
ZR, // 零寄存器
};
enum class PhysReg { W0, W8, W9, X29, X30, SP };
const char* PhysRegName(PhysReg reg);
// ========== 条件码枚举(用于 BCond 指令)==========
enum class CondCode {
EQ, // 相等 (equal)
NE, // 不等 (not equal)
CS, // 进位设置 (carry set) / 无符号大于等于
CC, // 进位清除 (carry clear) / 无符号小于
MI, // 负数 (minus)
PL, // 非负数 (plus)
VS, // 溢出 (overflow set)
VC, // 无溢出 (overflow clear)
HI, // 无符号大于 (higher)
LS, // 无符号小于等于 (lower or same)
GE, // 有符号大于等于 (greater or equal)
LT, // 有符号小于 (less than)
GT, // 有符号大于 (greater than)
LE, // 有符号小于等于 (less or equal)
AL, // 总是 (always)
};
const char* CondCodeName(CondCode cc);
// ========== MIR 指令操作码枚举 ==========
enum class Opcode {
// ---------- 栈帧相关 ----------
Prologue, // 函数序言(伪指令)
Epilogue, // 函数尾声(伪指令)
// ---------- 数据传输 ----------
MovImm, // 立即数移动到寄存器: MOV w8, #imm
MovReg, // 寄存器之间移动: MOV w8, w9
LoadStack, // 从栈槽加载: LDR w8, [sp, #offset]
StoreStack, // 存储到栈槽: STR w8, [sp, #offset]
LoadStackPair,// 成对加载: LDP x29, x30, [sp], #16
StoreStackPair,// 成对存储: STP x29, x30, [sp, #-16]!
// ---------- 整数算术运算 ----------
AddRR, // 加法: ADD w8, w8, w9
AddRI, // 加法(立即数): ADD w8, w8, #imm
SubRR, // 减法: SUB w8, w8, w9
SubRI, // 减法(立即数): SUB w8, w8, #imm
MulRR, // 乘法: MUL w8, w8, w9
SDivRR, // 有符号除法: SDIV w8, w8, w9
UDivRR, // 无符号除法: UDIV w8, w8, w9
// ---------- 浮点算术运算 ----------
FAddRR, // 浮点加法: FADD s0, s0, s1
FSubRR, // 浮点减法: FSUB s0, s0, s1
FMulRR, // 浮点乘法: FMUL s0, s0, s1
FDivRR, // 浮点除法: FDIV s0, s0, s1
// ---------- 比较运算 ----------
CmpRR, // 比较(寄存器): CMP w8, w9
CmpRI, // 比较(立即数): CMP w8, #imm
FCmpRR, // 浮点比较: FCMP s0, s1
// ---------- 类型转换 ----------
SIToFP, // 有符号整数转浮点: SCVTF s0, w0
FPToSI, // 浮点转有符号整数: FCVTZS w0, s0
ZExt, // 零扩展i1 -> i32: AND w8, w8, #1
// ---------- 控制流 ----------
B, // 无条件跳转: B label
BCond, // 条件跳转: B.EQ label, B.NE label, B.GT label 等
Call, // 函数调用: BL target
Ret, // 函数返回: RET
// ---------- 逻辑运算 ----------
AndRR, // 按位与: AND w8, w8, w9
OrRR, // 按位或: ORR w8, w8, w9
EorRR, // 按位异或: EOR w8, w8, w9
LslRR, // 逻辑左移: LSL w8, w8, w9
LsrRR, // 逻辑右移: LSR w8, w8, w9
AsrRR, // 算术右移: ASR w8, w8, w9
// ---------- 特殊 ----------
Nop, // 空操作: NOP
Label, // 内联标签,不生成实际指令,仅输出标签名
// 添加
Movk, // movk Rd, #imm16, lsl #shift
// 添加
LoadStackAddr, // 将栈帧地址加载到寄存器 (add xd, sp, #offset)
// 用于全局变量地址计算
Adrp, // ADRP Xd, label
AddLabel, // ADD Xd, Xn, :lo12:label
// 新增
Sxtw, // 符号扩展字到双字sxtw Xd, Wn
Prologue,
Epilogue,
MovImm,
LoadStack,
StoreStack,
AddRR,
Ret,
};
// ========== 操作数类 ==========
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex, Cond, Label };
enum class Kind { Reg, Imm, FrameIndex };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand Cond(CondCode cc);
static Operand Label(const std::string& label);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
CondCode GetCondCode() const { return cc_; }
const std::string& GetLabel() const { return label_; }
private:
Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label);
Operand(Kind kind, PhysReg reg, int imm);
Kind kind_;
PhysReg reg_;
int imm_;
CondCode cc_;
std::string label_;
};
// ========== MIR 指令类 ==========
class MachineInstr {
public:
MachineInstr(Opcode opcode, std::vector<Operand> operands = {});
@ -186,14 +66,12 @@ class MachineInstr {
std::vector<Operand> operands_;
};
// ========== 栈槽结构 ==========
struct FrameSlot {
int index = 0;
int size = 4;
int offset = 0;
};
// ========== MIR 基本块 ==========
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
@ -204,133 +82,38 @@ class MachineBasicBlock {
MachineInstr& Append(Opcode opcode,
std::initializer_list<Operand> operands = {});
MachineInstr& Append(Opcode opcode, std::vector<Operand> operands);
// 控制流信息
std::vector<MachineBasicBlock*>& GetSuccessors() { return successors_; }
const std::vector<MachineBasicBlock*>& GetSuccessors() const { return successors_; }
void AddSuccessor(MachineBasicBlock* succ) { successors_.push_back(succ); }
private:
std::string name_;
std::vector<MachineInstr> instructions_;
std::vector<MachineBasicBlock*> successors_;
};
// ========== MIR 函数 ==========
class MachineFunction {
public:
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
// 基本块管理
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBasicBlocks() {
return basic_blocks_;
}
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBasicBlocks() const {
return basic_blocks_;
}
void AddBasicBlock(std::unique_ptr<MachineBasicBlock> bb) {
basic_blocks_.push_back(std::move(bb));
}
// 栈槽管理
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
std::vector<FrameSlot>& GetFrameSlots() { return frame_slots_; }
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
// 栈帧大小
int GetFrameSize() const { return frame_size_; }
void SetFrameSize(int size) { frame_size_ = size; }
private:
std::string name_;
MachineBasicBlock entry_;
std::vector<std::unique_ptr<MachineBasicBlock>> basic_blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
// ========== MIR 模块 ==========
class MachineModule {
public:
MachineModule() = default;
// 添加 MachineFunction
void AddFunction(std::unique_ptr<MachineFunction> func) {
functions_.push_back(std::move(func));
}
// 获取所有函数
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const {
return functions_;
}
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() {
return functions_;
}
// 根据名称查找函数
MachineFunction* GetFunction(const std::string& name) {
for (auto& func : functions_) {
if (func->GetName() == name) {
return func.get();
}
}
return nullptr;
}
const MachineFunction* GetFunction(const std::string& name) const {
for (const auto& func : functions_) {
if (func->GetName() == name) {
return func.get();
}
}
return nullptr;
}
struct GlobalDecl {
std::string name;
int size; // 字节大小
int alignment; // 对齐要求(通常为 4 或 8
bool is_zero_init; // 是否为零初始化
bool has_init_data; // 是否包含初始化数据(用于标量常量)
uint64_t init_data; // 初始化数据≤8字节
// 构造函数,默认零初始化
GlobalDecl(const std::string& n, int sz, int align, bool zero = true,
bool has_data = false, uint64_t data = 0)
: name(n), size(sz), alignment(align), is_zero_init(zero),
has_init_data(has_data), init_data(data) {}
};
void AddGlobal(const std::string& name, int size, int alignment,
bool is_zero_init = true,
bool has_init_data = false, uint64_t init_data = 0) {
globals_.emplace_back(name, size, alignment, is_zero_init,
has_init_data, init_data);
}
const std::vector<GlobalDecl>& GetGlobals() const { return globals_; }
private:
std::vector<std::unique_ptr<MachineFunction>> functions_;
std::vector<GlobalDecl> globals_;
};
// ========== 后端流程函数 ==========
/* std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os); */
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
void PrintAsm(const MachineFunction& function, std::ostream& os);
} // namespace mir

@ -7,7 +7,7 @@
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/SymbolTable.h"
// 表达式信息结构
struct ExprInfo {
std::shared_ptr<ir::Type> type = nullptr;
@ -91,12 +91,4 @@ private:
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
// SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
// 新增:语义分析结果结构体
struct SemaResult {
SemanticContext context;
SymbolTable symbol_table;
};
// 修改 RunSema 的返回类型
SemaResult RunSema(SysYParser::CompUnitContext& comp_unit);
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,7 +1,6 @@
// 极简符号表:记录局部变量定义点。
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
@ -18,113 +17,46 @@ enum class SymbolKind {
Constant
};
// 符号条目
// 符号条目
struct Symbol {
// 基本信息
std::string name;
SymbolKind kind;
std::shared_ptr<ir::Type> type;
int scope_level = 0;
int stack_offset = -1;
bool is_initialized = false;
bool is_builtin = false;
std::shared_ptr<ir::Type> type; // 指向 Type 对象的智能指针
int scope_level = 0; // 定义时的作用域深度
int stack_offset = -1; // 局部变量/参数栈偏移(全局变量为 -1
bool is_initialized = false; // 是否已初始化
bool is_builtin = false; // 是否为库函数
// 数组参数相关
std::vector<int> array_dims;
bool is_array_param = false;
// 对于数组参数,存储维度信息
std::vector<int> array_dims; // 数组各维长度参数数组的第一维可能为0表示省略
bool is_array_param = false; // 是否是数组参数
// 函数相关
// 对于函数,额外存储参数列表(类型已包含在函数类型中,这里仅用于快速访问)
std::vector<std::shared_ptr<ir::Type>> param_types;
// 常量值存储
// 对于常量,存储常量值(这里支持 int32 和 float
union ConstantValue {
int i32;
float f32;
};
// 标量常量
bool is_int_const = true;
ConstantValue const_value;
// 数组常量(扁平化存储)
bool is_array_const = false;
std::vector<ConstantValue> array_const_values;
} const_value;
bool is_int_const = true; // 标记常量类型,用于区分 int 和 float
// 语法树节点
// 关联的语法树节点(用于报错位置或进一步分析)
SysYParser::VarDefContext* var_def_ctx = nullptr;
SysYParser::ConstDefContext* const_def_ctx = nullptr;
SysYParser::FuncFParamContext* param_def_ctx = nullptr;
SysYParser::FuncDefContext* func_def_ctx = nullptr;
// 辅助方法
bool IsScalarConstant() const {
return kind == SymbolKind::Constant && !type->IsArray();
}
bool IsArrayConstant() const {
return kind == SymbolKind::Constant && type->IsArray();
}
int GetIntConstant() const {
if (!IsScalarConstant()) {
throw std::runtime_error("不是标量常量");
}
if (!is_int_const) {
throw std::runtime_error("不是整型常量");
}
return const_value.i32;
}
float GetFloatConstant() const {
if (!IsScalarConstant()) {
throw std::runtime_error("不是标量常量");
}
if (is_int_const) {
return static_cast<float>(const_value.i32);
}
return const_value.f32;
}
ConstantValue GetArrayElement(size_t index) const {
if (!IsArrayConstant()) {
throw std::runtime_error("不是数组常量");
}
if (index >= array_const_values.size()) {
throw std::runtime_error("数组下标越界");
}
return array_const_values[index];
}
size_t GetArraySize() const {
if (!IsArrayConstant()) return 0;
return array_const_values.size();
}
};
class SymbolTable {
public:
SymbolTable();
~SymbolTable() = default;
// 添加调试方法
size_t getScopeCount() const { return active_scope_stack_.size(); }
void dump() const {
std::cerr << "=== SymbolTable Dump ===" << std::endl;
for (size_t i = 0; i < scopes_.size(); ++i) {
std::cerr << "Scope " << i << " (depth=" << i << ")";
bool active = std::find(active_scope_stack_.begin(), active_scope_stack_.end(), i) != active_scope_stack_.end();
std::cerr << (active ? " [active]" : " [inactive]") << std::endl;
for (const auto& [name, sym] : scopes_[i]) {
std::cerr << " " << name
<< " (kind=" << (int)sym.kind
<< ", level=" << sym.scope_level << ")" << std::endl;
}
}
}
// ----- 作用域管理 -----
void enterScope(); // 进入新作用域
void exitScope(); // 退出当前作用域
int currentScopeLevel() const { return static_cast<int>(active_scope_stack_.size()) - 1; }
int currentScopeLevel() const { return static_cast<int>(scopes_.size()) - 1; }
// ----- 符号操作(推荐使用)-----
bool addSymbol(const Symbol& sym); // 添加符号到当前作用域
@ -132,9 +64,6 @@ class SymbolTable {
Symbol* lookupCurrent(const std::string& name); // 仅在当前作用域查找
const Symbol* lookup(const std::string& name) const;
const Symbol* lookupCurrent(const std::string& name) const;
const Symbol* lookupAll(const std::string& name) const; // 所有作用域查找,包括已结束的作用域
const Symbol* lookupByVarDef(const SysYParser::VarDefContext* decl) const; // 通过定义节点查找符号
const Symbol* lookupByConstDef(const SysYParser::ConstDefContext* decl) const; // 通过常量定义节点查找符号
// ----- 与原接口兼容(保留原有功能)-----
void Add(const std::string& name, SysYParser::VarDefContext* decl);
@ -161,9 +90,17 @@ class SymbolTable {
float float_val;
};
};
void flattenInit(SysYParser::ConstInitValContext* ctx,
std::vector<ConstValue>& out,
std::shared_ptr<ir::Type> base_type) const;
void fillArray(
std::vector<ConstValue>& values,
size_t& index,
SysYParser::ConstInitValContext* ctx,
const std::vector<int>& dims,
size_t dim_idx,
std::shared_ptr<ir::Type> base_type) const;
void fillZero(std::vector<ConstValue>& values, size_t& index,
const std::vector<int>& dims, size_t dim_idx,
std::shared_ptr<ir::Type> base_type) const;
std::vector<ConstValue> EvaluateConstInitVal(
SysYParser::ConstInitValContext* ctx,
const std::vector<int>& dims,
@ -174,7 +111,6 @@ class SymbolTable {
private:
// 作用域栈:每个元素是一个从名字到符号的映射
std::vector<std::unordered_map<std::string, Symbol>> scopes_;
std::vector<size_t> active_scope_stack_;
static constexpr int GLOBAL_SCOPE = 0; // 全局作用域索引

Binary file not shown.

@ -1,492 +0,0 @@
; ModuleID = 'optimized.bc'
source_filename = "./build/test_compiler/performance/03_sort1.ll"
@a = global [30000010 x i32] zeroinitializer
@ans = local_unnamed_addr global i32 0
declare i32 @getarray(ptr) local_unnamed_addr
declare void @putint(i32) local_unnamed_addr
declare void @putch(i32) local_unnamed_addr
declare void @starttime() local_unnamed_addr
declare void @stoptime() local_unnamed_addr
declare ptr @sysy_alloc_i32(i32) local_unnamed_addr
declare void @sysy_free_i32(ptr) local_unnamed_addr
declare void @sysy_zero_i32(ptr, i32) local_unnamed_addr
; Function Attrs: nofree norecurse nosync nounwind memory(argmem: read)
define i32 @getMaxNum(i32 %n, ptr nocapture readonly %arr) local_unnamed_addr #0 {
entry:
%t95 = icmp sgt i32 %n, 0
br i1 %t95, label %while.body.t5, label %while.exit.t6
while.body.t5: ; preds = %entry, %while.body.t5
%t3_i.07 = phi i32 [ %t21, %while.body.t5 ], [ 0, %entry ]
%t2_ret.06 = phi i32 [ %spec.select, %while.body.t5 ], [ 0, %entry ]
%0 = zext nneg i32 %t3_i.07 to i64
%t13 = getelementptr i32, ptr %arr, i64 %0
%t14 = load i32, ptr %t13, align 4
%spec.select = tail call i32 @llvm.smax.i32(i32 %t14, i32 %t2_ret.06)
%t21 = add nuw nsw i32 %t3_i.07, 1
%t9 = icmp slt i32 %t21, %n
br i1 %t9, label %while.body.t5, label %while.exit.t6
while.exit.t6: ; preds = %while.body.t5, %entry
%t2_ret.0.lcssa = phi i32 [ 0, %entry ], [ %spec.select, %while.body.t5 ]
ret i32 %t2_ret.0.lcssa
}
; Function Attrs: nofree norecurse nosync nounwind memory(none)
define i32 @getNumPos(i32 %num, i32 %pos) local_unnamed_addr #1 {
entry:
%t333 = icmp sgt i32 %pos, 0
br i1 %t333, label %while.body.t29, label %while.exit.t30
while.body.t29: ; preds = %entry, %while.body.t29
%t27_i.05 = phi i32 [ %t37, %while.body.t29 ], [ 0, %entry ]
%t24.04 = phi i32 [ %t35, %while.body.t29 ], [ %num, %entry ]
%t35 = sdiv i32 %t24.04, 16
%t37 = add nuw nsw i32 %t27_i.05, 1
%t33 = icmp slt i32 %t37, %pos
br i1 %t33, label %while.body.t29, label %while.exit.t30
while.exit.t30: ; preds = %while.body.t29, %entry
%t24.0.lcssa = phi i32 [ %num, %entry ], [ %t35, %while.body.t29 ]
%t39 = srem i32 %t24.0.lcssa, 16
ret i32 %t39
}
define void @radixSort(i32 %bitround, ptr nocapture %a, i32 %l, i32 %r) local_unnamed_addr {
entry:
%t43 = tail call ptr @sysy_alloc_i32(i32 16)
tail call void @sysy_zero_i32(ptr %t43, i32 16)
%t46 = tail call ptr @sysy_alloc_i32(i32 16)
tail call void @sysy_zero_i32(ptr %t46, i32 16)
%t48 = tail call ptr @sysy_alloc_i32(i32 16)
tail call void @sysy_zero_i32(ptr %t48, i32 16)
%t54 = icmp eq i32 %bitround, -1
%t56 = add i32 %l, 1
%t58 = icmp sge i32 %t56, %r
%t59 = or i1 %t54, %t58
br i1 %t59, label %cleanup.t44, label %while.cond.t62.preheader
while.cond.t62.preheader: ; preds = %entry
%t6796 = icmp slt i32 %l, %r
br i1 %t6796, label %while.body.t63.lr.ph, label %while.exit.t64
while.body.t63.lr.ph: ; preds = %while.cond.t62.preheader
%t333.i = icmp sgt i32 %bitround, 0
br label %while.body.t63
cleanup.t44: ; preds = %merge.t196, %entry
tail call void @sysy_free_i32(ptr %t48)
tail call void @sysy_free_i32(ptr %t46)
tail call void @sysy_free_i32(ptr %t43)
ret void
while.body.t63: ; preds = %while.body.t63.lr.ph, %getNumPos.exit28
%storemerge97 = phi i32 [ %l, %while.body.t63.lr.ph ], [ %t83, %getNumPos.exit28 ]
%0 = sext i32 %storemerge97 to i64
%t69 = getelementptr i32, ptr %a, i64 %0
%t70 = load i32, ptr %t69, align 4
br i1 %t333.i, label %while.body.t29.i, label %getNumPos.exit.thread
getNumPos.exit.thread: ; preds = %while.body.t63
%t39.i80 = srem i32 %t70, 16
%1 = sext i32 %t39.i80 to i64
%t7381 = getelementptr i32, ptr %t48, i64 %1
%t7482 = load i32, ptr %t7381, align 4
br label %getNumPos.exit28
while.body.t29.i: ; preds = %while.body.t63, %while.body.t29.i
%t27_i.05.i = phi i32 [ %t37.i, %while.body.t29.i ], [ 0, %while.body.t63 ]
%t24.04.i = phi i32 [ %t35.i, %while.body.t29.i ], [ %t70, %while.body.t63 ]
%t35.i = sdiv i32 %t24.04.i, 16
%t37.i = add nuw nsw i32 %t27_i.05.i, 1
%t33.i = icmp slt i32 %t37.i, %bitround
br i1 %t33.i, label %while.body.t29.i, label %getNumPos.exit
getNumPos.exit: ; preds = %while.body.t29.i
%t39.i = srem i32 %t35.i, 16
%2 = sext i32 %t39.i to i64
%t73 = getelementptr i32, ptr %t48, i64 %2
%t74 = load i32, ptr %t73, align 4
br label %while.body.t29.i22
while.body.t29.i22: ; preds = %getNumPos.exit, %while.body.t29.i22
%t27_i.05.i23 = phi i32 [ %t37.i26, %while.body.t29.i22 ], [ 0, %getNumPos.exit ]
%t24.04.i24 = phi i32 [ %t35.i25, %while.body.t29.i22 ], [ %t70, %getNumPos.exit ]
%t35.i25 = sdiv i32 %t24.04.i24, 16
%t37.i26 = add nuw nsw i32 %t27_i.05.i23, 1
%t33.i27 = icmp slt i32 %t37.i26, %bitround
br i1 %t33.i27, label %while.body.t29.i22, label %getNumPos.exit28.loopexit
getNumPos.exit28.loopexit: ; preds = %while.body.t29.i22
%.pre114 = srem i32 %t35.i25, 16
%.pre115 = sext i32 %.pre114 to i64
br label %getNumPos.exit28
getNumPos.exit28: ; preds = %getNumPos.exit28.loopexit, %getNumPos.exit.thread
%.pre-phi116 = phi i64 [ %.pre115, %getNumPos.exit28.loopexit ], [ %1, %getNumPos.exit.thread ]
%t7584.in = phi i32 [ %t74, %getNumPos.exit28.loopexit ], [ %t7482, %getNumPos.exit.thread ]
%t7584 = add i32 %t7584.in, 1
%t81 = getelementptr i32, ptr %t48, i64 %.pre-phi116
store i32 %t7584, ptr %t81, align 4
%t83 = add nsw i32 %storemerge97, 1
%t67 = icmp slt i32 %t83, %r
br i1 %t67, label %while.body.t63, label %while.exit.t64
while.exit.t64: ; preds = %getNumPos.exit28, %while.cond.t62.preheader
store i32 %l, ptr %t43, align 4
%t88 = load i32, ptr %t48, align 4
%t89 = add i32 %t88, %l
store i32 %t89, ptr %t46, align 4
%invariant.gep = getelementptr i32, ptr %t46, i64 -1
%t101 = getelementptr i32, ptr %t43, i64 1
store i32 %t89, ptr %t101, align 4
%t106 = getelementptr i32, ptr %t48, i64 1
%t107 = load i32, ptr %t106, align 4
%t108 = add i32 %t107, %t89
%t110 = getelementptr i32, ptr %t46, i64 1
store i32 %t108, ptr %t110, align 4
%t101.1 = getelementptr i32, ptr %t43, i64 2
store i32 %t108, ptr %t101.1, align 4
%t106.1 = getelementptr i32, ptr %t48, i64 2
%t107.1 = load i32, ptr %t106.1, align 4
%t108.1 = add i32 %t107.1, %t108
%t110.1 = getelementptr i32, ptr %t46, i64 2
store i32 %t108.1, ptr %t110.1, align 4
%t101.2 = getelementptr i32, ptr %t43, i64 3
store i32 %t108.1, ptr %t101.2, align 4
%t106.2 = getelementptr i32, ptr %t48, i64 3
%t107.2 = load i32, ptr %t106.2, align 4
%t108.2 = add i32 %t107.2, %t108.1
%t110.2 = getelementptr i32, ptr %t46, i64 3
store i32 %t108.2, ptr %t110.2, align 4
%t101.3 = getelementptr i32, ptr %t43, i64 4
store i32 %t108.2, ptr %t101.3, align 4
%t106.3 = getelementptr i32, ptr %t48, i64 4
%t107.3 = load i32, ptr %t106.3, align 4
%t108.3 = add i32 %t107.3, %t108.2
%t110.3 = getelementptr i32, ptr %t46, i64 4
store i32 %t108.3, ptr %t110.3, align 4
%t101.4 = getelementptr i32, ptr %t43, i64 5
store i32 %t108.3, ptr %t101.4, align 4
%t106.4 = getelementptr i32, ptr %t48, i64 5
%t107.4 = load i32, ptr %t106.4, align 4
%t108.4 = add i32 %t107.4, %t108.3
%t110.4 = getelementptr i32, ptr %t46, i64 5
store i32 %t108.4, ptr %t110.4, align 4
%t101.5 = getelementptr i32, ptr %t43, i64 6
store i32 %t108.4, ptr %t101.5, align 4
%t106.5 = getelementptr i32, ptr %t48, i64 6
%t107.5 = load i32, ptr %t106.5, align 4
%t108.5 = add i32 %t107.5, %t108.4
%t110.5 = getelementptr i32, ptr %t46, i64 6
store i32 %t108.5, ptr %t110.5, align 4
%t101.6 = getelementptr i32, ptr %t43, i64 7
store i32 %t108.5, ptr %t101.6, align 4
%t106.6 = getelementptr i32, ptr %t48, i64 7
%t107.6 = load i32, ptr %t106.6, align 4
%t108.6 = add i32 %t107.6, %t108.5
%t110.6 = getelementptr i32, ptr %t46, i64 7
store i32 %t108.6, ptr %t110.6, align 4
%t101.7 = getelementptr i32, ptr %t43, i64 8
store i32 %t108.6, ptr %t101.7, align 4
%t106.7 = getelementptr i32, ptr %t48, i64 8
%t107.7 = load i32, ptr %t106.7, align 4
%t108.7 = add i32 %t107.7, %t108.6
%t110.7 = getelementptr i32, ptr %t46, i64 8
store i32 %t108.7, ptr %t110.7, align 4
%t101.8 = getelementptr i32, ptr %t43, i64 9
store i32 %t108.7, ptr %t101.8, align 4
%t106.8 = getelementptr i32, ptr %t48, i64 9
%t107.8 = load i32, ptr %t106.8, align 4
%t108.8 = add i32 %t107.8, %t108.7
%t110.8 = getelementptr i32, ptr %t46, i64 9
store i32 %t108.8, ptr %t110.8, align 4
%t101.9 = getelementptr i32, ptr %t43, i64 10
store i32 %t108.8, ptr %t101.9, align 4
%t106.9 = getelementptr i32, ptr %t48, i64 10
%t107.9 = load i32, ptr %t106.9, align 4
%t108.9 = add i32 %t107.9, %t108.8
%t110.9 = getelementptr i32, ptr %t46, i64 10
store i32 %t108.9, ptr %t110.9, align 4
%t101.10 = getelementptr i32, ptr %t43, i64 11
store i32 %t108.9, ptr %t101.10, align 4
%t106.10 = getelementptr i32, ptr %t48, i64 11
%t107.10 = load i32, ptr %t106.10, align 4
%t108.10 = add i32 %t107.10, %t108.9
%t110.10 = getelementptr i32, ptr %t46, i64 11
store i32 %t108.10, ptr %t110.10, align 4
%t101.11 = getelementptr i32, ptr %t43, i64 12
store i32 %t108.10, ptr %t101.11, align 4
%t106.11 = getelementptr i32, ptr %t48, i64 12
%t107.11 = load i32, ptr %t106.11, align 4
%t108.11 = add i32 %t107.11, %t108.10
%t110.11 = getelementptr i32, ptr %t46, i64 12
store i32 %t108.11, ptr %t110.11, align 4
%t101.12 = getelementptr i32, ptr %t43, i64 13
store i32 %t108.11, ptr %t101.12, align 4
%t106.12 = getelementptr i32, ptr %t48, i64 13
%t107.12 = load i32, ptr %t106.12, align 4
%t108.12 = add i32 %t107.12, %t108.11
%t110.12 = getelementptr i32, ptr %t46, i64 13
store i32 %t108.12, ptr %t110.12, align 4
%t101.13 = getelementptr i32, ptr %t43, i64 14
store i32 %t108.12, ptr %t101.13, align 4
%t106.13 = getelementptr i32, ptr %t48, i64 14
%t107.13 = load i32, ptr %t106.13, align 4
%t108.13 = add i32 %t107.13, %t108.12
%t110.13 = getelementptr i32, ptr %t46, i64 14
store i32 %t108.13, ptr %t110.13, align 4
%t101.14 = getelementptr i32, ptr %t43, i64 15
store i32 %t108.13, ptr %t101.14, align 4
%t106.14 = getelementptr i32, ptr %t48, i64 15
%t107.14 = load i32, ptr %t106.14, align 4
%t108.14 = add i32 %t107.14, %t108.13
%t110.14 = getelementptr i32, ptr %t46, i64 15
store i32 %t108.14, ptr %t110.14, align 4
%t333.i29 = icmp sgt i32 %bitround, 0
br label %while.cond.t118.preheader
while.cond.t118.preheader: ; preds = %while.exit.t64, %while.exit.t120
%storemerge17104 = phi i32 [ 0, %while.exit.t64 ], [ %t180, %while.exit.t120 ]
%3 = zext nneg i32 %storemerge17104 to i64
%t122 = getelementptr i32, ptr %t43, i64 %3
%t125 = getelementptr i32, ptr %t46, i64 %3
%t123100 = load i32, ptr %t122, align 4
%t126101 = load i32, ptr %t125, align 4
%t127102 = icmp slt i32 %t123100, %t126101
br i1 %t127102, label %while.body.t119, label %while.exit.t120
while.body.t191.peel.next: ; preds = %while.exit.t120
store i32 %l, ptr %t43, align 4
%t187 = load i32, ptr %t48, align 4
%t188 = add i32 %t187, %l
store i32 %t188, ptr %t46, align 4
%t215 = add i32 %bitround, -1
%t218.peel.pre = load i32, ptr %t43, align 4
tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.peel.pre, i32 %t188)
br label %merge.t196
while.body.t119: ; preds = %while.cond.t118.preheader, %while.exit.t136
%t123103 = phi i32 [ %t176, %while.exit.t136 ], [ %t123100, %while.cond.t118.preheader ]
%4 = sext i32 %t123103 to i64
%t132 = getelementptr i32, ptr %a, i64 %4
%t133 = load i32, ptr %t132, align 4
br label %while.cond.t134
while.exit.t120: ; preds = %while.exit.t136, %while.cond.t118.preheader
%t180 = add nuw nsw i32 %storemerge17104, 1
%t117 = icmp ult i32 %storemerge17104, 15
br i1 %t117, label %while.cond.t118.preheader, label %while.body.t191.peel.next
while.cond.t134: ; preds = %getNumPos.exit78, %while.body.t119
%t15099 = phi i32 [ %t150129, %getNumPos.exit78 ], [ %t133, %while.body.t119 ]
br i1 %t333.i29, label %while.body.t29.i32, label %getNumPos.exit38.thread
while.body.t29.i32: ; preds = %while.cond.t134, %while.body.t29.i32
%t27_i.05.i33 = phi i32 [ %t37.i36, %while.body.t29.i32 ], [ 0, %while.cond.t134 ]
%t24.04.i34 = phi i32 [ %t35.i35, %while.body.t29.i32 ], [ %t15099, %while.cond.t134 ]
%t35.i35 = sdiv i32 %t24.04.i34, 16
%t37.i36 = add nuw nsw i32 %t27_i.05.i33, 1
%t33.i37 = icmp slt i32 %t37.i36, %bitround
br i1 %t33.i37, label %while.body.t29.i32, label %getNumPos.exit38
getNumPos.exit38: ; preds = %while.body.t29.i32
%t39.i31 = srem i32 %t35.i35, 16
%t141.not = icmp eq i32 %t39.i31, %storemerge17104
br i1 %t141.not, label %while.exit.t136, label %while.body.t29.i42
getNumPos.exit38.thread: ; preds = %while.cond.t134
%t39.i3186 = srem i32 %t15099, 16
%t141.not87 = icmp eq i32 %t39.i3186, %storemerge17104
br i1 %t141.not87, label %while.exit.t136, label %getNumPos.exit48.thread
getNumPos.exit48.thread: ; preds = %getNumPos.exit38.thread
%5 = sext i32 %t39.i3186 to i64
%t147125 = getelementptr i32, ptr %t43, i64 %5
%t148126 = load i32, ptr %t147125, align 4
%6 = sext i32 %t148126 to i64
%t149127 = getelementptr i32, ptr %a, i64 %6
%t150128 = load i32, ptr %t149127, align 4
br label %getNumPos.exit68.thread
while.body.t29.i42: ; preds = %getNumPos.exit38, %while.body.t29.i42
%t27_i.05.i43 = phi i32 [ %t37.i46, %while.body.t29.i42 ], [ 0, %getNumPos.exit38 ]
%t24.04.i44 = phi i32 [ %t35.i45, %while.body.t29.i42 ], [ %t15099, %getNumPos.exit38 ]
%t35.i45 = sdiv i32 %t24.04.i44, 16
%t37.i46 = add nuw nsw i32 %t27_i.05.i43, 1
%t33.i47 = icmp slt i32 %t37.i46, %bitround
br i1 %t33.i47, label %while.body.t29.i42, label %getNumPos.exit48
getNumPos.exit48: ; preds = %while.body.t29.i42
%.pre117 = srem i32 %t35.i45, 16
%7 = sext i32 %.pre117 to i64
%t147 = getelementptr i32, ptr %t43, i64 %7
%t148 = load i32, ptr %t147, align 4
%8 = sext i32 %t148 to i64
%t149 = getelementptr i32, ptr %a, i64 %8
%t150 = load i32, ptr %t149, align 4
br i1 %t333.i29, label %while.body.t29.i52, label %getNumPos.exit68.thread
while.body.t29.i52: ; preds = %getNumPos.exit48, %while.body.t29.i52
%t27_i.05.i53 = phi i32 [ %t37.i56, %while.body.t29.i52 ], [ 0, %getNumPos.exit48 ]
%t24.04.i54 = phi i32 [ %t35.i55, %while.body.t29.i52 ], [ %t15099, %getNumPos.exit48 ]
%t35.i55 = sdiv i32 %t24.04.i54, 16
%t37.i56 = add nuw nsw i32 %t27_i.05.i53, 1
%t33.i57 = icmp slt i32 %t37.i56, %bitround
br i1 %t33.i57, label %while.body.t29.i52, label %while.body.t29.i62.preheader
while.body.t29.i62.preheader: ; preds = %while.body.t29.i52
%t39.i51 = srem i32 %t35.i55, 16
%9 = sext i32 %t39.i51 to i64
%t155 = getelementptr i32, ptr %t43, i64 %9
%t156 = load i32, ptr %t155, align 4
%10 = sext i32 %t156 to i64
%t157 = getelementptr i32, ptr %a, i64 %10
store i32 %t15099, ptr %t157, align 4
br label %while.body.t29.i62
getNumPos.exit68.thread: ; preds = %getNumPos.exit48, %getNumPos.exit48.thread
%t150130 = phi i32 [ %t150128, %getNumPos.exit48.thread ], [ %t150, %getNumPos.exit48 ]
%t39.i51.c = srem i32 %t15099, 16
%11 = sext i32 %t39.i51.c to i64
%t155.c = getelementptr i32, ptr %t43, i64 %11
%t156.c = load i32, ptr %t155.c, align 4
%12 = sext i32 %t156.c to i64
%t157.c = getelementptr i32, ptr %a, i64 %12
store i32 %t15099, ptr %t157.c, align 4
%t16191 = getelementptr i32, ptr %t43, i64 %11
%t16292 = load i32, ptr %t16191, align 4
br label %getNumPos.exit78
while.body.t29.i62: ; preds = %while.body.t29.i62.preheader, %while.body.t29.i62
%t27_i.05.i63 = phi i32 [ %t37.i66, %while.body.t29.i62 ], [ 0, %while.body.t29.i62.preheader ]
%t24.04.i64 = phi i32 [ %t35.i65, %while.body.t29.i62 ], [ %t15099, %while.body.t29.i62.preheader ]
%t35.i65 = sdiv i32 %t24.04.i64, 16
%t37.i66 = add nuw nsw i32 %t27_i.05.i63, 1
%t33.i67 = icmp slt i32 %t37.i66, %bitround
br i1 %t33.i67, label %while.body.t29.i62, label %getNumPos.exit68
getNumPos.exit68: ; preds = %while.body.t29.i62
%t39.i61 = srem i32 %t35.i65, 16
%13 = sext i32 %t39.i61 to i64
%t161 = getelementptr i32, ptr %t43, i64 %13
%t162 = load i32, ptr %t161, align 4
br label %while.body.t29.i72
while.body.t29.i72: ; preds = %getNumPos.exit68, %while.body.t29.i72
%t27_i.05.i73 = phi i32 [ %t37.i76, %while.body.t29.i72 ], [ 0, %getNumPos.exit68 ]
%t24.04.i74 = phi i32 [ %t35.i75, %while.body.t29.i72 ], [ %t15099, %getNumPos.exit68 ]
%t35.i75 = sdiv i32 %t24.04.i74, 16
%t37.i76 = add nuw nsw i32 %t27_i.05.i73, 1
%t33.i77 = icmp slt i32 %t37.i76, %bitround
br i1 %t33.i77, label %while.body.t29.i72, label %getNumPos.exit78.loopexit
getNumPos.exit78.loopexit: ; preds = %while.body.t29.i72
%.pre118 = srem i32 %t35.i75, 16
%.pre119 = sext i32 %.pre118 to i64
br label %getNumPos.exit78
getNumPos.exit78: ; preds = %getNumPos.exit78.loopexit, %getNumPos.exit68.thread
%t150129 = phi i32 [ %t150, %getNumPos.exit78.loopexit ], [ %t150130, %getNumPos.exit68.thread ]
%.pre-phi120 = phi i64 [ %.pre119, %getNumPos.exit78.loopexit ], [ %11, %getNumPos.exit68.thread ]
%t16394.in = phi i32 [ %t162, %getNumPos.exit78.loopexit ], [ %t16292, %getNumPos.exit68.thread ]
%t16394 = add i32 %t16394.in, 1
%t167 = getelementptr i32, ptr %t43, i64 %.pre-phi120
store i32 %t16394, ptr %t167, align 4
br label %while.cond.t134
while.exit.t136: ; preds = %getNumPos.exit38.thread, %getNumPos.exit38
%t171 = load i32, ptr %t122, align 4
%14 = sext i32 %t171 to i64
%t172 = getelementptr i32, ptr %a, i64 %14
store i32 %t15099, ptr %t172, align 4
%t175 = load i32, ptr %t122, align 4
%t176 = add i32 %t175, 1
store i32 %t176, ptr %t122, align 4
%t126 = load i32, ptr %t125, align 4
%t127 = icmp slt i32 %t176, %t126
br i1 %t127, label %while.body.t119, label %while.exit.t120
merge.t196: ; preds = %while.body.t191.peel.next, %merge.t196
%storemerge18107 = phi i32 [ 1, %while.body.t191.peel.next ], [ %t224, %merge.t196 ]
%15 = zext nneg i32 %storemerge18107 to i64
%gep106 = getelementptr i32, ptr %invariant.gep, i64 %15
%t202 = load i32, ptr %gep106, align 4
%t204 = getelementptr i32, ptr %t43, i64 %15
store i32 %t202, ptr %t204, align 4
%t209 = getelementptr i32, ptr %t48, i64 %15
%t210 = load i32, ptr %t209, align 4
%t211 = add i32 %t210, %t202
%t213 = getelementptr i32, ptr %t46, i64 %15
store i32 %t211, ptr %t213, align 4
%t218.pre = load i32, ptr %t204, align 4
tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.pre, i32 %t211)
%t224 = add nuw nsw i32 %storemerge18107, 1
%t194 = icmp ult i32 %storemerge18107, 15
br i1 %t194, label %merge.t196, label %cleanup.t44, !llvm.loop !0
}
define noundef i32 @main() local_unnamed_addr {
entry:
%t231 = tail call i32 @getarray(ptr nonnull @a)
tail call void @starttime()
tail call void @radixSort(i32 8, ptr nonnull @a, i32 0, i32 %t231)
%ans.promoted = load i32, ptr @ans, align 4
%t2427 = icmp sgt i32 %t231, 0
br i1 %t2427, label %while.body.t238, label %while.exit.t239
while.body.t238: ; preds = %entry, %while.body.t238
%t236_i.09 = phi i32 [ %t254, %while.body.t238 ], [ 0, %entry ]
%t25268 = phi i32 [ %t252, %while.body.t238 ], [ %ans.promoted, %entry ]
%0 = zext nneg i32 %t236_i.09 to i64
%t246 = getelementptr [30000010 x i32], ptr @a, i64 0, i64 %0
%t247 = load i32, ptr %t246, align 4
%t249 = add nuw i32 %t236_i.09, 2
%t250 = srem i32 %t247, %t249
%t251 = mul i32 %t250, %t236_i.09
%t252 = add i32 %t251, %t25268
%t254 = add nuw nsw i32 %t236_i.09, 1
%t242 = icmp slt i32 %t254, %t231
br i1 %t242, label %while.body.t238, label %while.cond.t237.while.exit.t239_crit_edge
while.cond.t237.while.exit.t239_crit_edge: ; preds = %while.body.t238
store i32 %t252, ptr @ans, align 4
br label %while.exit.t239
while.exit.t239: ; preds = %while.cond.t237.while.exit.t239_crit_edge, %entry
%t257 = phi i32 [ %t252, %while.cond.t237.while.exit.t239_crit_edge ], [ %ans.promoted, %entry ]
%t258 = icmp slt i32 %t257, 0
br i1 %t258, label %then.t255, label %merge.t256
then.t255: ; preds = %while.exit.t239
%t260 = sub i32 0, %t257
store i32 %t260, ptr @ans, align 4
br label %merge.t256
merge.t256: ; preds = %then.t255, %while.exit.t239
tail call void @stoptime()
%t262 = load i32, ptr @ans, align 4
tail call void @putint(i32 %t262)
tail call void @putch(i32 10)
ret i32 0
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.smax.i32(i32, i32) #2
attributes #0 = { nofree norecurse nosync nounwind memory(argmem: read) }
attributes #1 = { nofree norecurse nosync nounwind memory(none) }
attributes #2 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!0 = distinct !{!0, !1}
!1 = !{!"llvm.loop.peeled.count", i32 1}

File diff suppressed because it is too large Load Diff

@ -4,13 +4,7 @@ set -euo pipefail
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
COMPILER="$ROOT_DIR/build/bin/compiler"
TMP_DIR="$ROOT_DIR/build/test_compiler"
RESULT_BASE_DIR="$ROOT_DIR/test/test_result"
TEST_DIRS=("$ROOT_DIR/test/test_case/functional" "$ROOT_DIR/test/test_case/performance")
CC_BIN="${CC:-cc}"
RUNTIME_SRC="$ROOT_DIR/sylib/sylib.c"
RUNTIME_OBJ="$TMP_DIR/sylib.o"
LLC_BIN="${LLC:-llc}"
CLANG_BIN="${CLANG:-clang}"
if [[ ! -x "$COMPILER" ]]; then
echo "未找到编译器: $COMPILER"
@ -20,30 +14,6 @@ fi
mkdir -p "$TMP_DIR"
if ! command -v "$LLC_BIN" >/dev/null 2>&1; then
echo "未找到 llc: $LLC_BIN"
echo "请安装 LLVM或通过 LLC 环境变量指定 llc 路径"
exit 1
fi
if ! command -v "$CLANG_BIN" >/dev/null 2>&1; then
echo "未找到 clang: $CLANG_BIN"
echo "请安装 Clang或通过 CLANG 环境变量指定 clang 路径"
exit 1
fi
# 编译运行库(供链接生成的可执行文件)
runtime_ready=0
if [[ -f "$RUNTIME_SRC" ]]; then
if "$CC_BIN" -c "$RUNTIME_SRC" -o "$RUNTIME_OBJ" >/dev/null 2>&1; then
runtime_ready=1
else
echo "[WARN] 运行库编译失败,生成的可执行文件将不链接 sylib: $RUNTIME_SRC"
fi
else
echo "[WARN] 未找到运行库源码: $RUNTIME_SRC"
fi
ir_total=0
ir_pass=0
result_total=0
@ -67,17 +37,7 @@ for test_dir in "${TEST_DIRS[@]}"; do
ir_total=$((ir_total+1))
base=$(basename "$input")
stem=${base%.sy}
case "$(basename "$test_dir")" in
functional)
out_dir="$RESULT_BASE_DIR/functional/ir"
;;
performance)
out_dir="$RESULT_BASE_DIR/performance/ir"
;;
*)
out_dir="$RESULT_BASE_DIR/$(basename "$test_dir")"
;;
esac
out_dir="$TMP_DIR/$(basename "$test_dir")"
mkdir -p "$out_dir"
ll_file="$out_dir/$stem.ll"
stdout_file="$out_dir/$stem.stdout"
@ -105,15 +65,10 @@ for test_dir in "${TEST_DIRS[@]}"; do
continue
fi
# 从混杂输出中提取 IR
# - 顶层实体define/declare/@global
# - 基本块标签
# - 缩进的指令行
# - 函数结束花括号
grep -E '^(define |declare |@|[[:space:]]|})|^[A-Za-z_.$%][A-Za-z0-9_.$%]*:$' "$raw_ll" > "$ll_file"
# 检查是否生成了有效函数定义
if ! grep -qE '^define ' "$ll_file"; then
# 检查是否生成了有效的函数定义(在过滤后的内容中检查)
# 先过滤一下看看是否有define
filtered_content=$(sed -E '/^\[DEBUG\]|^SymbolTable::|^Check|^绑定|^保存|^dim_count:/d' "$raw_ll")
if ! echo "$filtered_content" | grep -qE '^define '; then
echo " [IR] 失败: 未生成有效函数定义"
ir_failures+=("$input: invalid IR output")
# 失败:保留原始输出
@ -121,7 +76,17 @@ for test_dir in "${TEST_DIRS[@]}"; do
rm -f "$raw_ll"
continue
fi
# 编译成功过滤掉所有调试输出只保留IR
# 过滤规则:
# 1. 以 [DEBUG] 开头的行
# 2. SymbolTable:: 开头的行
# 3. CheckLValue: 开头的行
# 4. 绑定变量: 开头的行
# 5. dim_count: 开头的行
# 6. 空行(可选)
sed -E '/^(\[DEBUG|SymbolTable::|Check|绑定|保存|dim_)/d' "$raw_ll" > "$ll_file"
# 可选:删除多余的空行
sed -i '/^$/N;/\n$/D' "$ll_file"
@ -131,72 +96,30 @@ for test_dir in "${TEST_DIRS[@]}"; do
echo " [IR] 生成成功 (IR已保存到: $ll_file)"
# 运行测试
# 运行测试部分
if [[ -f "$expected_file" ]]; then
result_total=$((result_total+1))
# 运行生成的可执行文件(优先链接运行库)
# 运行LLVM IR
run_status=0
obj_file="$out_dir/$stem.o"
exe_file="$out_dir/$stem"
if ! "$LLC_BIN" -filetype=obj "$ll_file" -o "$obj_file" > "$stdout_file" 2>&1; then
echo " [RUN] llc 失败"
result_failures+=("$input: llc failed")
continue
fi
if [[ $runtime_ready -eq 1 ]]; then
if ! "$CLANG_BIN" "$obj_file" "$RUNTIME_OBJ" -o "$exe_file" >> "$stdout_file" 2>&1; then
echo " [RUN] clang 链接失败"
result_failures+=("$input: clang link failed")
continue
fi
else
if ! "$CLANG_BIN" "$obj_file" -o "$exe_file" >> "$stdout_file" 2>&1; then
echo " [RUN] clang 链接失败"
result_failures+=("$input: clang link failed")
continue
fi
fi
if [[ -f "$stdin_file" ]]; then
"$exe_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$?
lli "$ll_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$?
else
"$exe_file" > "$stdout_file" 2>&1 || run_status=$?
lli "$ll_file" > "$stdout_file" 2>&1 || run_status=$?
fi
# 读取预期文件内容
expected_content=$(normalize_file "$expected_file")
# 读取预期返回值
expected=$(normalize_file "$expected_file")
# 判断预期文件是只包含退出码,还是包含输出+退出码
if [[ "$expected_content" =~ ^[0-9]+$ ]]; then
# 只包含退出码
expected=$expected_content
if [[ "$run_status" -eq "$expected" ]]; then
result_pass=$((result_pass+1))
echo " [RUN] 返回值匹配: $run_status"
rm -f "$stdout_file"
else
echo " [RUN] 返回值不匹配: got $run_status, expected $expected"
result_failures+=("$input: exit code mismatch (got $run_status, expected $expected)")
fi
# 比较返回值
if [[ "$run_status" -eq "$expected" ]]; then
result_pass=$((result_pass+1))
echo " [RUN] 返回值匹配: $run_status"
# 成功:保留已清理的.ll文件删除输出文件
rm -f "$stdout_file"
else
# 包含输出和退出码(最后一行是退出码)
expected_output=$(head -n -1 <<< "$expected_content")
expected_exit=$(tail -n 1 <<< "$expected_content")
actual_output=$(cat "$stdout_file")
if [[ "$run_status" -eq "$expected_exit" ]] && [[ "$actual_output" == "$expected_output" ]]; then
result_pass=$((result_pass+1))
echo " [RUN] 成功: 退出码和输出都匹配"
rm -f "$stdout_file"
else
echo " [RUN] 不匹配: 退出码 got $run_status, expected $expected_exit"
if [[ "$actual_output" != "$expected_output" ]]; then
echo " 输出不匹配"
fi
result_failures+=("$input: mismatch")
fi
echo " [RUN] 返回值不匹配: got $run_status, expected $expected"
result_failures+=("$input: exit code mismatch (got $run_status, expected $expected)")
# 失败:.ll文件已经保留输出文件也保留用于调试
fi
else
echo " [RUN] 未找到预期返回值文件 $expected_file,跳过结果验证"

@ -1,167 +0,0 @@
#!/bin/bash
# 测试脚本(支持动态输出目录)
# 用法: ./run_tests.sh [--verbose] [--mode <模式>] [<单个测试文件>]
# 模式: parse-tree, ir, asm, run (默认 run)
COMPILER="./build/bin/compiler"
TEST_SCRIPT="./scripts/verify_asm.sh"
TEST_DIR="test/test_case"
RESULT_FILE="result.txt"
VERBOSE=0
MODE="run" # 默认模式为完整测试
# 解析参数
while [[ $# -gt 0 ]]; do
case "$1" in
--verbose|-v)
VERBOSE=1
shift
;;
--mode)
MODE="$2"
shift 2
;;
--emit-parse-tree)
MODE="parse-tree"
shift
;;
--emit-ir)
MODE="ir"
shift
;;
--emit-asm)
MODE="asm"
shift
;;
--run)
MODE="run"
shift
;;
-*)
echo "未知选项: $1"
exit 1
;;
*)
# 非选项参数视为测试文件
SINGLE_FILE="$1"
shift
;;
esac
done
# 检查编译器是否存在
if [ ! -f "$COMPILER" ]; then
echo "错误: 编译器未找到于 $COMPILER"
echo "请先完成项目构建 (cmake 和 make)"
exit 1
fi
# 如果是 run 模式,检查测试脚本是否存在
if [ "$MODE" = "run" ] && [ ! -f "$TEST_SCRIPT" ]; then
echo "错误: 测试脚本未找到于 $TEST_SCRIPT"
exit 1
fi
# 如果指定了单个文件,检查文件是否存在
if [ -n "$SINGLE_FILE" ] && [ ! -f "$SINGLE_FILE" ]; then
echo "错误: 文件 $SINGLE_FILE 不存在"
exit 1
fi
# 清空(或创建)结果文件
> "$RESULT_FILE"
# 计数器
total=0
passed=0
failed=0
echo "开始测试 (模式: $MODE)..."
echo "输出将保存到 $RESULT_FILE"
echo "------------------------"
# 确定测试文件列表
if [ -n "$SINGLE_FILE" ]; then
TEST_FILES=("$SINGLE_FILE")
else
mapfile -t TEST_FILES < <(find "$TEST_DIR" -type f -name "*.sy" | sort)
fi
# 根据模式执行测试
for file in "${TEST_FILES[@]}"; do
((total++))
if [ $VERBOSE -eq 1 ]; then
echo "测试文件: $file"
else
echo -n "测试 $file ... "
fi
echo "========== $file ==========" >> "$RESULT_FILE"
# 根据模式构建命令
case "$MODE" in
parse-tree)
cmd="$COMPILER --emit-parse-tree \"$file\""
;;
ir)
cmd="$COMPILER --emit-ir \"$file\""
;;
asm)
cmd="$COMPILER --emit-asm \"$file\""
;;
run)
# 根据输入文件所在子目录确定输出目录
# 提取 test/test_case/ 之后的子目录名functional 或 performance
rel_path="${file#$TEST_DIR/}"
subdir=$(echo "$rel_path" | cut -d'/' -f1)
if [[ "$subdir" != "functional" && "$subdir" != "performance" ]]; then
echo "警告: 未知子目录 $subdir,使用默认输出目录" >> "$RESULT_FILE"
out_dir="test/test_result/asm"
else
out_dir="test/test_result/$subdir/asm"
fi
cmd="$TEST_SCRIPT \"$file\" \"$out_dir\" --run"
;;
*)
echo "未知模式: $MODE" >&2
exit 1
;;
esac
if [ $VERBOSE -eq 1 ]; then
eval "$cmd" 2>&1 | tee -a "$RESULT_FILE"
result=${PIPESTATUS[0]}
else
eval "$cmd" >> "$RESULT_FILE" 2>&1
result=$?
fi
echo "" >> "$RESULT_FILE"
if [ $result -eq 0 ]; then
if [ $VERBOSE -eq 0 ]; then
echo "通过"
fi
((passed++))
else
if [ $VERBOSE -eq 0 ]; then
echo "失败"
else
echo ">>> 测试失败: $file"
fi
((failed++))
fi
done
echo "------------------------"
echo "总计: $total"
echo "通过: $passed"
echo "失败: $failed"
echo "详细输出已保存至 $RESULT_FILE"
if [ $failed -gt 0 ]; then
exit 1
else
exit 0
fi

@ -52,8 +52,7 @@ expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file"
# 静态链接
aarch64-linux-gnu-gcc -no-pie "$asm_file" -L./sylib -lsysy -static -o "$exe"
aarch64-linux-gnu-gcc "$asm_file" -o "$exe"
echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then
@ -84,7 +83,7 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -w -u "$expected_file" "$actual_file"; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2

@ -1,8 +1,9 @@
// 管理基础类型、整型常量池和临时名生成。
// ir/IR.cpp
#include "ir/IR.h"
#include <cstring>
#include <cstring> // for memcpy
#include <sstream>
#include <functional>
namespace ir {
@ -16,7 +17,9 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get();
}
// 新增:获取浮点常量
ConstantFloat* Context::GetConstFloat(float v) {
// 使用浮点数的二进制表示作为键,避免精度问题
uint32_t key;
std::memcpy(&key, &v, sizeof(float));
@ -32,68 +35,16 @@ ConstantFloat* Context::GetConstFloat(float v) {
return ptr;
}
// 新增:创建数组常量
ConstantArray* Context::GetConstArray(std::shared_ptr<ArrayType> ty,
std::vector<ConstantValue*> elements) {
// 验证数组常量
size_t expected_size = ty->GetElementCount();
if (elements.size() != expected_size) {
// 如果元素数量不匹配,可能需要补零或报错
// 这里根据需求处理
if (elements.size() < expected_size) {
// 补零
auto elem_type = ty->GetElementType();
while (elements.size() < expected_size) {
if (elem_type->IsInt32()) {
elements.push_back(GetConstInt(0));
} else if (elem_type->IsFloat()) {
elements.push_back(GetConstFloat(0.0f));
}
}
} else {
throw std::runtime_error("Array constant size mismatch");
}
}
// 构建缓存键
struct ArrayKey {
std::shared_ptr<ArrayType> type;
std::vector<ConstantValue*> elements;
bool operator==(const ArrayKey& other) const {
if (type != other.type) return false;
if (elements.size() != other.elements.size()) return false;
for (size_t i = 0; i < elements.size(); ++i) {
if (elements[i] != other.elements[i]) return false;
}
return true;
}
};
struct ArrayKeyHash {
size_t operator()(const ArrayKey& key) const {
size_t hash = std::hash<Type*>{}(key.type.get());
for (auto* elem : key.elements) {
hash ^= std::hash<ConstantValue*>{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
}
return hash;
}
};
// 使用静态缓存(需要作为成员变量)
static std::unordered_map<ArrayKey, std::unique_ptr<ConstantArray>, ArrayKeyHash> cache;
ArrayKey key{ty, elements};
auto it = cache.find(key);
if (it != cache.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantArray>(ty, std::move(elements));
auto* ptr = constant.get();
cache[std::move(key)] = std::move(constant);
const_arrays_.push_back(std::move(constant));
return ptr;
}
// 新增:获取零常量
ConstantZero* Context::GetZeroConstant(std::shared_ptr<Type> ty) {
auto it = zero_constants_.find(ty.get());
if (it != zero_constants_.end()) {
@ -106,6 +57,7 @@ ConstantZero* Context::GetZeroConstant(std::shared_ptr<Type> ty) {
return ptr;
}
// 新增:获取聚合类型的零常量
ConstantAggregateZero* Context::GetAggregateZero(std::shared_ptr<Type> ty) {
auto it = aggregate_zeros_.find(ty.get());
if (it != aggregate_zeros_.end()) {
@ -124,4 +76,5 @@ std::string Context::NextTemp() {
return oss.str();
}
} // namespace ir

@ -1,30 +1,9 @@
// ir/GlobalValue.cpp
#include "ir/IR.h"
#include <stdexcept>
namespace ir {
namespace {
ConstantValue* GetScalarZeroConstant(const Type& type) {
if (type.IsInt32()) {
static ConstantInt* zero_i32 = new ConstantInt(Type::GetInt32Type(), 0);
return zero_i32;
}
if (type.IsFloat()) {
static ConstantFloat* zero_f32 = new ConstantFloat(Type::GetFloatType(), 0.0f);
return zero_f32;
}
if (type.IsInt1()) {
static ConstantInt* zero_i1 = new ConstantInt(Type::GetInt1Type(), 0);
return zero_i1;
}
return nullptr;
}
} // namespace
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
@ -34,10 +13,42 @@ void GlobalValue::SetInitializer(ConstantValue* init) {
}
// 获取实际的值类型(用于类型检查)
std::shared_ptr<Type> value_type = GetValueType();
std::shared_ptr<Type> value_type;
// 如果当前类型是指针,获取指向的值类型
if (GetType()->IsPtrInt32()) {
value_type = Type::GetInt32Type();
} else if (GetType()->IsPtrFloat()) {
value_type = Type::GetFloatType();
} else if (GetType()->IsPtrInt1()) {
value_type = Type::GetInt1Type();
} else {
// 非指针类型:直接使用当前类型
value_type = GetType();
}
// 类型检查
bool type_match = CheckTypeCompatibility(value_type, init);
bool type_match = false;
// 检查标量类型
if (value_type->IsInt32() && init->GetType()->IsInt32()) {
type_match = true;
} else if (value_type->IsFloat() && init->GetType()->IsFloat()) {
type_match = true;
} else if (value_type->IsInt1() && init->GetType()->IsInt1()) {
type_match = true;
}
// 检查数组类型:允许用单个标量初始化整个数组
else if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
auto* elem_type = array_ty->GetElementType().get();
if (elem_type->IsInt32() && init->GetType()->IsInt32()) {
type_match = true;
} else if (elem_type->IsFloat() && init->GetType()->IsFloat()) {
type_match = true;
}
}
if (!type_match) {
throw std::runtime_error("GlobalValue::SetInitializer: type mismatch");
@ -49,14 +60,23 @@ void GlobalValue::SetInitializer(ConstantValue* init) {
void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
if (init.empty()) {
initializer_.clear();
return;
}
// 获取实际的值类型
std::shared_ptr<Type> value_type = GetValueType();
std::shared_ptr<Type> value_type;
// 类型检查
if (GetType()->IsPtrInt32()) {
value_type = Type::GetInt32Type();
} else if (GetType()->IsPtrFloat()) {
value_type = Type::GetFloatType();
} else if (GetType()->IsPtrInt1()) {
value_type = Type::GetInt1Type();
} else {
value_type = GetType();
}
// 检查类型
if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
size_t array_size = array_ty->GetElementCount();
@ -67,23 +87,16 @@ void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
// 检查每个初始化值的类型
auto* elem_type = array_ty->GetElementType().get();
for (size_t i = 0; i < init.size(); ++i) {
auto* elem = init[i];
if (!elem) {
throw std::runtime_error("GlobalValue::SetInitializer: null initializer at index " + std::to_string(i));
}
for (auto* elem : init) {
bool elem_match = false;
if (elem_type->IsInt32() && elem->GetType()->IsInt32()) {
elem_match = true;
} else if (elem_type->IsFloat() && elem->GetType()->IsFloat()) {
elem_match = true;
} else if (elem_type->IsInt1() && elem->GetType()->IsInt1()) {
elem_match = true;
}
if (!elem_match) {
throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch at index " + std::to_string(i));
throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch");
}
}
}
@ -92,10 +105,6 @@ void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
throw std::runtime_error("GlobalValue::SetInitializer: scalar requires exactly one initializer");
}
if (!init[0]) {
throw std::runtime_error("GlobalValue::SetInitializer: null initializer");
}
if ((value_type->IsInt32() && !init[0]->GetType()->IsInt32()) ||
(value_type->IsFloat() && !init[0]->GetType()->IsFloat()) ||
(value_type->IsInt1() && !init[0]->GetType()->IsInt1())) {
@ -109,87 +118,4 @@ void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
initializer_ = init;
}
// 辅助方法:获取实际的值类型(处理指针包装)
std::shared_ptr<Type> GlobalValue::GetValueType() const {
if (GetType()->IsPtrInt32()) {
return Type::GetInt32Type();
} else if (GetType()->IsPtrFloat()) {
return Type::GetFloatType();
} else if (GetType()->IsPtrInt1()) {
return Type::GetInt1Type();
}
return GetType();
}
// 辅助方法:检查类型兼容性
bool GlobalValue::CheckTypeCompatibility(std::shared_ptr<Type> value_type,
ConstantValue* init) const {
// 检查标量类型
if (value_type->IsInt32() && init->GetType()->IsInt32()) {
return true;
} else if (value_type->IsFloat() && init->GetType()->IsFloat()) {
return true;
} else if (value_type->IsInt1() && init->GetType()->IsInt1()) {
return true;
}
// 检查数组类型:允许用单个标量初始化整个数组
else if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
auto* elem_type = array_ty->GetElementType().get();
if (elem_type->IsInt32() && init->GetType()->IsInt32()) {
return true;
} else if (elem_type->IsFloat() && init->GetType()->IsFloat()) {
return true;
} else if (elem_type->IsInt1() && init->GetType()->IsInt1()) {
return true;
}
// 也可以允许 ConstantArray 作为初始化器
else if (init->GetType()->IsArray()) {
auto* init_array = static_cast<ConstantArray*>(init);
return init_array->IsValid();
}
}
// 检查指针类型(用于数组参数)
else if (value_type->IsPtrInt32() && init->GetType()->IsInt32()) {
return true;
} else if (value_type->IsPtrFloat() && init->GetType()->IsFloat()) {
return true;
}
return false;
}
// 添加获取数组元素的便捷方法
ConstantValue* GlobalValue::GetArrayElement(size_t index) const {
if (!GetType()->IsArray()) {
return nullptr;
}
auto* array_ty = dynamic_cast<ArrayType*>(GetType().get());
if (!array_ty) {
return nullptr;
}
if (index >= static_cast<size_t>(array_ty->GetElementCount())) {
return nullptr;
}
if (index >= initializer_.size()) {
return GetScalarZeroConstant(*array_ty->GetElementType());
}
return initializer_[index];
}
// 添加获取数组元素数量的方法
size_t GlobalValue::GetArraySize() const {
if (!IsArrayConstant()) {
return 0;
}
return initializer_.size();
}
// 添加判断是否为数组常量的方法
bool GlobalValue::IsArrayConstant() const {
return GetType()->IsArray() && !initializer_.empty();
}
} // namespace ir

@ -119,17 +119,6 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!ty) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAlloca 缺少类型"));
}
return insert_block_->Append<AllocaInst>(ty, name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -201,21 +190,18 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
}
} else if (ptr_ty->IsPtrFloat()) {
if (!val_ty->IsFloat()) {
throw std::runtime_error(
FormatError("ir", "存储类型不匹配:期望 float, 实际 kind=" +
std::to_string(static_cast<int>(val_ty->GetKind()))));
throw std::runtime_error(FormatError("ir", "存储类型不匹配:期望 float"));
}
} else if (ptr_ty->IsArray()) {
// 数组存储支持两种形式:
// 1. 标量元素写入(通常配合 GEP 后落到元素指针,不会走到这里)
// 2. 聚合数组整体写入,例如 `store [16 x i32] zeroinitializer, [16 x i32]* %arr`
if (!val_ty->IsArray()) {
throw std::runtime_error(
FormatError("ir", "数组地址仅支持聚合数组整体存储"));
}
if (val_ty->GetKind() != ptr_ty->GetKind()) {
throw std::runtime_error(
FormatError("ir", "聚合数组存储类型不匹配"));
// 数组存储:检查元素类型
auto* array_ty = dynamic_cast<ArrayType*>(ptr_ty.get());
if (array_ty) {
auto elem_ty = array_ty->GetElementType();
if (elem_ty->IsInt32() && !val_ty->IsInt32()) {
throw std::runtime_error(FormatError("ir", "数组元素类型不匹配:期望 int32"));
} else if (elem_ty->IsFloat() && !val_ty->IsFloat()) {
throw std::runtime_error(FormatError("ir", "数组元素类型不匹配:期望 float"));
}
}
}
@ -226,6 +212,10 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!v) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
}
@ -396,8 +386,7 @@ ZExtInst* IRBuilder::CreateZExt(Value* value, std::shared_ptr<Type> target_ty,
FormatError("ir", "ZExt 目标类型必须是整数类型"));
}
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<ZExtInst>(value, target_ty, inst_name);
return insert_block_->Append<ZExtInst>(value, target_ty, name);
}
// 创建截断指令
@ -427,8 +416,7 @@ TruncInst* IRBuilder::CreateTrunc(Value* value, std::shared_ptr<Type> target_ty,
FormatError("ir", "Trunc 目标类型必须是整数类型"));
}
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<TruncInst>(value, target_ty, inst_name);
return insert_block_->Append<TruncInst>(value, target_ty, name);
}
// 便捷方法i1 转 i32
@ -478,9 +466,7 @@ BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name
if (!rhs) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAnd 缺少 rhs"));
}
auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type()
: Type::GetInt32Type();
return insert_block_->Append<BinaryInst>(Opcode::And, result_ty, lhs, rhs, name);
return insert_block_->Append<BinaryInst>(Opcode::And, Type::GetInt32Type(), lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) {
@ -493,9 +479,7 @@ BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name)
if (!rhs) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateOr 缺少 rhs"));
}
auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type()
: Type::GetInt32Type();
return insert_block_->Append<BinaryInst>(Opcode::Or, result_ty, lhs, rhs, name);
return insert_block_->Append<BinaryInst>(Opcode::Or, Type::GetInt32Type(), lhs, rhs, name);
}
IcmpInst* IRBuilder::CreateNot(Value* val, const std::string& name) {
@ -505,12 +489,7 @@ IcmpInst* IRBuilder::CreateNot(Value* val, const std::string& name) {
if (!val) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateNot 缺少 operand"));
}
if (val->GetType()->IsInt1()) {
auto* ext = CreateZExtI1ToI32(val, "");
auto* zero = CreateConstInt(0);
return CreateICmpEQ(ext, zero, name);
}
auto* zero = CreateConstInt(0);
auto zero = CreateConstInt(0);
return CreateICmpEQ(val, zero, name);
}
@ -532,29 +511,8 @@ GEPInst* IRBuilder::CreateGEP(Value* base,
}
}
// 结果类型推断:
// - 对 i32*/float* 基址,结果仍分别为 i32*/float*
// - 对数组基址,按多索引向下剥离元素类型;若到达标量则返回对应标量指针
// (本项目没有“指向数组的指针类型”,未完全剥离时退回数组类型)
std::shared_ptr<Type> result_ty = base->GetType();
if (base->GetType()->IsPtrInt32()) {
result_ty = Type::GetPtrInt32Type();
} else if (base->GetType()->IsPtrFloat()) {
result_ty = Type::GetPtrFloatType();
} else if (base->GetType()->IsArray()) {
std::shared_ptr<Type> cur = base->GetType();
for (size_t i = 1; i < indices.size(); ++i) {
auto* at = dynamic_cast<ArrayType*>(cur.get());
if (!at) break;
cur = at->GetElementType();
}
if (cur->IsInt32()) result_ty = Type::GetPtrInt32Type();
else if (cur->IsFloat()) result_ty = Type::GetPtrFloatType();
else result_ty = cur;
}
return insert_block_->Append<GEPInst>(result_ty, base, indices, name);
// GEP返回指针类型假设与base类型相同
return insert_block_->Append<GEPInst>(base->GetType(), base, indices, name);
}
@ -651,20 +609,12 @@ FcmpInst* IRBuilder::CreateFCmpOGE(Value* lhs, Value* rhs, const std::string& na
// 类型转换
SIToFPInst* IRBuilder::CreateSIToFP(Value* value, std::shared_ptr<Type> target_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<SIToFPInst>(value, target_ty, inst_name);
return insert_block_->Append<SIToFPInst>(value, target_ty, name);
}
FPToSIInst* IRBuilder::CreateFPToSI(Value* value, std::shared_ptr<Type> target_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<FPToSIInst>(value, target_ty, inst_name);
return insert_block_->Append<FPToSIInst>(value, target_ty, name);
}
} // namespace ir

@ -4,9 +4,6 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <ostream>
#include <stdexcept>
#include <string>
@ -15,109 +12,7 @@
namespace ir {
static std::string TypeToString(const Type& ty);
static std::string ArrayTypeToStringFrom(const Type& base_ty,
const std::vector<int>& dims,
size_t begin) {
std::string s = TypeToString(base_ty);
for (size_t i = dims.size(); i-- > begin;) {
s = "[" + std::to_string(dims[i]) + " x " + s + "]";
}
return s;
}
static bool IsZeroConstant(const ConstantValue* value) {
if (!value) {
return true;
}
if (auto* ci = dynamic_cast<const ConstantInt*>(value)) {
return ci->GetValue() == 0;
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(value)) {
return cf->GetValue() == 0.0f;
}
if (dynamic_cast<const ConstantZero*>(value) ||
dynamic_cast<const ConstantAggregateZero*>(value)) {
return true;
}
if (auto* arr = dynamic_cast<const ConstantArray*>(value)) {
for (auto* elem : arr->GetElements()) {
if (!IsZeroConstant(elem)) {
return false;
}
}
return true;
}
return false;
}
static size_t AggregateSpan(const std::vector<int>& dims, size_t level) {
size_t span = 1;
for (size_t i = level; i < dims.size(); ++i) {
span *= static_cast<size_t>(dims[i]);
}
return span;
}
static bool IsZeroRange(const std::vector<ConstantValue*>& init,
size_t begin,
size_t count) {
for (size_t i = 0; i < count; ++i) {
const size_t index = begin + i;
if (index >= init.size()) {
continue;
}
if (!IsZeroConstant(init[index])) {
return false;
}
}
return true;
}
static void PrintFlatArrayBody(std::ostream& os,
const Type& base_ty,
const std::vector<int>& dims,
size_t level,
const std::vector<ConstantValue*>& init,
size_t& flat_index) {
const size_t span = AggregateSpan(dims, level);
if (IsZeroRange(init, flat_index, span)) {
os << "zeroinitializer";
flat_index += span;
return;
}
os << "[";
for (int i = 0; i < dims[level]; ++i) {
if (i > 0) os << ", ";
if (level + 1 < dims.size()) {
os << ArrayTypeToStringFrom(base_ty, dims, level + 1) << " ";
PrintFlatArrayBody(os, base_ty, dims, level + 1, init, flat_index);
continue;
}
os << TypeToString(base_ty) << " ";
if (flat_index < init.size() && init[flat_index]) {
if (auto* ci = dynamic_cast<const ConstantInt*>(init[flat_index])) {
os << ci->GetValue();
} else if (auto* cf = dynamic_cast<const ConstantFloat*>(init[flat_index])) {
os << cf->GetValue();
} else if (IsZeroConstant(init[flat_index])) {
os << "0";
} else {
os << "0";
}
} else {
os << "0";
}
++flat_index;
}
os << "]";
}
static std::string TypeToString(const Type& ty) {
static const char* TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void: return "void";
case Type::Kind::Int32: return "i32";
@ -125,23 +20,10 @@ static std::string TypeToString(const Type& ty) {
case Type::Kind::PtrInt32: return "i32*";
case Type::Kind::PtrFloat: return "float*";
case Type::Kind::Label: return "label";
case Type::Kind::Array: return "array";
case Type::Kind::Function: return "function";
case Type::Kind::Int1: return "i1";
case Type::Kind::PtrInt1: return "i1*";
case Type::Kind::Array: {
// 打印数组类型为 LLVM 风格,如 [4 x [2 x i32]]
auto* at = dynamic_cast<const ArrayType*>(&ty);
if (!at) return "array";
// 递归构建类型字符串
std::string elem = TypeToString(*at->GetElementType());
const auto& dims = at->GetDimensions();
// 从外到内构建
std::string s = elem;
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
s = "[" + std::to_string(*it) + " x " + s + "]";
}
return s;
}
default: return "unknown";
}
throw std::runtime_error(FormatError("ir", "未知类型"));
@ -172,9 +54,9 @@ static const char* OpcodeToString(Opcode op) {
case Opcode::Icmp:
return "icmp";
case Opcode::Div:
return "sdiv";
return "div";
case Opcode::Mod:
return "srem";
return "mod";
case Opcode::ZExt:
return "zext";
case Opcode::Trunc:
@ -200,29 +82,12 @@ static const char* OpcodeToString(Opcode op) {
return "?";
}
// 将 float 值转为 LLVM IR 接受的 64-bit 十六进制浮点格式
static std::string FloatToLLVMHex(float f) {
double d = static_cast<double>(f);
uint64_t bits;
memcpy(&bits, &d, sizeof(bits));
char buf[20];
snprintf(buf, sizeof(buf), "0x%016llX", (unsigned long long)bits);
return buf;
}
static std::string ValueToString(const Value* v) {
if (!v) {
return "<null>";
}
if (dynamic_cast<const ConstantZero*>(v) ||
dynamic_cast<const ConstantAggregateZero*>(v)) {
return "zeroinitializer";
}
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return FloatToLLVMHex(cf->GetValue());
if (!v) {
return "<null>";
}
const auto& name = v->GetName();
if (name.empty()) {
@ -237,107 +102,27 @@ static std::string ValueToString(const Value* v) {
return "%" + name;
}
static std::string MemoryTypeToString(const Type& ty) {
std::string text = TypeToString(ty);
if (ty.IsArray()) {
text += "*";
}
return text;
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& global : module.GetGlobals()) {
if (!global) continue;
os << "@" << global->GetName() << " = "
<< (global->IsConstant() ? "constant " : "global ");
os << "@" << global->GetName() << " = global ";
if (global->GetType()->IsPtrInt32()) {
os << "i32 ";
if (global->HasInitializer()) {
auto* ci = dynamic_cast<const ConstantInt*>(global->GetInitializer().front());
os << (ci ? ci->GetValue() : 0);
} else {
os << "0";
}
os << "\n";
continue;
}
if (global->GetType()->IsPtrFloat()) {
os << "float ";
if (global->HasInitializer()) {
auto* cf = dynamic_cast<const ConstantFloat*>(global->GetInitializer().front());
os << (cf ? ValueToString(cf) : FloatToLLVMHex(0.0f));
} else {
os << FloatToLLVMHex(0.0f);
}
os << "\n";
continue;
}
if (global->GetType()->IsArray()) {
auto* at = dynamic_cast<const ArrayType*>(global->GetType().get());
os << TypeToString(*global->GetType()) << " ";
if (!at || !global->HasInitializer() ||
IsZeroRange(global->GetInitializer(), 0, AggregateSpan(at->GetDimensions(), 0))) {
os << "zeroinitializer\n";
continue;
}
size_t flat_index = 0;
PrintFlatArrayBody(os,
*at->GetElementType(),
at->GetDimensions(),
0,
global->GetInitializer(),
flat_index);
os << "\n";
continue;
os << "i32 0\n";
} else if (global->GetType()->IsPtrFloat()) {
os << "float 0.0\n";
} else {
os << TypeToString(*global->GetType()) << " zeroinitializer\n";
}
os << TypeToString(*global->GetType()) << " zeroinitializer\n";
}
auto print_func_params = [&](const Function* func,
const FunctionType* func_ty) {
for (const auto& func : module.GetFunctions()) {
auto* func_ty = static_cast<const FunctionType*>(func->GetType().get());
os << "define " << TypeToString(*func_ty->GetReturnType()) << " @" << func->GetName() << "(";
bool first = true;
if (!func->GetArguments().empty()) {
for (const auto& arg : func->GetArguments()) {
if (!first) os << ", ";
first = false;
os << TypeToString(*arg->GetType()) << " %" << arg->GetName();
}
return;
}
for (const auto& pty : func_ty->GetParamTypes()) {
for (const auto& arg : func->GetArguments()) {
if (!first) os << ", ";
first = false;
os << TypeToString(*pty);
os << TypeToString(*arg->GetType()) << " %" << arg->GetName();
}
};
auto is_declaration_only = [](const Function* func) {
const auto& blocks = func->GetBlocks();
if (blocks.size() != 1) return false;
const auto& only = blocks.front();
if (!only) return false;
return only->GetInstructions().empty();
};
for (const auto& func : module.GetFunctions()) {
auto* func_ty = static_cast<const FunctionType*>(func->GetType().get());
if (is_declaration_only(func.get())) {
os << "declare " << TypeToString(*func_ty->GetReturnType()) << " @"
<< func->GetName() << "(";
print_func_params(func.get(), func_ty);
os << ")\n";
continue;
}
os << "define " << TypeToString(*func_ty->GetReturnType()) << " @"
<< func->GetName() << "(";
print_func_params(func.get(), func_ty);
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
@ -354,11 +139,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
case Opcode::Mod:
case Opcode::And:
case Opcode::Not:
case Opcode::Or:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::Or:
{
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
@ -385,7 +166,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< MemoryTypeToString(*load->GetPtr()->GetType()) << " "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
@ -393,29 +174,21 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
auto* store = static_cast<const StoreInst*>(inst);
os << " store " << TypeToString(*store->GetValue()->GetType()) << " "
<< ValueToString(store->GetValue())
<< ", " << MemoryTypeToString(*store->GetPtr()->GetType()) << " "
<< ", " << TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->GetValue()) {
os << " ret void\n";
} else {
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
}
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
break;
}
// CallInst类在 include/ir/IR.h 中定义
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
os << " ";
if (!call->GetType()->IsVoid()) {
os << call->GetName() << " = ";
}
os << "call " << TypeToString(*call->GetType()) << " @"
<< call->GetCallee()->GetName() << "(";
os << " " << call->GetName() << " = call "
<< TypeToString(*call->GetType()) << " @" << call->GetCallee()->GetName() << "(";
bool first = true;
for (auto* arg : call->GetArgs()) {
if (!first) os << ", ";
@ -475,81 +248,16 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
break;
}
case Opcode::GEP:{
// 打印为类似 LLVM 的 getelementptr 形式:
// getelementptr <elem_ty>, <base_ty> <base>, i32 <idx0>, i32 <idx1>, ...
// 简化打印:只打印基本信息和操作数数量
os << " " << inst->GetName() << " = getelementptr ";
// 基地址类型使用第一个操作数的类型
Value* base = inst->GetOperand(0);
// GEP 的第一个类型参数应是基址指向的元素类型pointee
std::string elem_ty;
if (base->GetType()->IsPtrInt32()) elem_ty = "i32";
else if (base->GetType()->IsPtrFloat()) elem_ty = "float";
else if (base->GetType()->IsArray()) elem_ty = TypeToString(*base->GetType());
else elem_ty = TypeToString(*inst->GetType());
std::string base_ty = TypeToString(*base->GetType());
if (base->GetType()->IsArray()) {
base_ty += "*";
}
os << elem_ty << ", " << base_ty << " " << ValueToString(base);
// 后续操作数为索引,按照 i32 打印
// 特殊处理:如果 base 是标量指针i32*/float*)且第一个索引是常量 0
// 且后续还有索引,则丢弃第一个 0对 T* 来说多余且会导致无效 IR
size_t start_idx = 1;
if ((base->GetType()->IsPtrInt32() || base->GetType()->IsPtrFloat()) &&
inst->GetNumOperands() >= 3) {
// 检查第一个索引是否为常量 0
auto* first_idx = inst->GetOperand(1);
if (auto* ci = dynamic_cast<const ConstantInt*>(first_idx)) {
if (ci->GetValue() == 0) {
start_idx = 2; // 跳过第一个 0
}
}
}
for (size_t i = start_idx; i < inst->GetNumOperands(); ++i) {
os << ", i32 " << ValueToString(inst->GetOperand(i));
os << TypeToString(*inst->GetType()) << " (";
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (i > 0) os << ", ";
os << ValueToString(inst->GetOperand(i));
}
os << "\n";
break;
}
case Opcode::FCmp: {
auto* fcmp = static_cast<const FcmpInst*>(inst);
os << " " << fcmp->GetName() << " = fcmp ";
switch (fcmp->GetPredicate()) {
case FcmpInst::Predicate::OEQ: os << "oeq"; break;
case FcmpInst::Predicate::ONE: os << "one"; break;
case FcmpInst::Predicate::OLT: os << "olt"; break;
case FcmpInst::Predicate::OLE: os << "ole"; break;
case FcmpInst::Predicate::OGT: os << "ogt"; break;
case FcmpInst::Predicate::OGE: os << "oge"; break;
default: os << "oeq"; break;
}
os << " " << TypeToString(*fcmp->GetLhs()->GetType())
<< " " << ValueToString(fcmp->GetLhs())
<< ", " << ValueToString(fcmp->GetRhs()) << "\n";
break;
}
case Opcode::SIToFP: {
auto* sitofp = static_cast<const SIToFPInst*>(inst);
os << " " << sitofp->GetName() << " = sitofp "
<< TypeToString(*sitofp->GetValue()->GetType()) << " "
<< ValueToString(sitofp->GetValue()) << " to "
<< TypeToString(*sitofp->GetType()) << "\n";
break;
}
case Opcode::FPToSI: {
auto* fptosi = static_cast<const FPToSIInst*>(inst);
os << " " << fptosi->GetName() << " = fptosi "
<< TypeToString(*fptosi->GetValue()->GetType()) << " "
<< ValueToString(fptosi->GetValue()) << " to "
<< TypeToString(*fptosi->GetType()) << "\n";
os << ")\n";
break;
}
default: {
// 处理未知操作码
os << " ; 未知指令: " << OpcodeToString(inst->GetOpcode()) << "\n";
@ -578,10 +286,10 @@ void IRPrinter::PrintConstant(const ConstantValue* constant, std::ostream& os) {
}
os << "]";
}
else if (dynamic_cast<const ConstantZero*>(constant)) {
else if (auto* zero = dynamic_cast<const ConstantZero*>(constant)) {
os << "zero";
}
else if (dynamic_cast<const ConstantAggregateZero*>(constant)) {
else if (auto* agg_zero = dynamic_cast<const ConstantAggregateZero*>(constant)) {
os << "zeroinitializer";
}
}

@ -73,10 +73,6 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
case Opcode::Mod:
case Opcode::And:
case Opcode::Or:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
// 有效的二元操作符
break;
case Opcode::Not:
@ -100,26 +96,21 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
throw std::runtime_error(FormatError("ir", "BinaryInst 操作数类型不匹配"));
}
bool is_logical = (op == Opcode::And || op == Opcode::Or);
// 检查操作数类型是否支持
if (is_logical) {
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsInt1()) {
throw std::runtime_error(
FormatError("ir", "逻辑运算仅支持 i32/i1"));
}
} else {
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) {
throw std::runtime_error(
FormatError("ir", "BinaryInst 只支持 int32 和 float 类型"));
}
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) {
throw std::runtime_error(
FormatError("ir", "BinaryInst 只支持 int32 和 float 类型"));
}
// 对于算术运算,结果类型应与操作数类型相同
bool is_logical = (op == Opcode::And || op == Opcode::Or);
if (is_logical) {
// 逻辑运算结果类型应与操作数一致i1 或 i32
if (type_->GetKind() != lhs->GetType()->GetKind()) {
// 比较和逻辑运算的结果应该是整数类型
if (!type_->IsInt32()) {
throw std::runtime_error(
FormatError("ir", "逻辑运算结果类型与操作数类型不匹配"));
FormatError("ir", "比较和逻辑运算的结果类型必须是 int32"));
}
} else {
// 算术运算的结果类型应与操作数类型相同
@ -139,27 +130,21 @@ Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
if (val) {
AddOperand(val);
}
AddOperand(val);
}
Value* ReturnInst::GetValue() const {
if (GetNumOperands() == 0) {
return nullptr;
}
return GetOperand(0);
}
Value* ReturnInst::GetValue() const { return GetOperand(0); }
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ ||
(!type_->IsPtrInt32() && !type_->IsPtrFloat() && !type_->IsArray())) {
throw std::runtime_error(
FormatError("ir", "AllocaInst 仅支持 i32* / float* / array"));
if (!type_ || !type_->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
}
}
@ -168,15 +153,12 @@ LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
}
if (!type_ || (!type_->IsInt32() && !type_->IsFloat() && !type_->IsInt1())) {
throw std::runtime_error(
FormatError("ir", "LoadInst 仅支持加载 i32/float/i1"));
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
}
if (!ptr->GetType() ||
(!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() &&
!ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) {
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "LoadInst 仅支持从指针或数组地址加载"));
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
}
AddOperand(ptr);
}
@ -194,25 +176,13 @@ StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!val->GetType() ||
(!val->GetType()->IsInt32() && !val->GetType()->IsFloat() &&
!val->GetType()->IsInt1() && !val->GetType()->IsArray())) {
throw std::runtime_error(
FormatError("ir", "StoreInst 仅支持存储 i32/float/i1/array"));
}
if (!ptr->GetType() ||
(!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() &&
!ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) {
throw std::runtime_error(FormatError("ir", "StoreInst 仅支持写入指针或数组地址"));
if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
}
if (ptr->GetType()->IsArray()) {
if (!val->GetType()->IsArray() ||
val->GetType()->GetKind() != ptr->GetType()->GetKind()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 聚合存储要求 value/ptr 具有相同数组类型"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
}
AddOperand(val);
AddOperand(ptr);
}

@ -14,7 +14,6 @@ size_t Type::Size() const {
case Kind::PtrInt32: return 8; // 假设 64 位指针
case Kind::PtrFloat: return 8;
case Kind::Label: return 8; // 标签地址大小(指针大小)
case Kind::PtrInt1: return 8; // 指向 i1 的指针大小
default: return 0; // 派生类应重写
}
}

File diff suppressed because it is too large Load Diff

@ -6,11 +6,10 @@
#include "ir/IR.h"
#include "utils/Log.h"
// 修改 GenerateIR 函数
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemaResult& sema_result) {
const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>();
IRGenImpl gen(*module, sema_result.context, sema_result.symbol_table);
IRGenImpl gen(*module, sema);
tree.accept(&gen);
return module;
}

@ -21,84 +21,96 @@
// - 条件与比较表达式
// - ...
// 表达式生成
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
DEBUG_MSG("[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText());
std::cout << "[DEBUG IRGEN] EvalExpr: " << expr.getText() << std::endl;
try {
auto result_any = expr.accept(this);
if (!result_any.has_value()) {
DEBUG_MSG("[ERROR] EvalExpr: result_any has no value");
std::cerr << "[ERROR] EvalExpr: result_any has no value" << std::endl;
throw std::runtime_error("表达式求值结果为空");
}
try {
ir::Value* result = std::any_cast<ir::Value*>(result_any);
DEBUG_MSG("[DEBUG] EvalExpr: success, result = " << (void*)result);
std::cerr << "[DEBUG] EvalExpr: success, result = " << (void*)result << std::endl;
return result;
} catch (const std::bad_any_cast& e) {
DEBUG_MSG("[ERROR] EvalExpr: bad any_cast - " << e.what());
DEBUG_MSG(" Type info: " << result_any.type().name());
throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型"));
std::cerr << "[ERROR] EvalExpr: bad any_cast - " << e.what() << std::endl;
std::cerr << " Type info: " << result_any.type().name() << std::endl;
// 尝试其他可能的类型
try {
// 检查是否是无值的any可能来自visit函数返回{}
std::cerr << "[DEBUG] EvalExpr: Trying to handle empty any" << std::endl;
return nullptr;
} catch (...) {
throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型"));
}
}
} catch (const std::exception& e) {
DEBUG_MSG("[ERROR] Exception in EvalExpr: " << e.what());
std::cerr << "[ERROR] Exception in EvalExpr: " << e.what() << std::endl;
throw;
}
}
ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) {
DEBUG_MSG("[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText());
return std::any_cast<ir::Value*>(cond.accept(this));
}
// 基本表达式:数字、变量、括号表达式
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitPrimaryExp: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少基本表达式"));
}
DEBUG_MSG("[DEBUG] visitPrimaryExp");
std::cerr << "[DEBUG] visitPrimaryExp" << std::endl;
// 处理数字字面量
if (ctx->DECIMAL_INT()) {
int value = std::stoi(ctx->DECIMAL_INT()->getText());
ir::Value* const_int = builder_.CreateConstInt(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant int " << value
<< " created as " << (void*)const_int);
std::cerr << "[DEBUG] visitPrimaryExp: constant int " << value
<< " created as " << (void*)const_int << std::endl;
return static_cast<ir::Value*>(const_int);
}
if (ctx->HEX_FLOAT()) {
std::string hex_float_str = ctx->HEX_FLOAT()->getText();
float value = 0.0f;
// 解析十六进制浮点数
try {
// C++11 的 std::stof 支持十六进制浮点数表示
value = std::stof(hex_float_str);
} catch (const std::exception& e) {
DEBUG_MSG("[WARNING] 无法解析十六进制浮点数: " << hex_float_str
<< "使用0.0代替");
std::cerr << "[WARNING] 无法解析十六进制浮点数: " << hex_float_str
<< "使用0.0代替" << std::endl;
value = 0.0f;
}
ir::Value* const_float = builder_.CreateConstFloat(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant hex float " << value
<< " created as " << (void*)const_float);
std::cerr << "[DEBUG] visitPrimaryExp: constant hex float " << value
<< " created as " << (void*)const_float << std::endl;
return static_cast<ir::Value*>(const_float);
}
// 处理十进制浮点常量
if (ctx->DEC_FLOAT()) {
std::string dec_float_str = ctx->DEC_FLOAT()->getText();
float value = 0.0f;
try {
value = std::stof(dec_float_str);
} catch (const std::exception& e) {
DEBUG_MSG("[WARNING] 无法解析十进制浮点数: " << dec_float_str
<< "使用0.0代替");
std::cerr << "[WARNING] 无法解析十进制浮点数: " << dec_float_str
<< "使用0.0代替" << std::endl;
value = 0.0f;
}
ir::Value* const_float = builder_.CreateConstFloat(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant dec float " << value
<< " created as " << (void*)const_float);
std::cerr << "[DEBUG] visitPrimaryExp: constant dec float " << value
<< " created as " << (void*)const_float << std::endl;
return static_cast<ir::Value*>(const_float);
}
@ -106,8 +118,6 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::string hex = ctx->HEX_INT()->getText();
int value = std::stoi(hex, nullptr, 16);
ir::Value* const_int = builder_.CreateConstInt(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant hex int " << value
<< " created as " << (void*)const_int);
return static_cast<ir::Value*>(const_int);
}
@ -115,132 +125,79 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::string oct = ctx->OCTAL_INT()->getText();
int value = std::stoi(oct, nullptr, 8);
ir::Value* const_int = builder_.CreateConstInt(value);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant octal int " << value
<< " created as " << (void*)const_int);
return static_cast<ir::Value*>(const_int);
}
if (ctx->ZERO()) {
ir::Value* const_int = builder_.CreateConstInt(0);
DEBUG_MSG("[DEBUG] visitPrimaryExp: constant zero int created");
return static_cast<ir::Value*>(const_int);
}
// 处理变量
if (ctx->lVal()) {
DEBUG_MSG("[DEBUG] visitPrimaryExp: visiting lVal");
std::cerr << "[DEBUG] visitPrimaryExp: visiting lVal" << std::endl;
return ctx->lVal()->accept(this);
}
// 处理括号表达式
if (ctx->L_PAREN() && ctx->exp()) {
DEBUG_MSG("[DEBUG] visitPrimaryExp: visiting parenthesized expression");
std::cerr << "[DEBUG] visitPrimaryExp: visiting parenthesized expression" << std::endl;
return EvalExpr(*ctx->exp());
}
DEBUG_MSG("[ERROR] visitPrimaryExp: unsupported primary expression type");
std::cerr << "[ERROR] visitPrimaryExp: unsupported primary expression type" << std::endl;
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型"));
}
// 左值(变量)处理
// 1. 先通过语义分析结果把变量使用绑定回声明;
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
// 3. 最后生成 load把内存中的值读出来。
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
std::string varName = ctx->Ident()->getText();
DEBUG_MSG("[DEBUG] visitLVal: " << varName);
// 先检查语义分析中常量绑定
const SysYParser::ConstDefContext* const_decl = sema_.ResolveConstUse(ctx);
const Symbol* sym = nullptr;
if (const_decl) {
sym = symbol_table_.lookupByConstDef(const_decl);
if (!sym) {
sym = symbol_table_.lookupAll(varName);
}
} else {
sym = symbol_table_.lookup(varName);
}
// 如果是常量,直接返回常量值
if (sym && sym->kind == SymbolKind::Constant) {
DEBUG_MSG("[DEBUG] visitLVal: 找到常量 " << varName);
if (sym->IsScalarConstant()) {
if (sym->type->IsInt32()) {
ir::ConstantValue* const_val = builder_.CreateConstInt(sym->GetIntConstant());
return static_cast<ir::Value*>(const_val);
} else if (sym->type->IsFloat()) {
ir::ConstantValue* const_val = builder_.CreateConstFloat(sym->GetFloatConstant());
return static_cast<ir::Value*>(const_val);
}
} else if (sym->IsArrayConstant()) {
auto it = const_global_map_.find(varName);
if (it != const_global_map_.end()) {
ir::GlobalValue* global_array = it->second;
// 尝试获取类型信息,用于维度判断与下标线性化
auto* array_ty = dynamic_cast<ir::ArrayType*>(sym->type.get());
if (!array_ty) {
// 无法获取数组类型,退回返回全局对象
return static_cast<ir::Value*>(global_array);
}
size_t ndims = array_ty->GetDimensions().size();
// 有下标访问
if (!ctx->exp().empty()) {
size_t provided = ctx->exp().size();
// 完全索引(所有维度都有下标)——直接返回常量元素,不生成 Load
if (provided == ndims) {
std::vector<int> idxs;
idxs.reserve(provided);
for (auto* exp : ctx->exp()) {
ir::Value* v = EvalExpr(*exp);
if (!v || !v->IsConstant()) {
throw std::runtime_error(FormatError("irgen", "常量数组索引必须为常量整数: " + varName));
}
auto* ci = dynamic_cast<ir::ConstantInt*>(v);
if (!ci) {
throw std::runtime_error(FormatError("irgen", "常量数组索引非整型常量: " + varName));
}
idxs.push_back(ci->GetValue());
}
// 计算线性下标(行主序)
const auto& dims = array_ty->GetDimensions();
int flat = idxs[0];
for (size_t i = 1; i < ndims; ++i) {
flat = flat * dims[i] + idxs[i];
}
ir::ConstantValue* elem = global_array->GetArrayElement(static_cast<size_t>(flat));
return static_cast<ir::Value*>(elem);
}
// 部分索引:返回指针(不做 Load由上层按需处理
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : ctx->exp()) {
indices.push_back(EvalExpr(*exp));
}
return static_cast<ir::Value*>(
builder_.CreateGEP(global_array, indices, module_.GetContext().NextTemp()));
} else {
// 无下标,直接返回全局常量对象
return static_cast<ir::Value*>(global_array);
}
std::cerr << "[DEBUG] visitLVal: " << varName << std::endl;
// 优先检查是否是常量
auto const_it = const_value_map_.find(varName);
if (const_it != const_value_map_.end()) {
// 常量直接返回值不需要load
std::cerr << "[DEBUG] visitLVal: constant " << varName << std::endl;
return static_cast<ir::Value*>(const_it->second);
}
// 检查全局常量
auto const_global_it = const_global_map_.find(varName);
if (const_global_it != const_global_map_.end()) {
// 全局常量需要load
ir::Value* ptr = const_global_it->second;
if (!ctx->exp().empty()) {
// 数组访问
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : ctx->exp()) {
ir::Value* index = EvalExpr(*exp);
indices.push_back(index);
}
ir::Value* elem_ptr = builder_.CreateGEP(
ptr, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
} else {
return static_cast<ir::Value*>(
builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
}
}
// 不是常量,按正常变量处理
// ... 原有的变量查找代码 ...
auto* decl = sema_.ResolveVarUse(ctx);
ir::Value* ptr = nullptr;
if (decl) {
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
@ -277,124 +234,27 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
// 检查是否有数组下标
bool is_array_access = !ctx->exp().empty();
if (is_array_access) {
// 收集下标表达式不含前导0
std::vector<ir::Value*> idx_vals;
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : ctx->exp()) {
ir::Value* index = EvalExpr(*exp);
idx_vals.push_back(index);
}
const Symbol* var_sym = sym;
if (!var_sym) {
var_sym = symbol_table_.lookup(varName);
}
if (!var_sym && decl) {
var_sym = symbol_table_.lookupByVarDef(decl);
indices.push_back(index);
}
if (!var_sym) {
var_sym = symbol_table_.lookupAll(varName);
}
std::vector<int> dims;
if (var_sym) {
if (var_sym->is_array_param && !var_sym->array_dims.empty()) {
dims = var_sym->array_dims;
} else if (var_sym->type && var_sym->type->IsArray()) {
auto* at = dynamic_cast<ir::ArrayType*>(var_sym->type.get());
if (at) dims = at->GetDimensions();
}
}
if (dims.empty() && ptr->GetType()->IsArray()) {
if (auto* at = dynamic_cast<ir::ArrayType*>(ptr->GetType().get())) {
dims = at->GetDimensions();
}
}
// 兜底:从语法树声明提取维度,避免作用域关闭后符号查询不完整。
if (dims.empty() && const_decl) {
auto* mutable_const_decl = const_cast<SysYParser::ConstDefContext*>(const_decl);
for (auto* cexp : mutable_const_decl->constExp()) {
dims.push_back(symbol_table_.EvaluateConstExp(cexp));
}
}
if (dims.empty() && decl) {
for (auto* cexp : decl->constExp()) {
dims.push_back(symbol_table_.EvaluateConstExp(cexp));
}
}
const bool is_partial_array_access =
!dims.empty() && idx_vals.size() < dims.size();
// 如果 base 是标量指针(例如局部扁平数组或数组参数),
// 需要把多维下标折合为单一线性下标,然后用一个索引进行 GEP。
if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) {
// 如果没有维度信息,仍尝试用运行时算术合并下标(按后维乘积)
// flat = idx0 * (prod dims[1..]) + idx1 * (prod dims[2..]) + ...
ir::Value* flat = nullptr;
for (size_t i = 0; i < idx_vals.size(); ++i) {
ir::Value* term = idx_vals[i];
if (!term) continue;
// 计算乘数(后续维度乘积)
int mult = 1;
if (!dims.empty() && i + 1 < dims.size()) {
for (size_t j = i + 1; j < dims.size(); ++j) {
// 数组参数首维可能是 0表示省略不参与乘数。
if (dims[j] > 0) mult *= dims[j];
}
}
if (mult != 1) {
auto* mval = builder_.CreateConstInt(mult);
term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp());
}
if (!flat) flat = term;
else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp());
}
if (!flat) flat = builder_.CreateConstInt(0);
// 使用单一索引创建 GEP
std::vector<ir::Value*> gep_indices = { flat };
ir::Value* elem_ptr = builder_.CreateGEP(ptr, gep_indices, module_.GetContext().NextTemp());
if (is_partial_array_access) {
return elem_ptr;
}
return static_cast<ir::Value*>(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
}
std::vector<ir::Value*> indices;
// 标量指针T*使用单索引数组对象使用前导0进入首层。
if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) {
for (auto* v : idx_vals) indices.push_back(v);
} else {
indices.push_back(builder_.CreateConstInt(0));
for (auto* v : idx_vals) indices.push_back(v);
}
ir::Value* elem_ptr = builder_.CreateGEP(ptr, indices, module_.GetContext().NextTemp());
if (is_partial_array_access) {
return elem_ptr;
}
return static_cast<ir::Value*>(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
ir::Value* elem_ptr = builder_.CreateGEP(
ptr, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
} else {
if ((sym && sym->is_array_param) ||
pointer_param_names_.find(varName) != pointer_param_names_.end() ||
heap_local_array_names_.find(varName) != heap_local_array_names_.end()) {
return ptr;
}
if (ptr->GetType()->IsArray()) {
return ptr;
}
return static_cast<ir::Value*>(builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
return static_cast<ir::Value*>(
builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
}
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitAddExp: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
@ -418,10 +278,10 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
}
ir::Value* right = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitAddExp: left=" << (void*)left
std::cerr << "[DEBUG] visitAddExp: left=" << (void*)left
<< ", type=" << (left->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)right
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (left->GetType()->IsFloat() != right->GetType()->IsFloat()) {
@ -458,7 +318,7 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitMulExp: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
@ -482,10 +342,10 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
}
ir::Value* right = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitMulExp: left=" << (void*)left
std::cerr << "[DEBUG] visitMulExp: left=" << (void*)left
<< ", type=" << (left->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)right
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (right->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (left->GetType()->IsFloat() != right->GetType()->IsFloat()) {
@ -532,7 +392,6 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
// 逻辑与
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
if (!ctx->lAndExp()) {
@ -541,28 +400,14 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
ir::Value* left = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
ir::Value* right = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
auto to_bool = [&](ir::Value* v) -> ir::Value* {
if (v->GetType()->IsInt1()) {
return v;
}
if (v->GetType()->IsFloat()) {
return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmpNE(v, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
};
auto* left_bool = to_bool(left);
auto* right_bool = to_bool(right);
return static_cast<ir::Value*>(
builder_.CreateAnd(left_bool, right_bool, module_.GetContext().NextTemp()));
auto zero = builder_.CreateConstInt(0);
auto left_bool = builder_.CreateICmpNE(left, zero, module_.GetContext().NextTemp());
auto right_bool = builder_.CreateICmpNE(right, zero, module_.GetContext().NextTemp());
return builder_.CreateAnd(left_bool, right_bool, module_.GetContext().NextTemp());
}
// 逻辑或
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
if (!ctx->lOrExp()) {
@ -571,52 +416,36 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
ir::Value* left = std::any_cast<ir::Value*>(ctx->lOrExp()->accept(this));
ir::Value* right = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
auto to_bool = [&](ir::Value* v) -> ir::Value* {
if (v->GetType()->IsInt1()) {
return v;
}
if (v->GetType()->IsFloat()) {
return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmpNE(v, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
};
auto* left_bool = to_bool(left);
auto* right_bool = to_bool(right);
return static_cast<ir::Value*>(
builder_.CreateOr(left_bool, right_bool, module_.GetContext().NextTemp()));
auto zero = builder_.CreateConstInt(0);
auto left_bool = builder_.CreateICmpNE(left, zero, module_.GetContext().NextTemp());
auto right_bool = builder_.CreateICmpNE(right, zero, module_.GetContext().NextTemp());
return builder_.CreateOr(left_bool, right_bool, module_.GetContext().NextTemp());
}
std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法表达式"));
return ctx->addExp()->accept(this);
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
return ctx->lOrExp()->accept(this);
}
std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法函数调用"));
}
std::string funcName = ctx->Ident()->getText();
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 调用函数 " << funcName);
std::cout << "[DEBUG IRGEN] visitCallExp: 调用函数 " << funcName << std::endl;
// 查找函数对象
ir::Function* callee = module_.FindFunction(funcName);
// 如果没找到,可能是运行时函数还没声明,尝试动态声明
if (!callee) {
DEBUG_MSG("[DEBUG IRGEN] 函数 " << funcName << " 未找到,尝试动态声明");
std::cout << "[DEBUG IRGEN] 函数 " << funcName << " 未找到,尝试动态声明" << std::endl;
// 根据函数名动态创建运行时函数声明
callee = CreateRuntimeFunctionDecl(funcName);
@ -631,36 +460,9 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
auto argList = ctx->funcRParams()->accept(this);
try {
args = std::any_cast<std::vector<ir::Value*>>(argList);
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 收集到 " << args.size() << " 个参数");
std::cout << "[DEBUG IRGEN] visitCallExp: 收集到 " << args.size() << " 个参数" << std::endl;
} catch (const std::bad_any_cast& e) {
DEBUG_MSG("[ERROR] visitCallExp: 函数调用参数类型错误: " << e.what());
}
}
// 按形参类型修正实参(数组衰减为指针等)。
if (auto* fty = dynamic_cast<ir::FunctionType*>(callee->GetType().get())) {
const auto& param_tys = fty->GetParamTypes();
size_t n = std::min(param_tys.size(), args.size());
for (size_t i = 0; i < n; ++i) {
if (!args[i] || !param_tys[i]) continue;
// 数组实参传给指针形参时,执行数组到指针衰减。
if (args[i]->GetType()->IsArray() &&
(param_tys[i]->IsPtrInt32() || param_tys[i]->IsPtrFloat())) {
std::vector<ir::Value*> idx;
idx.push_back(builder_.CreateConstInt(0));
idx.push_back(builder_.CreateConstInt(0));
args[i] = builder_.CreateGEP(args[i], idx, module_.GetContext().NextTemp());
}
// 标量实参的隐式类型转换int <-> float
if (param_tys[i]->IsFloat() && args[i]->GetType()->IsInt32()) {
args[i] = builder_.CreateSIToFP(args[i], ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (param_tys[i]->IsInt32() && args[i]->GetType()->IsFloat()) {
args[i] = builder_.CreateFPToSI(args[i], ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
std::cerr << "[ERROR] visitCallExp: 函数调用参数类型错误: " << e.what() << std::endl;
}
}
@ -673,13 +475,13 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
}
DEBUG_MSG("[DEBUG IRGEN] visitCallExp: 函数调用完成,返回值 " << (void*)callResult);
std::cout << "[DEBUG IRGEN] visitCallExp: 函数调用完成,返回值 " << (void*)callResult << std::endl;
return static_cast<ir::Value*>(callResult);
}
// 动态创建运行时函数声明的辅助函数
ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) {
DEBUG_MSG("[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName);
std::cout << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: " << funcName << std::endl;
// 根据常见运行时函数名创建对应的函数类型
if (funcName == "getint" || funcName == "getch") {
@ -696,7 +498,7 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName)
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type()}));
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
}
else if (funcName == "putarray") {
return module_.CreateFunction(funcName,
@ -720,27 +522,15 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName)
ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()}));
}
else if (funcName == "getfloat") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(ir::Type::GetFloatType(), {}));
}
else if (funcName == "putfloat") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetFloatType()}));
else if (funcName == "read_map") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}));
}
else if (funcName == "getfarray") {
return module_.CreateFunction(funcName,
else if (funcName == "float_eq") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetPtrFloatType()}));
}
else if (funcName == "putfarray") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()}));
{ir::Type::GetFloatType(), ir::Type::GetFloatType()}));
}
else if (funcName == "memset") {
return module_.CreateFunction(funcName,
@ -750,49 +540,12 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName)
ir::Type::GetInt32Type(),
ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_alloc_i32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetPtrInt32Type(),
{ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_alloc_f32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetPtrFloatType(),
{ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_free_i32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type()}));
}
else if (funcName == "sysy_free_f32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType()}));
}
else if (funcName == "sysy_zero_i32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_zero_f32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()}));
}
// 其他函数不支持动态创建
return nullptr;
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
}
@ -834,15 +587,14 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
ir::Value* zero;
if (operand->GetType()->IsFloat()) {
zero = builder_.CreateConstFloat(0.0f);
// 浮点逻辑非x == 0.0
ir::Value* cmp = builder_.CreateFCmpOEQ(operand, zero, module_.GetContext().NextTemp());
// 浮点比较:不等于0
ir::Value* cmp = builder_.CreateFCmpONE(operand, zero, module_.GetContext().NextTemp());
// 将bool转换为int
return static_cast<ir::Value*>(
builder_.CreateZExt(cmp, ir::Type::GetInt32Type()));
} else {
zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(
builder_.CreateNot(operand, module_.GetContext().NextTemp()));
return builder_.CreateNot(operand, module_.GetContext().NextTemp());
}
}
}
@ -852,7 +604,6 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 实现函数调用
std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) return std::vector<ir::Value*>{};
std::vector<ir::Value*> args;
for (auto* exp : ctx->exp()) {
@ -861,37 +612,67 @@ std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
return args;
}
// visitConstExp - 处理常量表达式
// 修改 visitConstExp 以支持常量表达式求值
std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx || !ctx->addExp()) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法常量表达式"));
}
auto result = ctx->addExp()->accept(this);
if (!result.has_value()) {
throw std::runtime_error(FormatError("irgen", "常量表达式求值失败"));
}
try {
return std::any_cast<ir::Value*>(result);
} catch (const std::bad_any_cast& e) {
throw std::runtime_error(FormatError("irgen",
"常量表达式返回类型错误: " + std::string(e.what())));
if (ctx->addExp()) {
// 尝试获取数值
auto result = ctx->addExp()->accept(this);
if (result.has_value()) {
try {
ir::Value* value = std::any_cast<ir::Value*>(result);
// 尝试判断是否是 ConstantInt
// 暂时简化:返回 IR 值
return static_cast<ir::Value*>(value);
} catch (const std::bad_any_cast&) {
// 可能是其他类型
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
}
}
}
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
} catch (const std::exception& e) {
std::cerr << "[WARNING] visitConstExp: 常量表达式求值失败: " << e.what()
<< "返回0" << std::endl;
// 如果普通表达式求值失败返回0
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
}
}
// visitConstInitVal - 处理常量初始化值
std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法常量初始化值"));
}
// 如果是单个常量表达式
if (ctx->constExp()) {
return ctx->constExp()->accept(this);
try {
auto result = ctx->constExp()->accept(this);
if (result.has_value()) {
try {
ir::Value* value = std::any_cast<ir::Value*>(result);
// 尝试提取常量值
if (auto* const_int = dynamic_cast<ir::ConstantInt*>(value)) {
return static_cast<ir::Value*>(const_int);
} else {
// 如果不是常量,尝试计算数值
int int_val = TryEvaluateConstInt(ctx->constExp());
return static_cast<ir::Value*>(builder_.CreateConstInt(int_val));
}
} catch (const std::bad_any_cast&) {
int int_val = TryEvaluateConstInt(ctx->constExp());
return static_cast<ir::Value*>(builder_.CreateConstInt(int_val));
}
}
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
} catch (const std::exception& e) {
std::cerr << "[WARNING] visitConstInitVal: " << e.what() << std::endl;
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
}
}
// 如果是聚合初始化(花括号列表)
else if (!ctx->constInitVal().empty()) {
@ -899,24 +680,22 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
for (auto* init_val : ctx->constInitVal()) {
auto result = init_val->accept(this);
if (!result.has_value()) {
throw std::runtime_error(FormatError("irgen", "常量初始化值求值失败"));
}
try {
// 尝试获取单个常量值
ir::Value* value = std::any_cast<ir::Value*>(result);
all_values.push_back(value);
} catch (const std::bad_any_cast&) {
if (result.has_value()) {
try {
// 尝试获取值列表(嵌套情况)
std::vector<ir::Value*> nested_values =
std::any_cast<std::vector<ir::Value*>>(result);
all_values.insert(all_values.end(),
nested_values.begin(), nested_values.end());
} catch (const std::bad_any_cast& e) {
throw std::runtime_error(FormatError("irgen",
"不支持的常量初始化值类型: " + std::string(e.what())));
// 尝试获取单个常量值
ir::Value* value = std::any_cast<ir::Value*>(result);
all_values.push_back(value);
} catch (const std::bad_any_cast&) {
try {
// 尝试获取值列表(嵌套情况)
std::vector<ir::Value*> nested_values =
std::any_cast<std::vector<ir::Value*>>(result);
all_values.insert(all_values.end(),
nested_values.begin(), nested_values.end());
} catch (const std::bad_any_cast&) {
throw std::runtime_error(
FormatError("irgen", "不支持的常量初始化值类型"));
}
}
}
}
@ -929,7 +708,6 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
@ -940,10 +718,10 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
auto* lhs = std::any_cast<ir::Value*>(left_any);
auto* rhs = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitRelExp: left=" << (void*)lhs
std::cerr << "[DEBUG] visitRelExp: left=" << (void*)lhs
<< ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)rhs
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (lhs->GetType()->IsFloat() != rhs->GetType()->IsFloat()) {
@ -1004,7 +782,6 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "<null>"));
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
@ -1015,10 +792,10 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
auto* lhs = std::any_cast<ir::Value*>(left_any);
auto* rhs = std::any_cast<ir::Value*>(right_any);
DEBUG_MSG("[DEBUG] visitEqExp: left=" << (void*)lhs
std::cerr << "[DEBUG] visitEqExp: left=" << (void*)lhs
<< ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int")
<< ", right=" << (void*)rhs
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int"));
<< ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int") << std::endl;
// 处理类型转换:如果操作数类型不同,需要进行类型转换
if (lhs->GetType()->IsFloat() != rhs->GetType()->IsFloat()) {
@ -1062,8 +839,7 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "<null>"));
DEBUG_MSG("[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->lVal() || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法赋值语句"));
}
@ -1088,29 +864,15 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
ir::Value* base_ptr = it->second;
auto convert_for_store = [&](ir::Value* value, ir::Value* ptr) -> ir::Value* {
if (ptr->GetType()->IsPtrFloat() && value->GetType()->IsInt32()) {
return builder_.CreateSIToFP(value, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
}
if (ptr->GetType()->IsPtrInt32() && value->GetType()->IsFloat()) {
return builder_.CreateFPToSI(value, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
return value;
};
// 检查是否有数组下标
auto exp_list = lval->exp();
if (!exp_list.empty()) {
// 这是数组元素赋值需要生成GEP指令
std::vector<ir::Value*> indices;
// 标量指针参数T*不应添加前导0数组对象需要前导0。
if (!(base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat())) {
indices.push_back(builder_.CreateConstInt(0));
}
// 第一个索引是0假设一维数组
indices.push_back(builder_.CreateConstInt(0));
// 添加用户提供的下标
for (auto* exp : exp_list) {
ir::Value* index = EvalExpr(*exp);
@ -1122,20 +884,18 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
base_ptr, indices, module_.GetContext().NextTemp());
// 生成store指令
rhs = convert_for_store(rhs, elem_ptr);
builder_.CreateStore(rhs, elem_ptr);
} else {
// 普通标量赋值
// 调试输出指针类型
DEBUG_MSG("[DEBUG] base_ptr type: " << base_ptr->GetType());
DEBUG_MSG("[DEBUG] rhs type: " << rhs->GetType());
std::cerr << "[DEBUG] base_ptr type: " << base_ptr->GetType() << std::endl;
std::cerr << "[DEBUG] rhs type: " << rhs->GetType()<< std::endl;
// 如果 base_ptr 不是标量指针类型,可能需要特殊处理
if (!base_ptr->GetType()->IsPtrInt32() && !base_ptr->GetType()->IsPtrFloat()) {
DEBUG_MSG("[ERROR] base_ptr is not a pointer type!");
// 如果 base_ptr 不是指针类型,可能需要特殊处理
if (!base_ptr->GetType()->IsPtrInt32()) {
std::cerr << "[ERROR] base_ptr is not a pointer type!" << std::endl;
throw std::runtime_error("尝试存储到非指针类型");
}
rhs = convert_for_store(rhs, base_ptr);
builder_.CreateStore(rhs, base_ptr);
}
} else {

@ -20,41 +20,18 @@ void VerifyFunctionStructure(const ir::Function& func) {
}
}
bool HasDirectSelfCall(antlr4::ParserRuleContext* node,
const std::string& func_name) {
if (!node) {
return false;
}
if (auto* unary = dynamic_cast<SysYParser::UnaryExpContext*>(node)) {
if (unary->Ident() && unary->Ident()->getText() == func_name) {
return true;
}
}
for (auto* child : node->children) {
if (auto* rule = dynamic_cast<antlr4::ParserRuleContext*>(child)) {
if (HasDirectSelfCall(rule, func_name)) {
return true;
}
}
}
return false;
}
} // namespace
// 实现新的构造函数
IRGenImpl::IRGenImpl(ir::Module& module,
const SemanticContext& sema,
const SymbolTable& sym_table)
: module_(module), sema_(sema), symbol_table_(sym_table),
builder_(module.GetContext(), nullptr), func_(nullptr) {
AddRuntimeFunctions();
}
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module),
sema_(sema),
func_(nullptr),
builder_(module.GetContext(), nullptr) {
AddRuntimeFunctions();
}
void IRGenImpl::AddRuntimeFunctions() {
DEBUG_MSG("[DEBUG IRGEN] 添加运行时库函数声明");
std::cout << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl;
// 输入函数(返回 int
module_.CreateFunction("getint",
@ -66,7 +43,7 @@ void IRGenImpl::AddRuntimeFunctions() {
module_.CreateFunction("getarray",
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type()}));
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
// 输出函数(返回 void
module_.CreateFunction("putint",
@ -106,22 +83,16 @@ void IRGenImpl::AddRuntimeFunctions() {
module_.CreateFunction("stoptime",
ir::Type::GetFunctionType(ir::Type::GetVoidType(), {}));
// 浮点 I/O
module_.CreateFunction("getfloat",
ir::Type::GetFunctionType(ir::Type::GetFloatType(), {}));
module_.CreateFunction("putfloat",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetFloatType()}));
module_.CreateFunction("getfarray",
// 其他可能需要的函数
module_.CreateFunction("read_map",
ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}));
// 浮点数
module_.CreateFunction("float_eq",
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetPtrFloatType()}));
module_.CreateFunction("putfarray",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()}));
{ir::Type::GetFloatType(), ir::Type::GetFloatType()}));
// 内存操作函数
module_.CreateFunction("memset",
ir::Type::GetFunctionType(
@ -129,48 +100,13 @@ void IRGenImpl::AddRuntimeFunctions() {
{ir::Type::GetPtrInt32Type(),
ir::Type::GetInt32Type(),
ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_alloc_i32",
ir::Type::GetFunctionType(
ir::Type::GetPtrInt32Type(),
{ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_alloc_f32",
ir::Type::GetFunctionType(
ir::Type::GetPtrFloatType(),
{ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_free_i32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type()}));
module_.CreateFunction("sysy_free_f32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType()}));
module_.CreateFunction("sysy_zero_i32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_zero_f32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()}));
DEBUG_MSG("[DEBUG IRGEN] 运行时库函数声明完成");
std::cout << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl;
}
// 修正:没有 mainFuncDef通过函数名找到 main
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitCompUnit");
DEBUG_MSG("[DEBUG] IRGen: 符号表地址 = " << &symbol_table_);
DEBUG_MSG("[DEBUG] IRGen: 开始生成 IR");
// 尝试查找 main 函数
const Symbol* main_sym = symbol_table_.lookup("main");
if (main_sym) {
DEBUG_MSG("[DEBUG] IRGen: 找到 main 函数符号");
} else {
DEBUG_MSG("[DEBUG] IRGen: 未找到 main 函数符号");
}
std::cout << "[DEBUG IRGEN] visitCompUnit" << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
}
@ -193,7 +129,7 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
}
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
@ -255,44 +191,31 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
auto func_type = ir::Type::GetFunctionType(ret_type, param_types);
// 调试输出
DEBUG_MSG("[DEBUG] visitFuncDef: 创建函数 " << funcName
std::cerr << "[DEBUG] visitFuncDef: 创建函数 " << funcName
<< ",返回类型: " << (ret_type->IsVoid() ? "void" : ret_type->IsFloat() ? "float" : "int")
<< ",参数数量: " << param_types.size());
<< ",参数数量: " << param_types.size() << std::endl;
// 创建函数对象
func_ = module_.CreateFunction(funcName, func_type);
// 检查函数是否成功创建
if (!func_) {
DEBUG_MSG("[ERROR] visitFuncDef: 创建函数失败func_ 为 nullptr!");
std::cerr << "[ERROR] visitFuncDef: 创建函数失败func_ 为 nullptr!" << std::endl;
throw std::runtime_error(FormatError("irgen", "创建函数失败: " + funcName));
}
DEBUG_MSG("[DEBUG] visitFuncDef: 函数对象地址: " << (void*)func_);
std::cerr << "[DEBUG] visitFuncDef: 函数对象地址: " << (void*)func_ << std::endl;
// 设置插入点
auto* entry_block = func_->GetEntry();
if (!entry_block) {
DEBUG_MSG("[ERROR] visitFuncDef: 函数入口基本块为空!");
std::cerr << "[ERROR] visitFuncDef: 函数入口基本块为空!" << std::endl;
throw std::runtime_error(FormatError("irgen", "函数入口基本块为空: " + funcName));
}
builder_.SetInsertPoint(entry_block);
storage_map_.clear();
param_map_.clear();
pointer_param_names_.clear();
heap_local_array_names_.clear();
current_function_name_ = funcName;
current_function_is_recursive_ = HasDirectSelfCall(ctx->block(), funcName);
function_cleanup_block_ = nullptr;
function_cleanup_actions_.clear();
function_return_slot_ = nullptr;
if (ret_type->IsInt32()) {
function_return_slot_ = CreateEntryAllocaI32(module_.GetContext().NextTemp() + ".retval");
} else if (ret_type->IsFloat()) {
function_return_slot_ = CreateEntryAllocaFloat(module_.GetContext().NextTemp() + ".retval");
}
// 函数参数处理
if (ctx->funcFParams()) {
@ -324,15 +247,15 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
// 检查函数对象是否有效
if (!func_) {
DEBUG_MSG("[ERROR] visitFuncDef: func_ 在添加参数时变为 nullptr!");
std::cerr << "[ERROR] visitFuncDef: func_ 在添加参数时变为 nullptr!" << std::endl;
throw std::runtime_error(FormatError("irgen", "函数对象无效"));
}
DEBUG_MSG("[DEBUG] visitFuncDef: 为函数 " << funcName
std::cerr << "[DEBUG] visitFuncDef: 为函数 " << funcName
<< " 添加参数 " << name << ",类型: "
<< (param_ty->IsInt32() ? "int32" : param_ty->IsFloat() ? "float" :
param_ty->IsPtrInt32() ? "ptr_int32" : param_ty->IsPtrFloat() ? "ptr_float" : "other")
);
<< std::endl;
// 创建参数并添加到函数
auto arg = std::make_unique<ir::Argument>(param_ty, name);
@ -344,130 +267,58 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
auto* added_arg = func_->AddArgument(std::move(arg));
if (!added_arg) {
DEBUG_MSG("[ERROR] visitFuncDef: AddArgument 返回 nullptr!");
std::cerr << "[ERROR] visitFuncDef: AddArgument 返回 nullptr!" << std::endl;
throw std::runtime_error(FormatError("irgen", "添加参数失败: " + name));
}
// 标量参数:入栈到本地槽位;数组参数(指针)直接作为地址使用。
if (param_ty->IsPtrInt32() || param_ty->IsPtrFloat()) {
param_map_[name] = added_arg;
pointer_param_names_.insert(name);
// 为参数创建存储槽位
ir::AllocaInst* slot = nullptr;
if (param_ty->IsInt32() || param_ty->IsPtrInt32()) {
slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
} else if (param_ty->IsFloat() || param_ty->IsPtrFloat()) {
slot = builder_.CreateAllocaFloat(module_.GetContext().NextTemp());
} else {
ir::AllocaInst* slot = nullptr;
if (param_ty->IsInt32()) {
slot = CreateEntryAllocaI32(module_.GetContext().NextTemp());
} else if (param_ty->IsFloat()) {
slot = CreateEntryAllocaFloat(module_.GetContext().NextTemp());
} else {
throw std::runtime_error(FormatError("irgen", "不支持的参数类型"));
}
if (!slot) {
throw std::runtime_error(FormatError("irgen", "创建参数存储槽位失败: " + name));
}
builder_.CreateStore(added_arg, slot);
param_map_[name] = slot;
pointer_param_names_.erase(name);
throw std::runtime_error(FormatError("irgen", "不支持的参数类型"));
}
DEBUG_MSG("[DEBUG] visitFuncDef: 参数 " << name << " 处理完成");
if (!slot) {
throw std::runtime_error(FormatError("irgen", "创建参数存储槽位失败: " + name));
}
builder_.CreateStore(added_arg, slot);
param_map_[name] = slot;
std::cerr << "[DEBUG] visitFuncDef: 参数 " << name << " 处理完成" << std::endl;
}
}
// 生成函数体
DEBUG_MSG("[DEBUG] visitFuncDef: 开始生成函数体");
std::cerr << "[DEBUG] visitFuncDef: 开始生成函数体" << std::endl;
ctx->block()->accept(this);
// 如果当前插入块没有终止指令,添加默认返回
if (auto* cur = builder_.GetInsertBlock(); cur && !cur->HasTerminator()) {
DEBUG_MSG("[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回");
if (function_cleanup_block_) {
if (ret_type->IsFloat()) {
builder_.CreateStore(builder_.CreateConstFloat(0.0f), function_return_slot_);
} else if (ret_type->IsInt32()) {
builder_.CreateStore(builder_.CreateConstInt(0), function_return_slot_);
}
builder_.CreateBr(function_cleanup_block_);
} else if (ret_type->IsVoid()) {
builder_.CreateRet(nullptr);
} else if (ret_type->IsFloat()) {
builder_.CreateRet(builder_.CreateConstFloat(0.0f));
} else {
builder_.CreateRet(builder_.CreateConstInt(0));
}
}
if (function_cleanup_block_ && !function_cleanup_block_->HasTerminator()) {
builder_.SetInsertPoint(function_cleanup_block_);
for (auto it = function_cleanup_actions_.rbegin();
it != function_cleanup_actions_.rend(); ++it) {
builder_.CreateCall(it->first, {it->second}, module_.GetContext().NextTemp());
}
if (ret_type->IsVoid()) {
builder_.CreateRet(nullptr);
} else {
builder_.CreateRet(builder_.CreateLoad(function_return_slot_, module_.GetContext().NextTemp()));
}
// 如果函数没有终止指令,添加默认返回
if (!func_->GetEntry()->HasTerminator()) {
std::cerr << "[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回" << std::endl;
auto retVal = builder_.CreateConstInt(0);
builder_.CreateRet(retVal);
}
// 验证函数结构
try {
VerifyFunctionStructure(*func_);
} catch (const std::exception& e) {
DEBUG_MSG("[ERROR] visitFuncDef: 验证函数结构失败: " << e.what());
std::cerr << "[ERROR] visitFuncDef: 验证函数结构失败: " << e.what() << std::endl;
throw;
}
DEBUG_MSG("[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成");
std::cerr << "[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成" << std::endl;
func_ = nullptr;
current_function_name_.clear();
current_function_is_recursive_ = false;
function_return_slot_ = nullptr;
function_cleanup_block_ = nullptr;
function_cleanup_actions_.clear();
return {};
}
ir::BasicBlock* IRGenImpl::EnsureCleanupBlock() {
if (!function_cleanup_block_) {
std::string name = module_.GetContext().NextTemp();
if (!name.empty() && name[0] == '%') {
name.erase(0, 1);
}
function_cleanup_block_ = func_->CreateBlock("cleanup." + name);
}
return function_cleanup_block_;
}
void IRGenImpl::RegisterCleanup(ir::Function* free_func, ir::Value* ptr) {
if (!free_func || !ptr) {
return;
}
EnsureCleanupBlock();
function_cleanup_actions_.push_back({free_func, ptr});
}
ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name) {
if (!func_ || !func_->GetEntry()) {
throw std::runtime_error(FormatError("irgen", "缺少函数入口块,无法创建入口栈槽位"));
}
return func_->GetEntry()->InsertBeforeTerminator<ir::AllocaInst>(ty, name);
}
ir::AllocaInst* IRGenImpl::CreateEntryAllocaI32(const std::string& name) {
return CreateEntryAlloca(ir::Type::GetPtrInt32Type(), name);
}
ir::AllocaInst* IRGenImpl::CreateEntryAllocaFloat(const std::string& name) {
return CreateEntryAlloca(ir::Type::GetPtrFloatType(), name);
}
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
@ -482,8 +333,8 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
}
auto* cur = builder_.GetInsertBlock();
DEBUG_MSG("[DEBUG] current insert block: "
<< (cur ? cur->GetName() : "<null>"));
std::cout << "[DEBUG] current insert block: "
<< (cur ? cur->GetName() : "<null>") << std::endl;
if (cur && cur->HasTerminator()) {
break;
}
@ -500,7 +351,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
}
// 用于遍历块内项,返回是否继续访问后续项(如遇到 return/break/continue 则终止访问)
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}

@ -16,7 +16,7 @@
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
@ -65,7 +65,7 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
// 修改 HandleReturnStmt 函数
IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
}
@ -88,12 +88,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
// 表达式被忽略(可计算但不使用)
EvalExpr(*ctx->exp());
}
if (function_cleanup_block_) {
builder_.CreateBr(function_cleanup_block_);
} else {
// 对于void函数创建返回指令不传参数
builder_.CreateRet(nullptr);
}
// 对于void函数创建返回指令不传参数
builder_.CreateRet(nullptr);
} else {
ir::Value* retValue = nullptr;
if (ctx->exp()) {
@ -119,12 +115,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
retValue = builder_.CreateConstInt(0); // fallback
}
}
if (function_cleanup_block_) {
builder_.CreateStore(retValue, function_return_slot_);
builder_.CreateBr(function_cleanup_block_);
} else {
builder_.CreateRet(retValue);
}
builder_.CreateRet(retValue);
}
return BlockFlow::Terminated;
}
@ -132,183 +123,162 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
// if语句待实现
IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
auto* cond = ctx->cond();
auto* thenStmt = ctx->stmt(0);
auto* elseStmt = ctx->stmt(1);
// 创建基本块(使用唯一名称,避免同名标签)
auto uniq = [&](const std::string& prefix) {
std::string t = module_.GetContext().NextTemp();
if (!t.empty() && t[0] == '%') t.erase(0, 1);
return prefix + "." + t;
};
auto* thenBlock = func_->CreateBlock(uniq("then"));
auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock(uniq("else")) : nullptr;
auto* mergeBlock = func_->CreateBlock(uniq("merge"));
DEBUG_MSG("[DEBUG IF] thenBlock: " << thenBlock->GetName());
if (elseBlock) DEBUG_MSG("[DEBUG IF] elseBlock: " << elseBlock->GetName());
DEBUG_MSG("[DEBUG IF] mergeBlock: " << mergeBlock->GetName());
DEBUG_MSG("[DEBUG IF] current insert block before cond: "
<< builder_.GetInsertBlock()->GetName());
// 创建基本块
auto* thenBlock = func_->CreateBlock("then");
auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock("else") : nullptr;
auto* mergeBlock = func_->CreateBlock("merge");
std::cout << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl;
if (elseBlock) std::cout << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl;
std::cout << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl;
std::cout << "[DEBUG IF] current insert block before cond: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
// 生成条件
auto* condValue = EvalCond(*cond);
if (!condValue->GetType()->IsInt1()) {
if (condValue->GetType()->IsFloat()) {
condValue = builder_.CreateFCmpONE(
condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp());
} else {
condValue = builder_.CreateICmpNE(
condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp());
}
}
// 创建条件跳转
if (elseBlock) {
DEBUG_MSG("[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << elseBlock->GetName());
std::cout << "[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << elseBlock->GetName() << std::endl;
builder_.CreateCondBr(condValue, thenBlock, elseBlock);
} else {
DEBUG_MSG("[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName());
std::cout << "[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName() << std::endl;
builder_.CreateCondBr(condValue, thenBlock, mergeBlock);
}
// 生成 then 分支
DEBUG_MSG("[DEBUG IF] Generating then branch in block: " << thenBlock->GetName());
std::cout << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl;
builder_.SetInsertPoint(thenBlock);
auto thenResult = thenStmt->accept(this);
bool thenTerminated = (std::any_cast<BlockFlow>(thenResult) == BlockFlow::Terminated);
DEBUG_MSG("[DEBUG IF] then branch terminated: " << thenTerminated);
std::cout << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl;
if (!thenTerminated) {
DEBUG_MSG("[DEBUG IF] Adding br to merge block from then");
std::cout << "[DEBUG IF] Adding br to merge block from then" << std::endl;
builder_.CreateBr(mergeBlock);
}
DEBUG_MSG("[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator());
std::cout << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl;
// 生成 else 分支
bool elseTerminated = false;
if (elseBlock) {
DEBUG_MSG("[DEBUG IF] Generating else branch in block: " << elseBlock->GetName());
std::cout << "[DEBUG IF] Generating else branch in block: " << elseBlock->GetName() << std::endl;
builder_.SetInsertPoint(elseBlock);
auto elseResult = elseStmt->accept(this);
elseTerminated = (std::any_cast<BlockFlow>(elseResult) == BlockFlow::Terminated);
DEBUG_MSG("[DEBUG IF] else branch terminated: " << elseTerminated);
std::cout << "[DEBUG IF] else branch terminated: " << elseTerminated << std::endl;
if (!elseTerminated) {
DEBUG_MSG("[DEBUG IF] Adding br to merge block from else");
std::cout << "[DEBUG IF] Adding br to merge block from else" << std::endl;
builder_.CreateBr(mergeBlock);
}
DEBUG_MSG("[DEBUG IF] else block has terminator: " << elseBlock->HasTerminator());
std::cout << "[DEBUG IF] else block has terminator: " << elseBlock->HasTerminator() << std::endl;
}
// 决定后续插入点
DEBUG_MSG("[DEBUG IF] thenTerminated=" << thenTerminated
<< ", elseTerminated=" << elseTerminated);
std::cout << "[DEBUG IF] thenTerminated=" << thenTerminated
<< ", elseTerminated=" << elseTerminated << std::endl;
if (elseBlock) {
DEBUG_MSG("[DEBUG IF] Setting insert point to merge block: "
<< mergeBlock->GetName());
builder_.SetInsertPoint(mergeBlock);
if (thenTerminated && elseTerminated) {
auto* afterIfBlock = func_->CreateBlock("after.if");
std::cout << "[DEBUG IF] Both branches terminated, creating new block: "
<< afterIfBlock->GetName() << std::endl;
builder_.SetInsertPoint(afterIfBlock);
} else {
std::cout << "[DEBUG IF] Setting insert point to merge block: "
<< mergeBlock->GetName() << std::endl;
builder_.SetInsertPoint(mergeBlock);
}
} else {
DEBUG_MSG("[DEBUG IF] No else, setting insert point to merge block: "
<< mergeBlock->GetName());
std::cout << "[DEBUG IF] No else, setting insert point to merge block: "
<< mergeBlock->GetName() << std::endl;
builder_.SetInsertPoint(mergeBlock);
}
DEBUG_MSG("[DEBUG IF] Final insert block: "
<< builder_.GetInsertBlock()->GetName());
std::cout << "[DEBUG IF] Final insert block: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
return BlockFlow::Continue;
}
// while语句待实现IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleWhileStmt: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] HandleWhileStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->cond() || !ctx->stmt(0)) {
throw std::runtime_error(FormatError("irgen", "非法 while 语句"));
}
DEBUG_MSG("[DEBUG WHILE] Current insert block before while: "
<< builder_.GetInsertBlock()->GetName());
std::cout << "[DEBUG WHILE] Current insert block before while: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
auto uniq = [&](const std::string& prefix) {
std::string t = module_.GetContext().NextTemp();
if (!t.empty() && t[0] == '%') t.erase(0, 1);
return prefix + "." + t;
};
auto* condBlock = func_->CreateBlock(uniq("while.cond"));
auto* bodyBlock = func_->CreateBlock(uniq("while.body"));
auto* exitBlock = func_->CreateBlock(uniq("while.exit"));
auto* condBlock = func_->CreateBlock("while.cond");
auto* bodyBlock = func_->CreateBlock("while.body");
auto* exitBlock = func_->CreateBlock("while.exit");
DEBUG_MSG("[DEBUG WHILE] condBlock: " << condBlock->GetName());
DEBUG_MSG("[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName());
DEBUG_MSG("[DEBUG WHILE] exitBlock: " << exitBlock->GetName());
std::cout << "[DEBUG WHILE] condBlock: " << condBlock->GetName() << std::endl;
std::cout << "[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName() << std::endl;
std::cout << "[DEBUG WHILE] exitBlock: " << exitBlock->GetName() << std::endl;
DEBUG_MSG("[DEBUG WHILE] Adding br to condBlock from current block");
std::cout << "[DEBUG WHILE] Adding br to condBlock from current block" << std::endl;
builder_.CreateBr(condBlock);
loopStack_.push_back({condBlock, bodyBlock, exitBlock});
DEBUG_MSG("[DEBUG WHILE] loopStack size: " << loopStack_.size());
std::cout << "[DEBUG WHILE] loopStack size: " << loopStack_.size() << std::endl;
// 条件块
DEBUG_MSG("[DEBUG WHILE] Generating condition in block: " << condBlock->GetName());
std::cout << "[DEBUG WHILE] Generating condition in block: " << condBlock->GetName() << std::endl;
builder_.SetInsertPoint(condBlock);
auto* condValue = EvalCond(*ctx->cond());
if (!condValue->GetType()->IsInt1()) {
if (condValue->GetType()->IsFloat()) {
condValue = builder_.CreateFCmpONE(
condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp());
} else {
condValue = builder_.CreateICmpNE(
condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp());
}
}
builder_.CreateCondBr(condValue, bodyBlock, exitBlock);
DEBUG_MSG("[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator());
std::cout << "[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator() << std::endl;
// 循环体
DEBUG_MSG("[DEBUG WHILE] Generating body in block: " << bodyBlock->GetName());
std::cout << "[DEBUG WHILE] Generating body in block: " << bodyBlock->GetName() << std::endl;
builder_.SetInsertPoint(bodyBlock);
auto bodyResult = ctx->stmt(0)->accept(this);
bool bodyTerminated = (std::any_cast<BlockFlow>(bodyResult) == BlockFlow::Terminated);
DEBUG_MSG("[DEBUG WHILE] body terminated: " << bodyTerminated);
std::cout << "[DEBUG WHILE] body terminated: " << bodyTerminated << std::endl;
if (!bodyTerminated) {
DEBUG_MSG("[DEBUG WHILE] Adding br to condBlock from body");
std::cout << "[DEBUG WHILE] Adding br to condBlock from body" << std::endl;
builder_.CreateBr(condBlock);
}
DEBUG_MSG("[DEBUG WHILE] bodyBlock has terminator: " << bodyBlock->HasTerminator());
std::cout << "[DEBUG WHILE] bodyBlock has terminator: " << bodyBlock->HasTerminator() << std::endl;
loopStack_.pop_back();
DEBUG_MSG("[DEBUG WHILE] loopStack size after pop: " << loopStack_.size());
std::cout << "[DEBUG WHILE] loopStack size after pop: " << loopStack_.size() << std::endl;
// 设置插入点为 exitBlock
DEBUG_MSG("[DEBUG WHILE] Setting insert point to exitBlock: " << exitBlock->GetName());
std::cout << "[DEBUG WHILE] Setting insert point to exitBlock: " << exitBlock->GetName() << std::endl;
builder_.SetInsertPoint(exitBlock);
DEBUG_MSG("[DEBUG WHILE] exitBlock has terminator before return: "
<< exitBlock->HasTerminator());
std::cout << "[DEBUG WHILE] exitBlock has terminator before return: "
<< exitBlock->HasTerminator() << std::endl;
return BlockFlow::Continue;
}
// break语句待实现
IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleBreakStmt: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] HandleBreakStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (loopStack_.empty()) {
throw std::runtime_error(FormatError("irgen", "break 语句不在循环中"));
}
DEBUG_MSG("[DEBUG BREAK] Current insert block before break: "
<< builder_.GetInsertBlock()->GetName());
DEBUG_MSG("[DEBUG BREAK] Breaking to exitBlock: "
<< loopStack_.back().exitBlock->GetName());
std::cout << "[DEBUG BREAK] Current insert block before break: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
std::cout << "[DEBUG BREAK] Breaking to exitBlock: "
<< loopStack_.back().exitBlock->GetName() << std::endl;
// 跳转到循环退出块
builder_.CreateBr(loopStack_.back().exitBlock);
@ -318,16 +288,16 @@ IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) {
}
IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleContinueStmt: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] HandleContinueStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (loopStack_.empty()) {
throw std::runtime_error(FormatError("irgen", "continue 语句不在循环中"));
}
DEBUG_MSG("[DEBUG CONTINUE] Current insert block before continue: "
<< builder_.GetInsertBlock()->GetName());
DEBUG_MSG("[DEBUG CONTINUE] Continuing to condBlock: "
<< loopStack_.back().condBlock->GetName());
std::cout << "[DEBUG CONTINUE] Current insert block before continue: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
std::cout << "[DEBUG CONTINUE] Continuing to condBlock: "
<< loopStack_.back().condBlock->GetName() << std::endl;
// 跳转到循环条件块
builder_.CreateBr(loopStack_.back().condBlock);
@ -340,7 +310,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx)
// 赋值语句
// 赋值语句
IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG IRGEN] HandleAssignStmt: " << (ctx ? ctx->getText() : "<null>"));
std::cout << "[DEBUG IRGEN] HandleAssignStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->lVal() || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法赋值语句"));
@ -354,7 +324,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto* lval = ctx->lVal();
std::string varName = lval->Ident()->getText();
DEBUG_MSG("[DEBUG] HandleAssignStmt: assigning to " << varName);
std::cerr << "[DEBUG] HandleAssignStmt: assigning to " << varName << std::endl;
// 1. 检查是否为常量(不能给常量赋值)
auto* const_decl = sema_.ResolveConstUse(lval);
@ -372,8 +342,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it = storage_map_.find(var_decl);
if (it != storage_map_.end()) {
base_ptr = it->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in storage_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
std::cerr << "[DEBUG] HandleAssignStmt: found in storage_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -382,8 +352,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it2 = param_map_.find(varName);
if (it2 != param_map_.end()) {
base_ptr = it2->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in param_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
std::cerr << "[DEBUG] HandleAssignStmt: found in param_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -392,8 +362,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it3 = global_map_.find(varName);
if (it3 != global_map_.end()) {
base_ptr = it3->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in global_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
std::cerr << "[DEBUG] HandleAssignStmt: found in global_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -402,8 +372,8 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto it4 = local_var_map_.find(varName);
if (it4 != local_var_map_.end()) {
base_ptr = it4->second;
DEBUG_MSG("[DEBUG] HandleAssignStmt: found in local_var_map_ for " << varName
<< ", ptr = " << (void*)base_ptr);
std::cerr << "[DEBUG] HandleAssignStmt: found in local_var_map_ for " << varName
<< ", ptr = " << (void*)base_ptr << std::endl;
}
}
@ -417,110 +387,37 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto exp_list = lval->exp();
if (!exp_list.empty()) {
// 数组元素赋值
std::vector<ir::Value*> idx_vals;
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : exp_list) {
ir::Value* index = EvalExpr(*exp);
idx_vals.push_back(index);
}
ir::Value* elem_ptr = nullptr;
// 扁平数组/数组参数T*)的多维访问:先线性化,再单索引 GEP。
if ((base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) && idx_vals.size() > 1) {
const Symbol* var_sym = symbol_table_.lookup(varName);
if (!var_sym && var_decl) {
var_sym = symbol_table_.lookupByVarDef(var_decl);
}
if (!var_sym) {
var_sym = symbol_table_.lookupAll(varName);
}
std::vector<int> dims;
if (var_sym) {
if (var_sym->is_array_param && !var_sym->array_dims.empty()) {
dims = var_sym->array_dims;
} else if (var_sym->type && var_sym->type->IsArray()) {
auto* at = dynamic_cast<ir::ArrayType*>(var_sym->type.get());
if (at) dims = at->GetDimensions();
}
}
if (dims.empty() && var_decl) {
for (auto* cexp : var_decl->constExp()) {
dims.push_back(symbol_table_.EvaluateConstExp(cexp));
}
}
ir::Value* flat = nullptr;
for (size_t i = 0; i < idx_vals.size(); ++i) {
ir::Value* term = idx_vals[i];
if (!term) continue;
int mult = 1;
if (!dims.empty() && i + 1 < dims.size()) {
for (size_t j = i + 1; j < dims.size(); ++j) {
if (dims[j] > 0) mult *= dims[j];
}
}
if (mult != 1) {
auto* mval = builder_.CreateConstInt(mult);
term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp());
}
if (!flat) flat = term;
else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp());
}
if (!flat) flat = builder_.CreateConstInt(0);
std::vector<ir::Value*> gep_indices = {flat};
elem_ptr = builder_.CreateGEP(base_ptr, gep_indices, module_.GetContext().NextTemp());
} else {
std::vector<ir::Value*> indices;
if (base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) {
for (auto* v : idx_vals) indices.push_back(v);
} else {
indices.push_back(builder_.CreateConstInt(0));
for (auto* v : idx_vals) indices.push_back(v);
}
elem_ptr = builder_.CreateGEP(base_ptr, indices, module_.GetContext().NextTemp());
indices.push_back(index);
}
if (elem_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) {
rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (elem_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) {
rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
ir::Value* elem_ptr = builder_.CreateGEP(
base_ptr, indices, module_.GetContext().NextTemp());
builder_.CreateStore(rhs, elem_ptr);
} else {
// 普通标量赋值
DEBUG_MSG("[DEBUG] HandleAssignStmt: scalar assignment to " << varName
std::cerr << "[DEBUG] HandleAssignStmt: scalar assignment to " << varName
<< ", ptr = " << (void*)base_ptr
<< ", rhs = " << (void*)rhs);
<< ", rhs = " << (void*)rhs << std::endl;
// 在 HandleAssignStmt 中,存储前添加类型调试
if (base_ptr && base_ptr->GetType()) {
DEBUG_MSG("[DEBUG] Is int32: " << base_ptr->GetType()->IsInt32());
DEBUG_MSG("[DEBUG] Is float: " << base_ptr->GetType()->IsFloat());
DEBUG_MSG("[DEBUG] Is ptr int32: " << base_ptr->GetType()->IsPtrInt32());
DEBUG_MSG("[DEBUG] Is ptr float: " << base_ptr->GetType()->IsPtrFloat());
DEBUG_MSG("[DEBUG] Is array: " << base_ptr->GetType()->IsArray());
std::cerr << "[DEBUG] Is int32: " << base_ptr->GetType()->IsInt32() << std::endl;
std::cerr << "[DEBUG] Is float: " << base_ptr->GetType()->IsFloat() << std::endl;
std::cerr << "[DEBUG] Is ptr int32: " << base_ptr->GetType()->IsPtrInt32() << std::endl;
std::cerr << "[DEBUG] Is ptr float: " << base_ptr->GetType()->IsPtrFloat() << std::endl;
std::cerr << "[DEBUG] Is array: " << base_ptr->GetType()->IsArray() << std::endl;
}
if (rhs && rhs->GetType()) {
DEBUG_MSG("[DEBUG] Value is int32: " << rhs->GetType()->IsInt32());
}
if (base_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) {
rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (base_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) {
rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
std::cerr << "[DEBUG] Value is int32: " << rhs->GetType()->IsInt32() << std::endl;
}
builder_.CreateStore(rhs, base_ptr);
builder_.CreateStore(rhs, base_ptr);
}
return BlockFlow::Continue;

@ -46,17 +46,13 @@ int main(int argc, char** argv) {
}
if (opts.emit_asm) {
//auto machine_func = mir::LowerToMIR(*module);
auto machine_module = mir::LowerToMIR(*module);
//mir::RunRegAlloc(*machine_func);
mir::RunRegAlloc(*machine_module);
//mir::RunFrameLowering(*machine_func);
mir::RunFrameLowering(*machine_module);
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
if (need_blank_line) {
std::cout << "\n";
}
//mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(*machine_module, std::cout);
mir::PrintAsm(*machine_func, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {

@ -2,24 +2,12 @@
#include <ostream>
#include <stdexcept>
#include <set>
#include "utils/Log.h"
//#define DEBUG_Asm
#ifdef DEBUG_Asm
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Asm Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace mir {
namespace {
static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm);
const FrameSlot& GetFrameSlot(const MachineFunction& function,
const Operand& operand) {
if (operand.GetKind() != Operand::Kind::FrameIndex) {
@ -28,458 +16,63 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintStackAccess(std::ostream& os, const char* insn, PhysReg reg, int64_t offset) {
// offset 通常是负数,例如 -8, -24, -40 等
if (offset >= -256 && offset <= 255) {
os << " " << insn << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
return;
}
// 大偏移量:用 x16 计算 x29 + offset然后间接访问
os << " mov x16, x29\n";
int64_t abs_offset = (offset >= 0) ? offset : -offset;
if (abs_offset <= 4095) {
if (offset >= 0) {
os << " add x16, x16, #" << offset << "\n";
} else {
os << " sub x16, x16, #" << abs_offset << "\n";
}
} else {
// 分解大偏移量
PrintLoadImm64(os, PhysReg::X17, abs_offset);
if (offset >= 0) {
os << " add x16, x16, x17\n";
} else {
os << " sub x16, x16, x17\n";
}
}
os << " " << insn << " " << PhysRegName(reg) << ", [x16]\n";
}
// 打印单个操作数
void PrintOperand(std::ostream& os, const Operand& op) {
switch (op.GetKind()) {
case Operand::Kind::Reg:
os << PhysRegName(op.GetReg());
break;
case Operand::Kind::Imm:
os << "#" << op.GetImm();
break;
case Operand::Kind::FrameIndex:
os << "[sp, #" << op.GetFrameIndex() << "]";
break;
case Operand::Kind::Cond:
os << CondCodeName(op.GetCondCode());
break;
case Operand::Kind::Label:
DEBUG_MSG("label is" << op.GetLabel());
os << op.GetLabel();
break;
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
}
// 判断立即数是否可作为 AArch64 ADD/SUB 指令的 12 位立即数(可左移 0 或 12 位)
static bool IsLegalAddSubImm(int64_t imm) {
if (imm < 0) imm = -imm; // 取绝对值,因为移位规则对称
if (imm <= 4095) return true; // 0-4095 直接合法
if ((imm & 0xFFF) == 0 && imm <= 4095 * 4096) return true; // 4096 的倍数且 ≤ 16773120
return false;
}
// 在匿名命名空间添加辅助函数
static void PrintLoadImm64(std::ostream& os, PhysReg reg, uint64_t imm) {
// 输出 movz + movk 序列
uint16_t part0 = imm & 0xFFFF;
uint16_t part1 = (imm >> 16) & 0xFFFF;
uint16_t part2 = (imm >> 32) & 0xFFFF;
uint16_t part3 = (imm >> 48) & 0xFFFF;
os << " movz " << PhysRegName(reg) << ", #" << part0;
if (part1 != 0 || part2 != 0 || part3 != 0) {
os << ", lsl #0";
}
os << "\n";
} // namespace
if (part1 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part1 << ", lsl #16\n";
}
if (part2 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part2 << ", lsl #32\n";
}
if (part3 != 0) {
os << " movk " << PhysRegName(reg) << ", #" << part3 << ", lsl #48\n";
}
}
void PrintAsm(const MachineFunction& function, std::ostream& os) {
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
// 打印单条指令
void PrintInstruction(std::ostream& os, const MachineInstr& instr,
const MachineFunction& function) {
const auto& ops = instr.GetOperands();
switch (instr.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
int64_t size = function.GetFrameSize();
if (IsLegalAddSubImm(size)) {
os << " sub sp, sp, #" << size << "\n";
} else {
PrintLoadImm64(os, PhysReg::X16, size);
os << " sub sp, sp, x16\n";
for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
int64_t size = function.GetFrameSize();
if (IsLegalAddSubImm(size)) {
os << " add sp, sp, #" << size << "\n";
} else {
PrintLoadImm64(os, PhysReg::X16, size);
os << " add sp, sp, x16\n";
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
}
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::MovReg:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::StoreStack: {
// 检查第二个操作数的类型
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
// 存储到栈槽
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接存储:存储到寄存器指向的地址
// STR W9, [X8]
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
} else {
throw std::runtime_error("StoreStack: 无效的操作数类型");
}
break;
}
case Opcode::LoadStack: {
// 检查第二个操作数的类型
if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::FrameIndex) {
// 从栈槽加载
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
} else if (ops.size() >= 2 && ops.at(1).GetKind() == Operand::Kind::Reg) {
// 间接加载:从寄存器指向的地址加载
// LDR W9, [X8]
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
} else {
throw std::runtime_error("LoadStack: 无效的操作数类型");
}
break;
}
case Opcode::StoreStackPair:
// stp x29, x30, [sp, #-16]!
os << " stp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "]!\n"; // 注意添加 ! 表示 pre-index
break;
case Opcode::LoadStackPair:
// ldp x29, x30, [sp], #16
os << " ldp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp]";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "\n";
break;
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SubRI:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::UDivRR:
os << " udiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FAddRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRI:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::FCmpRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ZExt:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #1\n";
break;
case Opcode::AndRR:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::OrRR:
os << " orr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::EorRR:
os << " eor " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LslRR:
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LsrRR:
os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AsrRR:
os << " asr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::B:
os << " b ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
case Opcode::BCond:
os << " b.";
PrintOperand(os, ops.at(0));
os << " ";
PrintOperand(os, ops.at(1));
os << "\n";
break;
case Opcode::Call:
os << " bl ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
case Opcode::Nop:
os << " nop\n";
break;
case Opcode::Label:
os << ops.at(0).GetLabel() << ":\n";
break;
case Opcode::Movk:
os << " movk " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << ", lsl #" << ops.at(2).GetImm() << "\n";
break;
case Opcode::LoadStackAddr: {
const FrameSlot& slot = GetFrameSlot(function, ops.at(1));
int64_t offset = slot.offset; // 负值,如 -8
PhysReg dst = ops.at(0).GetReg();
auto tryEmitSimple = [&]() -> bool {
if (offset >= 0 && offset <= 4095) {
os << " add " << PhysRegName(dst) << ", x29, #" << offset << "\n";
return true;
} else if (offset < 0 && offset >= -4095) {
os << " sub " << PhysRegName(dst) << ", x29, #" << (-offset) << "\n";
return true;
}
return false;
};
if (tryEmitSimple()) break;
// 复杂偏移
uint64_t absOffset = (offset >= 0) ? offset : -offset;
PrintLoadImm64(os, PhysReg::X16, absOffset);
if (offset >= 0) {
os << " add " << PhysRegName(dst) << ", x29, x16\n";
} else {
os << " sub " << PhysRegName(dst) << ", x29, x16\n";
}
break;
}
case Opcode::Adrp: {
// adrp Xd, label
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetLabel() << "\n";
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::AddLabel: {
// add Xd, Xn, :lo12:label
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", :lo12:"
<< ops.at(2).GetLabel() << "\n";
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
}
case Opcode::Sxtw:
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
default:
os << " // unknown instruction\n";
case Opcode::Ret:
os << " ret\n";
break;
}
}
// 打印单个函数(单函数版本)
void PrintAsm(const MachineFunction& function, std::ostream& os) {
// 输出函数标签
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
// 计算栈帧大小
int frameSize = function.GetFrameSize();
// 输出每个基本块
const auto& blocks = function.GetBasicBlocks();
bool firstBlock = true;
for (const auto& bb : blocks) {
DEBUG_MSG("block");
// 输出基本块标签(非第一个基本块)
if (!firstBlock) {
os << bb->GetName() << ":\n";
}
firstBlock = false;
// 输出基本块中的指令
for (const auto& inst : bb->GetInstructions()) {
DEBUG_MSG("inst");
PrintInstruction(os, inst, function);
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
// 打印模块(模块版本)
void PrintAsm(const MachineModule& module, std::ostream& os) {
// 输出文件头
os << ".arch armv8-a\n";
// 输出数据段:全局变量
const auto& globals = module.GetGlobals();
if (!globals.empty()) {
os << "\n.data\n";
for (const auto& g : globals) {
os << ".global " << g.name << "\n";
os << ".type " << g.name << ", %object\n";
os << ".align " << g.alignment << "\n";
os << g.name << ":\n";
if (g.is_zero_init) {
os << " .zero " << g.size << "\n";
} else if (g.has_init_data) {
if (g.size == 4) {
os << " .word " << static_cast<uint32_t>(g.init_data) << "\n";
} else if (g.size == 8) {
os << " .quad " << g.init_data << "\n";
} else {
// 暂不支持的标量大小,回退为零初始化
os << " .zero " << g.size << " // unhandled init size\n";
}
} else {
// 有初始值但无法提取(例如数组、结构体)
os << " .zero " << g.size << " // unhandled initializer\n";
}
os << ".size " << g.name << ", " << g.size << "\n\n";
}
}
static const std::set<std::string> externalFuncs = {
"getint", "getch", "getarray", "putint", "putch", "putarray", "puts",
"_sysy_starttime", "_sysy_stoptime", "starttime", "stoptime",
"getfloat", "putfloat", "getfarray", "putfarray", "memset",
"sysy_alloc_i32", "sysy_alloc_f32", "sysy_free_i32", "sysy_free_f32",
"sysy_zero_i32", "sysy_zero_f32"
};
DEBUG_MSG("module");
// 遍历所有函数,输出汇编
for (const auto& func : module.GetFunctions()) {
if (externalFuncs.count(func->GetName())) {
continue; // 跳过库函数桩
}
DEBUG_MSG("func");
PrintAsm(*func, os);
os << "\n";
}
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n";
}
} // namespace mir
} // namespace mir

@ -5,15 +5,6 @@
#include "utils/Log.h"
//#define DEBUG_Frame
#ifdef DEBUG_Frame
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Frame Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace mir {
namespace {
@ -24,49 +15,31 @@ int AlignTo(int value, int align) {
} // namespace
void RunFrameLowering(MachineFunction& function) {
DEBUG_MSG("function RunFrameLowering");
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
}
cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
}
function.SetFrameSize(AlignTo(cursor, 16));
// 基本块
const auto& blocks = function.GetBasicBlocks();
bool firstBlock = true;
for (const auto& bb : blocks) {
DEBUG_MSG("block");
auto& insts = bb->GetInstructions();
std::vector<MachineInstr> lowered;
// 输出基本块标签(非第一个基本块)
if (firstBlock) {
DEBUG_MSG("empalace Prologue");
lowered.emplace_back(Opcode::Prologue);
}
firstBlock = false;
// 输出基本块中的指令
for (const auto& inst : insts) {
DEBUG_MSG("inst");
if (inst.GetOpcode() == Opcode::Ret) {
DEBUG_MSG("empalace Epilogue");
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);
auto& insts = function.GetEntry().GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
}
insts = std::move(lowered);
}
}
// 模块版本的栈帧布局
void RunFrameLowering(MachineModule& module) {
// 对模块中的每个函数执行栈帧布局
DEBUG_MSG("module RunFrameLowering");
for (auto& func : module.GetFunctions()) {
RunFrameLowering(*func);
lowered.push_back(inst);
}
insts = std::move(lowered);
}
} // namespace mir
} // namespace mir

File diff suppressed because it is too large Load Diff

@ -9,8 +9,7 @@ MachineBasicBlock::MachineBasicBlock(std::string name)
MachineInstr& MachineBasicBlock::Append(Opcode opcode,
std::initializer_list<Operand> operands) {
//instructions_.emplace_back(opcode, std::vector<Operand>(operands));
instructions_.emplace_back(opcode, std::move(operands));
instructions_.emplace_back(opcode, std::vector<Operand>(operands));
return instructions_.back();
}

@ -4,25 +4,17 @@
namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label)
: kind_(kind), reg_(reg), imm_(imm), cc_(cc), label_(label) {}
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0, CondCode::EQ, ""); }
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value, CondCode::EQ, "");
return Operand(Kind::Imm, PhysReg::W0, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index, CondCode::EQ, "");
}
Operand Operand::Cond(CondCode cc) {
return Operand(Kind::Cond, PhysReg::W0, 0, cc, "");
}
Operand Operand::Label(const std::string& label) {
return Operand(Kind::Label, PhysReg::W0, 0, CondCode::EQ, label);
return Operand(Kind::FrameIndex, PhysReg::W0, index);
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)

@ -12,8 +12,8 @@ bool IsAllowedReg(PhysReg reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29: //FP = X29 帧指针
case PhysReg::X30: //LR = X30 链接寄存器
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
return true;
}
@ -22,61 +22,15 @@ bool IsAllowedReg(PhysReg reg) {
} // namespace
//void RunRegAlloc(MachineFunction& function) {
// for (const auto& inst : function.GetEntry().GetInstructions()) {
// for (const auto& operand : inst.GetOperands()) {
// if (operand.GetKind() == Operand::Kind::Reg &&
// !IsAllowedReg(operand.GetReg())) {
// throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
// }
// }
// }
//}
// 单函数版本的寄存器分配(原有逻辑)
void RunRegAlloc(MachineFunction& function) {
// 当前仅执行最小一致性检查,不实现真实寄存器分配
// Lab3 阶段保持栈槽模型,不需要真实寄存器分配
// 检查每个基本块中的指令
for (auto& bb : function.GetBasicBlocks()) {
for (auto& instr : bb->GetInstructions()) {
// 检查指令的操作数是否有效
for (const auto& operand : instr.GetOperands()) {
switch (operand.GetKind()) {
case Operand::Kind::Reg:
// 寄存器操作数:检查是否在允许的范围内
// 当前使用固定寄存器 w0, w8, w9, s0, s1 等
break;
case Operand::Kind::FrameIndex:
// 栈槽索引:检查是否有效
if (operand.GetFrameIndex() < 0 ||
operand.GetFrameIndex() >= static_cast<int>(function.GetFrameSlots().size())) {
throw std::runtime_error(
FormatError("regalloc", "无效的栈槽索引: " +
std::to_string(operand.GetFrameIndex())));
}
break;
case Operand::Kind::Imm:
case Operand::Kind::Cond:
case Operand::Kind::Label:
// 立即数、条件码、标签不需要检查
break;
}
for (const auto& inst : function.GetEntry().GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
}
}
}
// 注意Lab3 阶段不实现真实寄存器分配
// 所有值仍然使用栈槽模型,寄存器仅作为临时计算使用
}
// 模块版本的寄存器分配
void RunRegAlloc(MachineModule& module) {
// 对模块中的每个函数执行寄存器分配
for (auto& func : module.GetFunctions()) {
RunRegAlloc(*func);
}
}
} // namespace mir

@ -8,111 +8,20 @@ namespace mir {
const char* PhysRegName(PhysReg reg) {
switch (reg) {
// 32位寄存器
case PhysReg::W0: return "w0";
case PhysReg::W1: return "w1";
case PhysReg::W2: return "w2";
case PhysReg::W3: return "w3";
case PhysReg::W4: return "w4";
case PhysReg::W5: return "w5";
case PhysReg::W6: return "w6";
case PhysReg::W7: return "w7";
case PhysReg::W8: return "w8";
case PhysReg::W9: return "w9";
case PhysReg::W10: return "w10";
case PhysReg::W11: return "w11";
case PhysReg::W12: return "w12";
case PhysReg::W13: return "w13";
case PhysReg::W14: return "w14";
case PhysReg::W15: return "w15";
case PhysReg::W16: return "w16"; // 添加
case PhysReg::W17: return "w17"; // 添加
case PhysReg::W18: return "w18"; // 添加
case PhysReg::W19: return "w19"; // 添加
case PhysReg::W20: return "w20"; // 添加
case PhysReg::W21: return "w21"; // 添加
case PhysReg::W22: return "w22"; // 添加
case PhysReg::W23: return "w23"; // 添加
case PhysReg::W24: return "w24"; // 添加
case PhysReg::W25: return "w25"; // 添加
case PhysReg::W26: return "w26"; // 添加
case PhysReg::W27: return "w27"; // 添加
case PhysReg::W28: return "w28"; // 添加
case PhysReg::W29: return "w29";
case PhysReg::W30: return "w30";
// 64位寄存器
case PhysReg::X0: return "x0";
case PhysReg::X1: return "x1";
case PhysReg::X2: return "x2";
case PhysReg::X3: return "x3";
case PhysReg::X4: return "x4";
case PhysReg::X5: return "x5";
case PhysReg::X6: return "x6";
case PhysReg::X7: return "x7";
case PhysReg::X8: return "x8";
case PhysReg::X9: return "x9";
case PhysReg::X10: return "x10"; // 添加
case PhysReg::X11: return "x11"; // 添加
case PhysReg::X12: return "x12"; // 添加
case PhysReg::X13: return "x13"; // 添加
case PhysReg::X14: return "x14"; // 添加
case PhysReg::X15: return "x15"; // 添加
case PhysReg::X16: return "x16"; // 添加
case PhysReg::X17: return "x17"; // 添加
case PhysReg::X18: return "x18"; // 添加
case PhysReg::X19: return "x19"; // 添加
case PhysReg::X20: return "x20"; // 添加
case PhysReg::X21: return "x21"; // 添加
case PhysReg::X22: return "x22"; // 添加
case PhysReg::X23: return "x23"; // 添加
case PhysReg::X24: return "x24"; // 添加
case PhysReg::X25: return "x25"; // 添加
case PhysReg::X26: return "x26"; // 添加
case PhysReg::X27: return "x27"; // 添加
case PhysReg::X28: return "x28"; // 添加
case PhysReg::X29: return "x29";
case PhysReg::X30: return "x30";
// 浮点寄存器
case PhysReg::S0: return "s0";
case PhysReg::S1: return "s1";
case PhysReg::S2: return "s2";
case PhysReg::S3: return "s3";
case PhysReg::S4: return "s4";
case PhysReg::S5: return "s5";
case PhysReg::S6: return "s6";
case PhysReg::S7: return "s7";
// 特殊寄存器
case PhysReg::SP: return "sp";
case PhysReg::ZR: return "xzr";
default: return "unknown";
case PhysReg::W0:
return "w0";
case PhysReg::W8:
return "w8";
case PhysReg::W9:
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
case PhysReg::SP:
return "sp";
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
}
const char* CondCodeName(CondCode cc) {
switch (cc) {
case CondCode::EQ: return "eq";
case CondCode::NE: return "ne";
case CondCode::CS: return "cs";
case CondCode::CC: return "cc";
case CondCode::MI: return "mi";
case CondCode::PL: return "pl";
case CondCode::VS: return "vs";
case CondCode::VC: return "vc";
case CondCode::HI: return "hi";
case CondCode::LS: return "ls";
case CondCode::GE: return "ge";
case CondCode::LT: return "lt";
case CondCode::GT: return "gt";
case CondCode::LE: return "le";
case CondCode::AL: return "al";
default: return "unknown";
}
throw std::runtime_error(FormatError("mir", "未知条件码"));
}
} // namespace mir

@ -4,21 +4,11 @@
#include <stdexcept>
#include <string>
#include <sstream>
#include <iostream>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
//#define DEBUG_SEMA
#ifdef DEBUG_SEMA
#include <iostream>
#define DEBUG_MSG(msg) std::cerr << "[Sema Debug] " << msg << std::endl
#else
#define DEBUG_MSG(msg)
#endif
namespace {
// 获取左值名称的辅助函数
@ -45,6 +35,7 @@ public:
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
table_.enterScope(); // 创建全局作用域
for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用)
CollectFunctionDeclaration(func);
}
@ -55,6 +46,7 @@ public:
if (func) func->accept(this);
}
CheckMainFunction(); // 检查 main 函数存在且正确
table_.exitScope(); // 退出全局作用域
return {};
}
@ -77,9 +69,10 @@ public:
} else {
return_type = ir::Type::GetInt32Type();
}
DEBUG_MSG("[DEBUG] 进入函数: " << name
std::cout << "[DEBUG] 进入函数: " << name
<< " 返回类型: " << (return_type->IsInt32() ? "int" :
return_type->IsFloat() ? "float" : "void"));
return_type->IsFloat() ? "float" : "void")
<< std::endl;
// 记录当前函数返回类型(用于 return 检查)
current_func_return_type_ = return_type;
@ -92,9 +85,10 @@ public:
if (ctx->block()) { // 处理函数体
ctx->block()->accept(this);
}
DEBUG_MSG("[DEBUG] 函数 " << name
std::cout << "[DEBUG] 函数 " << name
<< " has_return: " << current_func_has_return_
<< " return_type_is_void: " << return_type->IsVoid());
<< " return_type_is_void: " << return_type->IsVoid()
<< std::endl;
if (!return_type->IsVoid() && !current_func_has_return_) { // 检查非 void 函数是否有 return
throw std::runtime_error(FormatError("sema", "非 void 函数 " + name + " 缺少 return 语句"));
}
@ -178,10 +172,10 @@ public:
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
// 调试输出
DEBUG_MSG("[DEBUG] CheckVarDef: " << name
std::cout << "[DEBUG] CheckVarDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size());
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
// 处理数组维度
for (auto* dim_exp : ctx->constExp()) {
@ -193,24 +187,26 @@ public:
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
DEBUG_MSG("[DEBUG] dim[" << dims.size() - 1 << "] = " << dim);
std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
// 创建数组类型
type = ir::Type::GetArrayType(base_type, dims);
DEBUG_MSG("[DEBUG] 创建数组类型完成");
DEBUG_MSG("[DEBUG] type->IsArray(): " << type->IsArray());
DEBUG_MSG("[DEBUG] type->GetKind(): " << (int)type->GetKind());
std::cout << "[DEBUG] 创建数组类型完成" << std::endl;
std::cout << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl;
std::cout << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl;
// 验证数组类型
if (type->IsArray()) {
auto* arr_type = dynamic_cast<ir::ArrayType*>(type.get());
if (arr_type) {
DEBUG_MSG("[DEBUG] ArrayType dimensions: ");
std::cout << "[DEBUG] ArrayType dimensions: ";
for (int d : arr_type->GetDimensions()) {
DEBUG_MSG(d << " ");
std::cout << d << " ";
}
DEBUG_MSG("[DEBUG] Element type: "
std::cout << std::endl;
std::cout << "[DEBUG] Element type: "
<< (arr_type->GetElementType()->IsInt32() ? "int" :
arr_type->GetElementType()->IsFloat() ? "float" : "unknown"));
arr_type->GetElementType()->IsFloat() ? "float" : "unknown")
<< std::endl;
}
}
}
@ -236,162 +232,11 @@ public:
sym.param_types.clear(); // 确保不混淆
}
table_.addSymbol(sym); // 添加到符号表
DEBUG_MSG("[DEBUG] 符号添加完成: " << name
std::cout << "[DEBUG] 符号添加完成: " << name
<< " type_kind: " << (int)sym.type->GetKind()
<< " is_array: " << sym.type->IsArray()
);
}
void CheckConstDef(SysYParser::ConstDefContext* ctx,
std::shared_ptr<ir::Type> base_type) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法常量定义"));
}
std::string name = ctx->Ident()->getText();
if (table_.lookupCurrent(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
// 确定类型
std::shared_ptr<ir::Type> type = base_type;
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
DEBUG_MSG("[DEBUG] CheckConstDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size());
if (is_array) {
for (auto* dim_exp : ctx->constExp()) {
int dim = table_.EvaluateConstExp(dim_exp);
if (dim <= 0) {
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
DEBUG_MSG("[DEBUG] dim[" << dims.size() - 1 << "] = " << dim);
}
type = ir::Type::GetArrayType(base_type, dims);
DEBUG_MSG("[DEBUG] 创建数组类型完成IsArray: " << type->IsArray());
}
// ========== 绑定维度表达式 ==========
for (auto* dim_exp : ctx->constExp()) {
dim_exp->addExp()->accept(this);
}
// 求值初始化器
std::vector<SymbolTable::ConstValue> init_values;
if (ctx->constInitVal()) {
// ========== 绑定初始化表达式 ==========
BindConstInitVal(ctx->constInitVal());
init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type);
DEBUG_MSG("[DEBUG] 初始化值数量: " << init_values.size());
}
// 计算期望的元素数量
size_t expected_count = 1;
if (is_array) {
expected_count = 1;
for (int d : dims) expected_count *= d;
DEBUG_MSG("[DEBUG] 期望元素数量: " << expected_count);
}
// 如果初始化值不足,补零
if (is_array && init_values.size() < expected_count) {
DEBUG_MSG("[DEBUG] 初始化值不足,补零");
SymbolTable::ConstValue zero;
if (base_type->IsInt32()) {
zero.kind = SymbolTable::ConstValue::INT;
zero.int_val = 0;
} else {
zero.kind = SymbolTable::ConstValue::FLOAT;
zero.float_val = 0.0f;
}
init_values.resize(expected_count, zero);
}
// 检查初始化值数量
if (init_values.size() > expected_count) {
throw std::runtime_error(FormatError("sema", "初始化值过多"));
}
// 创建符号
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Constant;
DEBUG_MSG("CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind);
sym.type = type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = true;
sym.var_def_ctx = nullptr;
sym.const_def_ctx = ctx;
DEBUG_MSG("保存常量定义上下文: " << name << ", ctx: " << ctx);
// ========== 存储常量值 ==========
if (is_array) {
// 存储数组常量(扁平化存储)
sym.is_array_const = true;
sym.array_const_values.clear();
for (const auto& val : init_values) {
Symbol::ConstantValue cv;
if (val.kind == SymbolTable::ConstValue::INT) {
cv.i32 = val.int_val;
} else {
cv.f32 = val.float_val;
}
sym.array_const_values.push_back(cv);
}
DEBUG_MSG("[DEBUG] 存储数组常量,共 " << sym.array_const_values.size()
<< " 个元素");
} else if (!init_values.empty()) {
// 存储标量常量
if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) {
sym.is_int_const = true;
sym.const_value.i32 = init_values[0].int_val;
DEBUG_MSG("[DEBUG] 存储整型常量: " << init_values[0].int_val);
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
sym.is_int_const = false;
sym.const_value.f32 = init_values[0].float_val;
DEBUG_MSG("[DEBUG] 存储浮点常量: " << init_values[0].float_val);
} else if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
// 整型常量用浮点数初始化(需要检查是否为整数)
float f = init_values[0].float_val;
int i = static_cast<int>(f);
if (std::abs(f - i) > 1e-6) {
throw std::runtime_error(FormatError("sema",
"整型常量不能用非整数值的浮点数初始化: " + std::to_string(f)));
}
sym.is_int_const = true;
sym.const_value.i32 = i;
DEBUG_MSG("[DEBUG] 浮点转整型常量: " << f << " -> " << i);
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::INT) {
// 浮点常量用整型初始化,隐式转换
sym.is_int_const = false;
sym.const_value.f32 = static_cast<float>(init_values[0].int_val);
DEBUG_MSG("[DEBUG] 整型转浮点常量: " << init_values[0].int_val
<< " -> " << static_cast<float>(init_values[0].int_val));
}
} else {
// 没有初始化值,对于标量常量这是错误的
if (!is_array) {
throw std::runtime_error(FormatError("sema", "常量必须有初始化值: " + name));
}
DEBUG_MSG("[DEBUG] 数组常量无初始化器,将全部补零");
<< std::endl;
}
table_.addSymbol(sym);
DEBUG_MSG("CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind);
auto* stored = table_.lookup(name);
DEBUG_MSG("CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx);
DEBUG_MSG("[DEBUG] 常量符号添加完成: " << name
<< " is_array_const: " << sym.is_array_const
<< " element_count: " << sym.array_const_values.size());
}
// ==================== 常量声明 ====================
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
@ -407,25 +252,111 @@ public:
return {};
}
void CheckConstDef(SysYParser::ConstDefContext* ctx,
std::shared_ptr<ir::Type> base_type) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法常量定义"));
}
std::string name = ctx->Ident()->getText();
if (table_.lookupCurrent(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
// 确定类型
std::shared_ptr<ir::Type> type = base_type;
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
std::cout << "[DEBUG] CheckConstDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
for (auto* dim_exp : ctx->constExp()) {
int dim = table_.EvaluateConstExp(dim_exp);
if (dim <= 0) {
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
type = ir::Type::GetArrayType(base_type, dims);
std::cout << "[DEBUG] 创建数组类型完成IsArray: " << type->IsArray() << std::endl;
}
// ========== 绑定维度表达式 ==========
for (auto* dim_exp : ctx->constExp()) {
dim_exp->addExp()->accept(this);
}
// 求值初始化器
std::vector<SymbolTable::ConstValue> init_values;
if (ctx->constInitVal()) {
// ========== 绑定初始化表达式 ==========
BindConstInitVal(ctx->constInitVal());
init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type);
std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl;
}
// 检查初始化值数量
size_t expected_count = 1;
if (is_array) {
expected_count = 1;
for (int d : dims) expected_count *= d;
std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl;
}
if (init_values.size() > expected_count) {
throw std::runtime_error(FormatError("sema", "初始化值过多"));
}
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Constant;
std::cout << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl;
sym.type = type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = true;
sym.var_def_ctx = nullptr;
sym.const_def_ctx = ctx;
sym.const_def_ctx = ctx;
std::cout << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl;
// 存储常量值(仅对非数组有效)
if (!is_array && !init_values.empty()) {
if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) {
sym.is_int_const = true;
sym.const_value.i32 = init_values[0].int_val;
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
sym.is_int_const = false;
sym.const_value.f32 = init_values[0].float_val;
}
} else if (is_array) {
std::cout << "[DEBUG] 数组常量,不存储单个常量值" << std::endl;
}
table_.addSymbol(sym);
std::cout << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl;
auto* stored = table_.lookup(name);
std::cout << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl;
std::cout << "[DEBUG] 常量符号添加完成" << std::endl;
}
// ==================== 语句语义检查 ====================
// 处理所有语句 - 通过运行时类型判断
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) return {};
// 调试输出
DEBUG_MSG("[DEBUG] visitStmt: ");
if (ctx->Return()) DEBUG_MSG("Return ");
if (ctx->If()) DEBUG_MSG("If ");
if (ctx->While()) DEBUG_MSG("While ");
if (ctx->Break()) DEBUG_MSG("Break ");
if (ctx->Continue()) DEBUG_MSG("Continue ");
if (ctx->lVal() && ctx->Assign()) DEBUG_MSG("Assign ");
if (ctx->exp() && ctx->Semi()) DEBUG_MSG("ExpStmt ");
if (ctx->block()) DEBUG_MSG("Block ");
std::cout << "[DEBUG] visitStmt: ";
if (ctx->Return()) std::cout << "Return ";
if (ctx->If()) std::cout << "If ";
if (ctx->While()) std::cout << "While ";
if (ctx->Break()) std::cout << "Break ";
if (ctx->Continue()) std::cout << "Continue ";
if (ctx->lVal() && ctx->Assign()) std::cout << "Assign ";
if (ctx->exp() && ctx->Semi()) std::cout << "ExpStmt ";
if (ctx->block()) std::cout << "Block ";
std::cout << std::endl;
// 判断语句类型 - 注意Return() 返回的是 TerminalNode*
if (ctx->Return() != nullptr) {
// return 语句
DEBUG_MSG("[DEBUG] 检测到 return 语句");
std::cout << "[DEBUG] 检测到 return 语句" << std::endl;
return visitReturnStmtInternal(ctx);
} else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) {
// 赋值语句
@ -454,14 +385,14 @@ public:
// return 语句内部实现
std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) {
DEBUG_MSG("[DEBUG] visitReturnStmtInternal 被调用");
std::cout << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl;
std::shared_ptr<ir::Type> expected = current_func_return_type_;
if (!expected) {
throw std::runtime_error(FormatError("sema", "return 语句不在函数体内"));
}
if (ctx->exp() != nullptr) {
// 有返回值的 return
DEBUG_MSG("[DEBUG] 有返回值的 return");
std::cout << "[DEBUG] 有返回值的 return" << std::endl;
ExprInfo ret_val = CheckExp(ctx->exp());
if (expected->IsVoid()) {
throw std::runtime_error(FormatError("sema", "void 函数不能返回值"));
@ -474,23 +405,23 @@ public:
}
// 设置 has_return 标志
current_func_has_return_ = true;
DEBUG_MSG("[DEBUG] 设置 current_func_has_return_ = true");
std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl;
} else {
// 无返回值的 return
DEBUG_MSG("[DEBUG] 无返回值的 return");
std::cout << "[DEBUG] 无返回值的 return" << std::endl;
if (!expected->IsVoid()) {
throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值"));
}
// 设置 has_return 标志
current_func_has_return_ = true;
DEBUG_MSG("[DEBUG] 设置 current_func_has_return_ = true");
std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl;
}
return {};
}
// 左值表达式(变量引用)
std::any visitLVal(SysYParser::LValContext* ctx) override {
DEBUG_MSG("[DEBUG] visitLVal: " << ctx->getText());
std::cout << "[DEBUG] visitLVal: " << ctx->getText() << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
@ -501,17 +432,17 @@ public:
}
// 检查数组访问
bool is_array_access = !ctx->exp().empty();
DEBUG_MSG("[DEBUG] name: " << name
std::cout << "[DEBUG] name: " << name
<< ", is_array_access: " << is_array_access
<< ", subscript_count: " << ctx->exp().size());
<< ", subscript_count: " << ctx->exp().size() << std::endl;
ExprInfo result;
// 判断是否为数组类型或指针类型(数组参数)
bool is_array_or_ptr = false;
if (sym->type) {
is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat();
DEBUG_MSG("[DEBUG] type_kind: " << (int)sym->type->GetKind()
std::cout << "[DEBUG] type_kind: " << (int)sym->type->GetKind()
<< ", is_array: " << sym->type->IsArray()
<< ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()));
<< ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) << std::endl;
}
if (is_array_or_ptr) {
@ -522,7 +453,7 @@ public:
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
dim_count = arr_type->GetDimensions().size();
elem_type = arr_type->GetElementType();
DEBUG_MSG("[DEBUG] 数组维度: " << dim_count);
std::cout << "[DEBUG] 数组维度: " << dim_count << std::endl;
}
} else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) {
dim_count = 1;
@ -531,12 +462,12 @@ public:
} else if (sym->type->IsPtrFloat()) {
elem_type = ir::Type::GetFloatType();
}
DEBUG_MSG("[DEBUG] 指针类型, dim_count: 1");
std::cout << "[DEBUG] 指针类型, dim_count: 1" << std::endl;
}
if (is_array_access) {
DEBUG_MSG("[DEBUG] 有下标访问,期望维度: " << dim_count
<< ", 实际下标数: " << ctx->exp().size());
std::cout << "[DEBUG] 有下标访问,期望维度: " << dim_count
<< ", 实际下标数: " << ctx->exp().size() << std::endl;
if (ctx->exp().size() != dim_count) {
throw std::runtime_error(FormatError("sema", "数组下标个数不匹配"));
}
@ -550,9 +481,9 @@ public:
result.is_lvalue = true;
result.is_const = false;
} else {
DEBUG_MSG("[DEBUG] 无下标访问");
std::cout << "[DEBUG] 无下标访问" << std::endl;
if (sym->type->IsArray()) {
DEBUG_MSG("[DEBUG] 数组名作为地址,转换为指针");
std::cout << "[DEBUG] 数组名作为地址,转换为指针" << std::endl;
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
if (arr_type->GetElementType()->IsInt32()) {
result.type = ir::Type::GetPtrInt32Type();
@ -674,7 +605,7 @@ public:
// 主表达式
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
DEBUG_MSG("[DEBUG] visitPrimaryExp: " << ctx->getText());
std::cout << "[DEBUG] visitPrimaryExp: " << ctx->getText() << std::endl;
ExprInfo result;
if (ctx->lVal()) { // 左值表达式
result = CheckLValue(ctx->lVal());
@ -706,14 +637,14 @@ public:
// 一元表达式
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
DEBUG_MSG("[DEBUG] visitUnaryExp: " << ctx->getText());
std::cout << "[DEBUG] visitUnaryExp: " << ctx->getText() << std::endl;
ExprInfo result;
if (ctx->primaryExp()) {
ctx->primaryExp()->accept(this);
auto* info = sema_.GetExprType(ctx->primaryExp());
if (info) result = *info;
} else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用
DEBUG_MSG("[DEBUG] 函数调用: " << ctx->Ident()->getText());
std::cout << "[DEBUG] 函数调用: " << ctx->Ident()->getText() << std::endl;
result = CheckFuncCall(ctx);
} else if (ctx->unaryOp()) { // 一元运算
ctx->unaryExp()->accept(this);
@ -1073,27 +1004,9 @@ public:
sema_.SetExprType(ctx, result);
return {};
}
// 新增:获取符号表
SymbolTable TakeSymbolTable() { return std::move(table_); }
SemanticContext TakeSemanticContext() { return std::move(sema_); }
// 新增:同时返回两者
SemaResult TakeResult() {
DEBUG_MSG("[DEBUG] TakeResult 前: 符号表作用域数量 = "
<< table_.getScopeCount());
// 可选:打印符号表内容
// table_.dump();
SemaResult result;
result.context = std::move(sema_);
result.symbol_table = std::move(table_);
DEBUG_MSG("[DEBUG] TakeResult 后: 符号表作用域数量 = "
<< result.symbol_table.getScopeCount());
return result;
}
// 获取语义上下文
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
SymbolTable table_;
@ -1107,11 +1020,12 @@ private:
bool current_func_has_return_ = false;
// ==================== 辅助函数 ====================
ExprInfo CheckExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "无效表达式"));
}
DEBUG_MSG("[DEBUG] CheckExp: " << ctx->getText());
std::cout << "[DEBUG] CheckExp: " << ctx->getText() << std::endl;
ctx->addExp()->accept(this);
auto* info = sema_.GetExprType(ctx->addExp());
if (!info) {
@ -1162,21 +1076,21 @@ private:
if (!sym) {
throw std::runtime_error(FormatError("sema", "未定义的变量: " + name));
}
DEBUG_MSG("CheckLValue: found sym->name = " << sym->name
<< ", sym->kind = " << (int)sym->kind);
std::cout << "CheckLValue: found sym->name = " << sym->name
<< ", sym->kind = " << (int)sym->kind << std::endl;
if (sym->kind == SymbolKind::Variable && sym->var_def_ctx) {
sema_.BindVarUse(ctx, sym->var_def_ctx);
DEBUG_MSG("绑定变量: " << name << " -> VarDefContext");
std::cout << "绑定变量: " << name << " -> VarDefContext" << std::endl;
}
else if (sym->kind == SymbolKind::Constant && sym->const_def_ctx) {
sema_.BindConstUse(ctx, sym->const_def_ctx);
DEBUG_MSG("绑定常量: " << name << " -> ConstDefContext");
std::cout << "绑定常量: " << name << " -> ConstDefContext" << std::endl;
}
DEBUG_MSG("CheckLValue 绑定变量: " << name
std::cout << "CheckLValue 绑定变量: " << name
<< ", sym->kind: " << (int)sym->kind
<< ", sym->var_def_ctx: " << sym->var_def_ctx
<< ", sym->const_def_ctx: " << sym->const_def_ctx);
<< ", sym->const_def_ctx: " << sym->const_def_ctx << std::endl;
bool is_array_access = !ctx->exp().empty();
bool is_const = (sym->kind == SymbolKind::Constant);
@ -1208,8 +1122,9 @@ private:
} else if (sym->type->IsPtrFloat()) {
elem_type = ir::Type::GetFloatType();
}
DEBUG_MSG("数组参数维度: " << dim_count << " 维, dims: ");
for (int d : dims) DEBUG_MSG(d << " ");
std::cout << "数组参数维度: " << dim_count << " 维, dims: ";
for (int d : dims) std::cout << d << " ";
std::cout << std::endl;
} else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) {
// 普通指针,只能有一个下标
dim_count = 1;
@ -1222,7 +1137,7 @@ private:
size_t subscript_count = ctx->exp().size();
DEBUG_MSG("dim_count: " << dim_count << ", subscript_count: " << subscript_count);
std::cout << "dim_count: " << dim_count << ", subscript_count: " << subscript_count << std::endl;
if (dim_count > 0 || sym->is_array_param || sym->type->IsArray() ||
sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) {
@ -1243,11 +1158,11 @@ private:
if (subscript_count == dim_count) {
// 完全索引,返回元素类型
DEBUG_MSG("完全索引,返回元素类型");
std::cout << "完全索引,返回元素类型" << std::endl;
return {elem_type, true, false};
} else {
// 部分索引,返回子数组的指针类型
DEBUG_MSG("部分索引,返回指针类型");
std::cout << "部分索引,返回指针类型" << std::endl;
// 计算剩余维度的指针类型
if (elem_type->IsInt32()) {
return {ir::Type::GetPtrInt32Type(), false, false};
@ -1261,7 +1176,7 @@ private:
// 没有下标访问
if (sym->type && sym->type->IsArray()) {
// 数组名作为地址
DEBUG_MSG("数组名作为地址");
std::cout << "数组名作为地址" << std::endl;
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
if (arr_type->GetElementType()->IsInt32()) {
return {ir::Type::GetPtrInt32Type(), false, true};
@ -1272,7 +1187,7 @@ private:
return {ir::Type::GetPtrInt32Type(), false, true};
} else if (sym->is_array_param) {
// 数组参数名作为地址
DEBUG_MSG("数组参数名作为地址");
std::cout << "数组参数名作为地址" << std::endl;
if (sym->type->IsPtrInt32()) {
return {ir::Type::GetPtrInt32Type(), false, true};
} else {
@ -1296,14 +1211,14 @@ private:
throw std::runtime_error(FormatError("sema", "非法函数调用"));
}
std::string func_name = ctx->Ident()->getText();
DEBUG_MSG("[DEBUG] CheckFuncCall: " << func_name);
std::cout << "[DEBUG] CheckFuncCall: " << func_name << std::endl;
auto* func_sym = table_.lookup(func_name);
if (!func_sym || func_sym->kind != SymbolKind::Function) {
throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name));
}
std::vector<ExprInfo> args;
if (ctx->funcRParams()) {
DEBUG_MSG("[DEBUG] 处理函数调用参数:");
std::cout << "[DEBUG] 处理函数调用参数:" << std::endl;
for (auto* exp : ctx->funcRParams()->exp()) {
if (exp) {
args.push_back(CheckExp(exp));
@ -1314,8 +1229,8 @@ private:
throw std::runtime_error(FormatError("sema", "参数个数不匹配"));
}
for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) {
DEBUG_MSG("[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind()
<< " 形参类型 " << (int)func_sym->param_types[i]->GetKind());
std::cout << "[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind()
<< " 形参类型 " << (int)func_sym->param_types[i]->GetKind() << std::endl;
if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) {
throw std::runtime_error(FormatError("sema", "参数类型不匹配"));
}
@ -1515,8 +1430,10 @@ private:
sym.array_dims = dims;
table_.addSymbol(sym);
DEBUG_MSG("[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind());
for (int d : dims) DEBUG_MSG(d << " ");
std::cout << "[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind()
<< " is_array: " << is_array << " dims: ";
for (int d : dims) std::cout << d << " ";
std::cout << std::endl;
}
}
@ -1580,10 +1497,9 @@ private:
} // namespace
// 修改 RunSema 函数,使其返回 SemaResult 结构体,包含符号表和语义上下文
SemaResult RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
// 直接返回 TakeResult(),利用移动语义
return visitor.TakeResult();
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
SemanticContext ctx = visitor.TakeSemanticContext();
return ctx;
}

@ -4,9 +4,8 @@
#include <stdexcept>
#include <string>
#include <cmath>
#include <functional>
//#define DEBUG_SYMBOL_TABLE
#define DEBUG_SYMBOL_TABLE
#ifdef DEBUG_SYMBOL_TABLE
#include <iostream>
@ -18,39 +17,35 @@
// ---------- 构造函数 ----------
SymbolTable::SymbolTable() {
scopes_.emplace_back(); // 初始化全局作用域
active_scope_stack_.push_back(0);
registerBuiltinFunctions(); // 注册内置库函数
}
// ---------- 作用域管理 ----------
void SymbolTable::enterScope() {
scopes_.emplace_back();
active_scope_stack_.push_back(scopes_.size() - 1);
}
void SymbolTable::exitScope() {
if (active_scope_stack_.size() > 1) {
active_scope_stack_.pop_back();
if (scopes_.size() > 1) {
scopes_.pop_back();
}
// 不能退出全局作用域
}
// ---------- 符号添加与查找 ----------
bool SymbolTable::addSymbol(const Symbol& sym) {
auto& current_scope = scopes_[active_scope_stack_.back()];
auto& current_scope = scopes_.back();
if (current_scope.find(sym.name) != current_scope.end()) {
return false; // 重复定义
}
Symbol stored_sym = sym;
stored_sym.scope_level = currentScopeLevel();
current_scope[sym.name] = stored_sym;
current_scope[sym.name] = sym;
// 立即验证存储的符号
const auto& stored = current_scope[sym.name];
DEBUG_MSG("SymbolTable::addSymbol: stored " << sym.name
std::cout << "SymbolTable::addSymbol: stored " << sym.name
<< " with kind=" << (int)stored.kind
<< ", const_def_ctx=" << stored.const_def_ctx);
<< ", const_def_ctx=" << stored.const_def_ctx
<< std::endl;
return true;
}
@ -64,14 +59,16 @@ Symbol* SymbolTable::lookupCurrent(const std::string& name) {
}
const Symbol* SymbolTable::lookup(const std::string& name) const {
for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) {
const auto& scope = scopes_[*it];
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
const auto& scope = *it;
auto found = scope.find(name);
if (found != scope.end()) {
DEBUG_MSG("SymbolTable::lookup: found " << name
<< " in active scope index " << *it
std::cout << "SymbolTable::lookup: found " << name
<< " in scope level " << (scopes_.rend() - it - 1)
<< ", kind=" << (int)found->second.kind
<< ", const_def_ctx=" << found->second.const_def_ctx);
<< ", const_def_ctx=" << found->second.const_def_ctx
<< std::endl;
return &found->second;
}
}
@ -79,7 +76,7 @@ const Symbol* SymbolTable::lookup(const std::string& name) const {
}
const Symbol* SymbolTable::lookupCurrent(const std::string& name) const {
const auto& current_scope = scopes_[active_scope_stack_.back()];
const auto& current_scope = scopes_.back();
auto it = current_scope.find(name);
if (it != current_scope.end()) {
return &it->second;
@ -87,40 +84,6 @@ const Symbol* SymbolTable::lookupCurrent(const std::string& name) const {
return nullptr;
}
const Symbol* SymbolTable::lookupAll(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}
const Symbol* SymbolTable::lookupByVarDef(const SysYParser::VarDefContext* decl) const {
if (!decl) return nullptr;
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
for (const auto& [name, sym] : *it) {
if (sym.var_def_ctx == decl) {
return &sym;
}
}
}
return nullptr;
}
const Symbol* SymbolTable::lookupByConstDef(const SysYParser::ConstDefContext* decl) const {
if (!decl) return nullptr;
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
for (const auto& [name, sym] : *it) {
if (sym.const_def_ctx == decl) {
return &sym;
}
}
}
return nullptr;
}
// ---------- 兼容原接口 ----------
void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl) {
Symbol sym;
@ -133,9 +96,9 @@ void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl)
}
bool SymbolTable::Contains(const std::string& name) const {
for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) {
const auto& scope = scopes_[*it];
if (scope.find(name) != scope.end()) {
// const 方法不能修改 scopes_我们模拟查找
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
if (it->find(name) != it->end()) {
return true;
}
}
@ -143,10 +106,9 @@ bool SymbolTable::Contains(const std::string& name) const {
}
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) {
const auto& scope = scopes_[*it];
auto found = scope.find(name);
if (found != scope.end()) {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
// 只返回变量定义的上下文(函数等其他符号返回 nullptr
if (found->second.kind == SymbolKind::Variable) {
return found->second.var_def_ctx;
@ -421,7 +383,10 @@ SymbolTable::ConstValue SymbolTable::EvaluatePrimaryExp(SysYParser::PrimaryExpCo
auto lval = ctx->lVal();
if (!lval->Ident()) throw std::runtime_error("常量表达式求值:无效左值");
std::string name = lval->Ident()->getText();
DEBUG_MSG(" 左值标识符: " << name);
const Symbol* sym = lookup(name);
DEBUG_MSG(" 找到符号: kind=" << (int)sym->kind << ", value="
<< (sym->is_int_const ? std::to_string(sym->const_value.i32) : std::to_string(sym->const_value.f32)));
if (!sym) throw std::runtime_error("常量表达式求值:未定义的标识符 " + name);
if (sym->kind != SymbolKind::Constant)
throw std::runtime_error("常量表达式求值:标识符 " + name + " 不是常量");
@ -443,6 +408,7 @@ SymbolTable::ConstValue SymbolTable::EvaluatePrimaryExp(SysYParser::PrimaryExpCo
ConstValue val;
val.kind = ConstValue::FLOAT;
val.float_val = ParseFloatLiteral(text);
DEBUG_MSG(" 浮点字面量: " << text << " -> " << val.float_val);
return val;
}
else if (ctx->HEX_INT() || ctx->OCTAL_INT() || ctx->DECIMAL_INT() || ctx->ZERO()) {
@ -454,6 +420,7 @@ SymbolTable::ConstValue SymbolTable::EvaluatePrimaryExp(SysYParser::PrimaryExpCo
ConstValue val;
val.kind = ConstValue::INT;
val.int_val = static_cast<int>(ParseIntegerLiteral(text));
DEBUG_MSG(" 整数字面量: " << text << " -> " << val.int_val);
return val;
}
else if (ctx->exp()) {
@ -473,6 +440,8 @@ SymbolTable::ConstValue SymbolTable::EvaluateUnaryExp(SysYParser::UnaryExpContex
else if (ctx->unaryOp()) {
ConstValue operand = EvaluateUnaryExp(ctx->unaryExp());
std::string op = ctx->unaryOp()->getText();
DEBUG_MSG("EvaluateUnaryExp: 操作符=" << op);
DEBUG_MSG(" 操作数=" << (operand.kind==ConstValue::INT ? std::to_string(operand.int_val) : std::to_string(operand.float_val)));
if (op == "+") {
return operand;
@ -508,10 +477,15 @@ SymbolTable::ConstValue SymbolTable::EvaluateMulExp(SysYParser::MulExpContext* c
if (!ctx) throw std::runtime_error("常量表达式求值:无效 MulExp");
if (ctx->mulExp()) {
DEBUG_MSG("EvaluateMulExp: 左子表达式");
ConstValue left = EvaluateMulExp(ctx->mulExp());
DEBUG_MSG(" 左值=" << (left.kind==ConstValue::INT ? std::to_string(left.int_val) : std::to_string(left.float_val)));
ConstValue right = EvaluateUnaryExp(ctx->unaryExp());
std::string op;
DEBUG_MSG(" 运算符=" << op);
DEBUG_MSG(" 右值=" << (right.kind==ConstValue::INT ? std::to_string(right.int_val) : std::to_string(right.float_val)));
if (ctx->MulOp()) op = "*";
else if (ctx->DivOp()) op = "/";
else if (ctx->QuoOp()) op = "%";
@ -588,7 +562,9 @@ SymbolTable::ConstValue SymbolTable::EvaluateAddExp(SysYParser::AddExpContext* c
int SymbolTable::EvaluateConstExp(SysYParser::ConstExpContext* ctx) const {
if (!ctx || !ctx->addExp())
throw std::runtime_error("常量表达式求值:无效 ConstExp");
DEBUG_MSG("EvaluateConstExp: 表达式文本=" << ctx->getText());
ConstValue val = EvaluateAddExp(ctx->addExp());
DEBUG_MSG(" 求值结果: " << (val.kind==ConstValue::INT ? std::to_string(val.int_val) : std::to_string(val.float_val)));
if (val.kind == ConstValue::INT) {
return val.int_val;
} else {
@ -604,7 +580,9 @@ int SymbolTable::EvaluateConstExp(SysYParser::ConstExpContext* ctx) const {
float SymbolTable::EvaluateConstExpFloat(SysYParser::ConstExpContext* ctx) const {
if (!ctx || !ctx->addExp())
throw std::runtime_error("常量表达式求值:无效 ConstExp");
DEBUG_MSG("EvaluateConstExpFloat: 表达式文本=" << ctx->getText());
ConstValue val = EvaluateAddExp(ctx->addExp());
DEBUG_MSG(" 求值结果: " << (val.kind==ConstValue::INT ? std::to_string(val.int_val) : std::to_string(val.float_val)));
if (val.kind == ConstValue::INT) {
return static_cast<float>(val.int_val);
} else {
@ -612,47 +590,97 @@ float SymbolTable::EvaluateConstExpFloat(SysYParser::ConstExpContext* ctx) const
}
}
void SymbolTable::flattenInit(SysYParser::ConstInitValContext* ctx,
std::vector<ConstValue>& out,
std::shared_ptr<ir::Type> base_type) const {
if (!ctx) return;
// 获取当前初始化列表的文本(用于调试)
std::string ctxText;
if (ctx->constExp()) {
ctxText = ctx->constExp()->getText();
} else {
ctxText = "{ ... }";
}
if (ctx->constExp()) {
ConstValue val = EvaluateAddExp(ctx->constExp()->addExp());
DEBUG_MSG("处理常量表达式: " << ctxText
<< " 类型=" << (val.kind == ConstValue::INT ? "INT" : "FLOAT")
<< " 值=" << (val.kind == ConstValue::INT ? std::to_string(val.int_val) : std::to_string(val.float_val))
<< " 目标类型=" << (base_type->IsInt32() ? "Int32" : "Float"));
// 整型数组不能接受浮点常量
if (base_type->IsInt32() && val.kind == ConstValue::FLOAT) {
DEBUG_MSG("错误:整型数组遇到浮点常量,值=" << val.float_val);
throw std::runtime_error("常量初始化:整型数组不能使用浮点常量");
// 递归填充函数,按维度填充值
void SymbolTable::fillArray(
std::vector<ConstValue>& values, // 存储最终所有元素(行优先)
size_t& index, // 当前填充到的位置
SysYParser::ConstInitValContext* ctx, // 当前初始化列表节点
const std::vector<int>& dims, // 剩余维度
size_t dim_idx, // 当前维度索引
std::shared_ptr<ir::Type> base_type) const // 元素基本类型
{
DEBUG_MSG("fillArray: 进入dim_idx=" << dim_idx
<< ", 剩余维度=" << (dims.size() - dim_idx)
<< ", 当前index=" << index);
// 如果已经是最内层(单个元素)
if (dim_idx == dims.size()) {
// 必须是单个表达式
if (!ctx || !ctx->constExp()) {
throw std::runtime_error("初始化值不是标量");
}
// 浮点数组接受整型常量,并隐式转换
if (base_type->IsFloat() && val.kind == ConstValue::INT) {
DEBUG_MSG("浮点数组接收整型常量,隐式转换为浮点: " << val.int_val);
val.kind = ConstValue::FLOAT;
val.float_val = static_cast<float>(val.int_val);
ConstValue val = EvaluateAddExp(ctx->constExp()->addExp());
// 类型检查和转换...
DEBUG_MSG(" 填充标量值: index=" << index
<< ", 值=" << (val.kind == ConstValue::INT ? std::to_string(val.int_val) : std::to_string(val.float_val)));
values[index++] = val;
return;
}
// 当前维度的元素个数
size_t cur_dim_size = dims[dim_idx];
DEBUG_MSG(" 当前维度大小=" << cur_dim_size << ", 是否是花括号列表=" << (ctx && !ctx->constExp()));
// 如果是花括号列表,则按子项填充
if (ctx && !ctx->constExp()) { // 花括号
auto sub_vals = ctx->constInitVal();
DEBUG_MSG(" 花括号列表,子项数量=" << sub_vals.size());
// 对于每个子项,填充一个子数组
for (size_t i = 0; i < cur_dim_size; ++i) {
DEBUG_MSG(" 处理子项 " << i << " / " << cur_dim_size);
if (i < sub_vals.size()) {
fillArray(values, index, sub_vals[i], dims, dim_idx + 1, base_type);
} else {
// 子项不足,填充零
DEBUG_MSG(" 子项不足,填充零");
fillZero(values, index, dims, dim_idx + 1, base_type);
}
}
out.push_back(val);
} else {
DEBUG_MSG("进入花括号初始化列表: " << ctxText);
// 花括号初始化列表:递归展开所有子项
for (auto* sub : ctx->constInitVal()) {
flattenInit(sub, out, base_type);
// 不是花括号,即单个值,应视为对当前维度第一个元素的初始化,其余补零
// 第一个子数组
DEBUG_MSG(" 单个值(非花括号),将填充第一个子数组,其余补零");
if (ctx && ctx->constExp()) {
fillArray(values, index, ctx, dims, dim_idx + 1, base_type);
} else {
DEBUG_MSG(" 无初始化值,第一个子数组也补零");
fillZero(values, index, dims, dim_idx + 1, base_type);
}
// 剩余子数组补零
for (size_t i = 1; i < cur_dim_size; ++i) {
DEBUG_MSG(" 填充第" << i << "个子数组为零");
fillZero(values, index, dims, dim_idx + 1, base_type);
}
DEBUG_MSG("退出花括号初始化列表");
}
DEBUG_MSG("fillArray: 退出index=" << index);
}
// 填充指定数量的零
void SymbolTable::fillZero(std::vector<ConstValue>& values, size_t& index,
const std::vector<int>& dims, size_t dim_idx,
std::shared_ptr<ir::Type> base_type) const
{
DEBUG_MSG("fillZero: 进入dim_idx=" << dim_idx << ", 当前index=" << index);
if (dim_idx == dims.size()) {
ConstValue zero;
if (base_type->IsInt32()) {
zero.kind = ConstValue::INT;
zero.int_val = 0;
} else {
zero.kind = ConstValue::FLOAT;
zero.float_val = 0.0f;
}
DEBUG_MSG(" 填充零值: index=" << index);
values[index++] = zero;
return;
}
size_t cur_dim_size = dims[dim_idx];
DEBUG_MSG(" fillZero 当前维度大小=" << cur_dim_size);
for (size_t i = 0; i < cur_dim_size; ++i) {
DEBUG_MSG(" fillZero 子项 " << i);
fillZero(values, index, dims, dim_idx + 1, base_type);
}
DEBUG_MSG("fillZero: 退出");
}
std::vector<SymbolTable::ConstValue> SymbolTable::EvaluateConstInitVal(
@ -662,6 +690,7 @@ std::vector<SymbolTable::ConstValue> SymbolTable::EvaluateConstInitVal(
// ========== 1. 标量常量dims 为空)==========
if (dims.empty()) {
DEBUG_MSG(" 标量常量初始化,表达式=" << ctx->getText());
if (!ctx || !ctx->constExp()) {
throw std::runtime_error("标量常量初始化必须使用单个表达式");
}
@ -676,7 +705,7 @@ std::vector<SymbolTable::ConstValue> SymbolTable::EvaluateConstInitVal(
// 隐式类型转换
if (base_type->IsInt32() && val.kind == ConstValue::FLOAT) {
val.kind = ConstValue::INT;
val.int_val = static_cast<int>(val.float_val);
val.float_val = static_cast<int>(val.int_val);
}
if (base_type->IsFloat() && val.kind == ConstValue::INT) {
val.kind = ConstValue::FLOAT;
@ -686,89 +715,23 @@ std::vector<SymbolTable::ConstValue> SymbolTable::EvaluateConstInitVal(
}
// ========== 2. 数组常量dims 非空)==========
// 计算数组总元素个数
DEBUG_MSG("EvaluateConstInitVal: 开始,维度=" << dims.size());
size_t total = 1;
for (int d : dims) total *= d;
DEBUG_MSG(" 数组常量初始化,总元素数=" << total);
ConstValue zero;
if (base_type->IsInt32()) {
zero.kind = ConstValue::INT;
zero.int_val = 0;
} else {
zero.kind = ConstValue::FLOAT;
zero.float_val = 0.0f;
std::vector<ConstValue> values(total);
size_t index = 0;
fillArray(values, index, ctx, dims, 0, base_type);
DEBUG_MSG("EvaluateConstInitVal: 填充完成最终index=" << index);
// 可选:打印所有填充的值
for (size_t i = 0; i < values.size(); ++i) {
DEBUG_MSG(" values[" << i << "] = " << (values[i].kind == ConstValue::INT ? std::to_string(values[i].int_val) : std::to_string(values[i].float_val)));
}
// 先整体补零,再按 C 语言花括号规则覆盖显式初始化项。
std::vector<ConstValue> flat(total, zero);
auto convert_value = [&](ConstValue v) -> ConstValue {
if (base_type->IsInt32()) {
if (v.kind == ConstValue::FLOAT) {
throw std::runtime_error("常量初始化:整型数组不能使用浮点常量");
}
return v;
}
if (v.kind == ConstValue::INT) {
ConstValue t;
t.kind = ConstValue::FLOAT;
t.float_val = static_cast<float>(v.int_val);
return t;
}
return v;
};
auto subarray_span = [&](size_t depth) -> size_t {
size_t span = 1;
for (size_t i = depth + 1; i < dims.size(); ++i) span *= static_cast<size_t>(dims[i]);
return span;
};
std::function<size_t(SysYParser::ConstInitValContext*, size_t, size_t, size_t)> fill;
fill = [&](SysYParser::ConstInitValContext* node,
size_t depth,
size_t begin,
size_t end) -> size_t {
if (!node || begin >= end) return begin;
// 标量初始化项
if (node->constExp()) {
ConstValue v = convert_value(EvaluateAddExp(node->constExp()->addExp()));
if (begin < flat.size()) flat[begin] = v;
return std::min(begin + 1, end);
}
size_t cursor = begin;
for (auto* child : node->constInitVal()) {
if (cursor >= end) break;
if (child->constExp()) {
ConstValue v = convert_value(EvaluateAddExp(child->constExp()->addExp()));
if (cursor < flat.size()) flat[cursor] = v;
++cursor;
continue;
}
// 花括号子列表:在非最内层需要按子聚合边界对齐。
if (depth + 1 < dims.size()) {
const size_t span = subarray_span(depth);
const size_t rel = (cursor - begin) % span;
if (rel != 0) cursor += (span - rel);
if (cursor >= end) break;
const size_t sub_end = std::min(cursor + span, end);
fill(child, depth + 1, cursor, sub_end);
// 一个带花括号的子初始化器会消费当前层的一个子聚合。
cursor = sub_end;
} else {
// 最内层(标量数组)遇到额外花括号,按同层顺序展开。
cursor = fill(child, depth, cursor, end);
}
}
return cursor;
};
fill(ctx, 0, 0, total);
return flat;
return values;
}
int SymbolTable::EvaluateConstExpression(SysYParser::ExpContext* ctx) const {

@ -1,5 +0,0 @@
编译libsysy.a
```
aarch64-linux-gnu-gcc -c sylib.c -o sylib.o
aarch64-linux-gnu-ar rcs libsysy.a sylib.o
```

@ -1,162 +1,4 @@
#include "sylib.h"
#include <math.h>
#include <stdlib.h>
extern int scanf(const char* format, ...);
extern int printf(const char* format, ...);
extern int getchar(void);
extern int putchar(int c);
int getint(void) {
int x = 0;
scanf("%d", &x);
return x;
}
int getch(void) {
return getchar();
}
int getarray(int a[]) {
int n;
scanf("%d", &n);
int i = 0;
for (; i < n; ++i) {
scanf("%d", &a[i]);
}
return n;
}
float getfloat(void) {
float x = 0.0f;
scanf("%f", &x);
return x;
}
int getfarray(float a[]) {
int n = 0;
if (scanf("%d", &n) != 1) {
return 0;
}
int i = 0;
for (; i < n; ++i) {
if (scanf("%f", &a[i]) != 1) {
return i;
}
}
return n;
}
void putint(int x) {
printf("%d", x);
}
void putch(int x) {
putchar(x);
}
void putarray(int n, int a[]) {
int i = 0;
printf("%d:", n);
for (; i < n; ++i) {
printf(" %d", a[i]);
}
putchar('\n');
}
void putfloat(float x) {
printf("%a", x);
}
void putfarray(int n, float a[]) {
int i = 0;
printf("%d:", n);
for (; i < n; ++i) {
printf(" %a", a[i]);
}
putchar('\n');
}
void puts(int s[]) {
if (!s) return;
while (*s) {
putchar(*s);
++s;
}
}
void _sysy_starttime(int lineno) {
(void)lineno;
}
void _sysy_stoptime(int lineno) {
(void)lineno;
}
void starttime(void) {
_sysy_starttime(0);
}
void stoptime(void) {
_sysy_stoptime(0);
}
int* memset(int* ptr, int value, int count) {
unsigned char* p = (unsigned char*)ptr;
unsigned char byte = (unsigned char)(value & 0xFF);
int i = 0;
for (; i < count; ++i) {
p[i] = byte;
}
return ptr;
}
int* sysy_alloc_i32(int count) {
if (count <= 0) {
return 0;
}
return (int*)malloc((size_t)count * sizeof(int));
}
float* sysy_alloc_f32(int count) {
if (count <= 0) {
return 0;
}
return (float*)malloc((size_t)count * sizeof(float));
}
void sysy_free_i32(int* ptr) {
if (!ptr) {
return;
}
free(ptr);
}
void sysy_free_f32(float* ptr) {
if (!ptr) {
return;
}
free(ptr);
}
void sysy_zero_i32(int* ptr, int count) {
int i = 0;
if (!ptr || count <= 0) {
return;
}
for (; i < count; ++i) {
ptr[i] = 0;
}
}
void sysy_zero_f32(float* ptr, int count) {
int i = 0;
if (!ptr || count <= 0) {
return;
}
for (; i < count; ++i) {
ptr[i] = 0.0f;
}
}
// SysY 运行库实现:
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为

@ -1,29 +1,4 @@
#pragma once
int getint(void);
int getch(void);
int getarray(int a[]);
float getfloat(void);
int getfarray(float a[]);
void putint(int x);
void putch(int x);
void putarray(int n, int a[]);
void putfloat(float x);
void putfarray(int n, float a[]);
void puts(int s[]);
void _sysy_starttime(int lineno);
void _sysy_stoptime(int lineno);
void starttime(void);
void stoptime(void);
int read_map(void);
int* memset(int* ptr, int value, int count);
int* sysy_alloc_i32(int count);
float* sysy_alloc_f32(int count);
void sysy_free_i32(int* ptr);
void sysy_free_f32(float* ptr);
void sysy_zero_i32(int* ptr, int count);
void sysy_zero_f32(float* ptr, int count);
// SysY 运行库头文件:
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明

Loading…
Cancel
Save