<feature/ir:语义分析模块修改,实现符号表与IR生成板块信息互通。常量,浮点,数组支持。大数组堆分配,alloca栈分配提到入口块以提升性能避免栈溢出,all passed。测试脚本见/script/test_compiler.sh,由/script/verify_ir.sh衍生而来.可改进:可删除很多为了便于调试而插入的print语句>

feature/ir-final
LuoHello 1 week ago
parent c8f40ea09a
commit ec56841167

@ -109,18 +109,25 @@ class Context {
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
// 数组常量缓存需要添加到类中
struct ArrayKey {
std::shared_ptr<ArrayType> type;
std::vector<ConstantValue*> elements;
bool operator==(const ArrayKey& other) const;
};
// 浮点常量:使用整数表示浮点数位模式作为键(避免浮点精度问题)
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
struct ArrayKeyHash {
size_t operator()(const ArrayKey& key) const;
};
std::unordered_map<ArrayKey, std::unique_ptr<ConstantArray>, ArrayKeyHash> array_cache_;
// 零常量缓存(按类型指针)
// 其他现有成员...
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
std::unordered_map<Type*, std::unique_ptr<ConstantZero>> zero_constants_;
std::unordered_map<Type*, std::unique_ptr<ConstantAggregateZero>> aggregate_zeros_;
// 数组常量简单存储,不去重(因为数组常量通常组合多样,去重成本高)
std::vector<std::unique_ptr<ConstantArray>> const_arrays_;
int temp_index_ = -1;
};
@ -357,18 +364,18 @@ class User : public Value {
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
// ir/IR.h - 修正 GlobalValue 定义
// ir/IR.h - 修正 GlobalValue 定义
// ir/IR.h - GlobalValue 定义需要添加这些方法
class GlobalValue : public User {
private:
std::vector<ConstantValue*> initializer_; // 初始化值列表
bool is_constant_ = false; // 是否为常量如const变量
bool is_extern_ = false; // 是否为外部声明
std::vector<ConstantValue*> initializer_;
bool is_constant_ = false;
bool is_extern_ = false;
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
// 初始化器相关 - 使用 ConstantValue*
// 初始化器相关
void SetInitializer(ConstantValue* init);
void SetInitializer(const std::vector<ConstantValue*>& init);
const std::vector<ConstantValue*>& GetInitializer() const { return initializer_; }
@ -382,17 +389,28 @@ public:
void SetExtern(bool is_extern) { is_extern_ = is_extern; }
bool IsExtern() const { return is_extern_; }
// 类型判断 - 使用 Type 的方法
// 类型判断
bool IsArray() const { return GetType()->IsArray(); }
bool IsScalar() const { return GetType()->IsInt32() || GetType()->IsFloat(); }
// 数组常量相关方法
bool IsArrayConstant() const;
ConstantValue* GetArrayElement(size_t index) const;
size_t GetArraySize() const;
// 获取数组大小(如果是数组类型)
int GetArraySize() const {
int GetArraySizeInElements() const {
if (auto* array_ty = dynamic_cast<ArrayType*>(GetType().get())) {
return array_ty->GetElementCount();
}
return 0;
}
private:
// 辅助方法
std::shared_ptr<Type> GetValueType() const;
bool CheckTypeCompatibility(std::shared_ptr<Type> value_type,
ConstantValue* init) const;
};
class Instruction : public User {
@ -742,6 +760,20 @@ class BasicBlock : public Value {
return ptr;
}
template <typename T, typename... Args>
T* InsertBeforeTerminator(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
auto pos = instructions_.end();
if (HasTerminator()) {
pos = instructions_.end() - 1;
}
instructions_.insert(pos, std::move(inst));
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
@ -812,6 +844,7 @@ class IRBuilder {
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaFloat(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);

@ -7,6 +7,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
@ -22,7 +23,10 @@ class Value;
class IRGenImpl final : public SysYBaseVisitor {
public:
IRGenImpl(ir::Module& module, const SemanticContext& sema);
// 修改构造函数,添加 SymbolTable 参数
IRGenImpl(ir::Module& module,
const SemanticContext& sema,
const SymbolTable& sym_table); // 新增
// 顶层
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
@ -67,9 +71,21 @@ public:
ir::Value* EvalCond(SysYParser::CondContext& cond);
std::any visitCallExp(SysYParser::UnaryExpContext* ctx);
std::vector<ir::Value*> ProcessNestedInitVals(SysYParser::InitValContext* ctx);
// 带维度感知的展平:按 C 语言花括号对齐规则填充 total_size 个槽位
// dims[0] 是最外层维度dims.back() 是最内层维度(元素层)
// 返回已展平并补零的 total_size 大小的向量
std::vector<ir::Value*> FlattenInitVal(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
bool is_float);
int TryEvaluateConstInt(SysYParser::ConstExpContext* ctx);
void AddRuntimeFunctions();
ir::Function* CreateRuntimeFunctionDecl(const std::string& funcName);
ir::BasicBlock* EnsureCleanupBlock();
void RegisterCleanup(ir::Function* free_func, ir::Value* ptr);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name);
ir::AllocaInst* CreateEntryAllocaI32(const std::string& name);
ir::AllocaInst* CreateEntryAllocaFloat(const std::string& name);
private:
// 辅助函数声明
enum class BlockFlow{
@ -108,6 +124,7 @@ private:
ir::Module& module_;
const SemanticContext& sema_;
const SymbolTable& symbol_table_; // 新增成员
ir::Function* func_;
ir::IRBuilder builder_;
ir::Value* EvalAssign(SysYParser::StmtContext* ctx);
@ -119,6 +136,8 @@ private:
std::unordered_map<std::string, ir::Value*> local_var_map_; // 局部变量
std::unordered_map<std::string, ir::GlobalValue*> global_map_; // 全局变量
std::unordered_map<std::string, ir::Value*> param_map_; // 函数参数
std::unordered_set<std::string> pointer_param_names_; // 指针/数组形参名
std::unordered_set<std::string> heap_local_array_names_; // 堆分配的局部数组名
// 常量映射:常量名 -> 常量值(标量常量)
std::unordered_map<std::string, ir::ConstantValue*> const_value_map_;
@ -131,21 +150,23 @@ private:
std::unordered_map<SysYParser::VarDefContext*, ArrayInfo> array_info_map_;
std::string current_function_name_;
bool current_function_is_recursive_ = false;
ir::AllocaInst* function_return_slot_ = nullptr;
ir::BasicBlock* function_cleanup_block_ = nullptr;
std::vector<std::pair<ir::Function*, ir::Value*>> function_cleanup_actions_;
// 新增:处理全局和局部变量的辅助函数
// 修改处理函数的签名,使用 Symbol* 参数
std::any HandleGlobalVariable(SysYParser::VarDefContext* ctx,
const std::string& varName,
bool is_array);
std::any HandleLocalVariable(SysYParser::VarDefContext* ctx,
const std::string& varName,
bool is_array);
const std::string& varName,
const Symbol* sym);
// 常量求值辅助函数
int EvaluateConstAddExp(SysYParser::AddExpContext* ctx);
int EvaluateConstMulExp(SysYParser::MulExpContext* ctx);
int EvaluateConstUnaryExp(SysYParser::UnaryExpContext* ctx);
int EvaluateConstPrimaryExp(SysYParser::PrimaryExpContext* ctx);
int EvaluateConstExp(SysYParser::ExpContext* ctx);
std::any HandleLocalVariable(SysYParser::VarDefContext* ctx,
const std::string& varName,
const Symbol* sym);
};
// 修改 GenerateIR 函数签名
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema);
const SemaResult& sema_result);

@ -7,7 +7,7 @@
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/SymbolTable.h"
// 表达式信息结构
struct ExprInfo {
std::shared_ptr<ir::Type> type = nullptr;
@ -91,4 +91,12 @@ private:
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
// SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
// 新增:语义分析结果结构体
struct SemaResult {
SemanticContext context;
SymbolTable symbol_table;
};
// 修改 RunSema 的返回类型
SemaResult RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,6 +1,7 @@
// 极简符号表:记录局部变量定义点。
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
@ -17,46 +18,113 @@ enum class SymbolKind {
Constant
};
// 符号条目
// 符号条目
struct Symbol {
// 基本信息
std::string name;
SymbolKind kind;
std::shared_ptr<ir::Type> type; // 指向 Type 对象的智能指针
int scope_level = 0; // 定义时的作用域深度
int stack_offset = -1; // 局部变量/参数栈偏移(全局变量为 -1
bool is_initialized = false; // 是否已初始化
bool is_builtin = false; // 是否为库函数
std::shared_ptr<ir::Type> type;
int scope_level = 0;
int stack_offset = -1;
bool is_initialized = false;
bool is_builtin = false;
// 对于数组参数,存储维度信息
std::vector<int> array_dims; // 数组各维长度参数数组的第一维可能为0表示省略
bool is_array_param = false; // 是否是数组参数
// 数组参数相关
std::vector<int> array_dims;
bool is_array_param = false;
// 对于函数,额外存储参数列表(类型已包含在函数类型中,这里仅用于快速访问)
// 函数相关
std::vector<std::shared_ptr<ir::Type>> param_types;
// 对于常量,存储常量值(这里支持 int32 和 float
// 常量值存储
union ConstantValue {
int i32;
float f32;
} const_value;
bool is_int_const = true; // 标记常量类型,用于区分 int 和 float
};
// 标量常量
bool is_int_const = true;
ConstantValue const_value;
// 数组常量(扁平化存储)
bool is_array_const = false;
std::vector<ConstantValue> array_const_values;
// 关联的语法树节点(用于报错位置或进一步分析)
// 语法树节点
SysYParser::VarDefContext* var_def_ctx = nullptr;
SysYParser::ConstDefContext* const_def_ctx = nullptr;
SysYParser::FuncFParamContext* param_def_ctx = nullptr;
SysYParser::FuncDefContext* func_def_ctx = nullptr;
// 辅助方法
bool IsScalarConstant() const {
return kind == SymbolKind::Constant && !type->IsArray();
}
bool IsArrayConstant() const {
return kind == SymbolKind::Constant && type->IsArray();
}
int GetIntConstant() const {
if (!IsScalarConstant()) {
throw std::runtime_error("不是标量常量");
}
if (!is_int_const) {
throw std::runtime_error("不是整型常量");
}
return const_value.i32;
}
float GetFloatConstant() const {
if (!IsScalarConstant()) {
throw std::runtime_error("不是标量常量");
}
if (is_int_const) {
return static_cast<float>(const_value.i32);
}
return const_value.f32;
}
ConstantValue GetArrayElement(size_t index) const {
if (!IsArrayConstant()) {
throw std::runtime_error("不是数组常量");
}
if (index >= array_const_values.size()) {
throw std::runtime_error("数组下标越界");
}
return array_const_values[index];
}
size_t GetArraySize() const {
if (!IsArrayConstant()) return 0;
return array_const_values.size();
}
};
class SymbolTable {
public:
SymbolTable();
~SymbolTable() = default;
// 添加调试方法
size_t getScopeCount() const { return active_scope_stack_.size(); }
void dump() const {
std::cerr << "=== SymbolTable Dump ===" << std::endl;
for (size_t i = 0; i < scopes_.size(); ++i) {
std::cerr << "Scope " << i << " (depth=" << i << ")";
bool active = std::find(active_scope_stack_.begin(), active_scope_stack_.end(), i) != active_scope_stack_.end();
std::cerr << (active ? " [active]" : " [inactive]") << std::endl;
for (const auto& [name, sym] : scopes_[i]) {
std::cerr << " " << name
<< " (kind=" << (int)sym.kind
<< ", level=" << sym.scope_level << ")" << std::endl;
}
}
}
// ----- 作用域管理 -----
void enterScope(); // 进入新作用域
void exitScope(); // 退出当前作用域
int currentScopeLevel() const { return static_cast<int>(scopes_.size()) - 1; }
int currentScopeLevel() const { return static_cast<int>(active_scope_stack_.size()) - 1; }
// ----- 符号操作(推荐使用)-----
bool addSymbol(const Symbol& sym); // 添加符号到当前作用域
@ -64,6 +132,9 @@ class SymbolTable {
Symbol* lookupCurrent(const std::string& name); // 仅在当前作用域查找
const Symbol* lookup(const std::string& name) const;
const Symbol* lookupCurrent(const std::string& name) const;
const Symbol* lookupAll(const std::string& name) const; // 所有作用域查找,包括已结束的作用域
const Symbol* lookupByVarDef(const SysYParser::VarDefContext* decl) const; // 通过定义节点查找符号
const Symbol* lookupByConstDef(const SysYParser::ConstDefContext* decl) const; // 通过常量定义节点查找符号
// ----- 与原接口兼容(保留原有功能)-----
void Add(const std::string& name, SysYParser::VarDefContext* decl);
@ -103,6 +174,7 @@ class SymbolTable {
private:
// 作用域栈:每个元素是一个从名字到符号的映射
std::vector<std::unordered_map<std::string, Symbol>> scopes_;
std::vector<size_t> active_scope_stack_;
static constexpr int GLOBAL_SCOPE = 0; // 全局作用域索引

Binary file not shown.

@ -0,0 +1,492 @@
; ModuleID = 'optimized.bc'
source_filename = "./build/test_compiler/performance/03_sort1.ll"
@a = global [30000010 x i32] zeroinitializer
@ans = local_unnamed_addr global i32 0
declare i32 @getarray(ptr) local_unnamed_addr
declare void @putint(i32) local_unnamed_addr
declare void @putch(i32) local_unnamed_addr
declare void @starttime() local_unnamed_addr
declare void @stoptime() local_unnamed_addr
declare ptr @sysy_alloc_i32(i32) local_unnamed_addr
declare void @sysy_free_i32(ptr) local_unnamed_addr
declare void @sysy_zero_i32(ptr, i32) local_unnamed_addr
; Function Attrs: nofree norecurse nosync nounwind memory(argmem: read)
define i32 @getMaxNum(i32 %n, ptr nocapture readonly %arr) local_unnamed_addr #0 {
entry:
%t95 = icmp sgt i32 %n, 0
br i1 %t95, label %while.body.t5, label %while.exit.t6
while.body.t5: ; preds = %entry, %while.body.t5
%t3_i.07 = phi i32 [ %t21, %while.body.t5 ], [ 0, %entry ]
%t2_ret.06 = phi i32 [ %spec.select, %while.body.t5 ], [ 0, %entry ]
%0 = zext nneg i32 %t3_i.07 to i64
%t13 = getelementptr i32, ptr %arr, i64 %0
%t14 = load i32, ptr %t13, align 4
%spec.select = tail call i32 @llvm.smax.i32(i32 %t14, i32 %t2_ret.06)
%t21 = add nuw nsw i32 %t3_i.07, 1
%t9 = icmp slt i32 %t21, %n
br i1 %t9, label %while.body.t5, label %while.exit.t6
while.exit.t6: ; preds = %while.body.t5, %entry
%t2_ret.0.lcssa = phi i32 [ 0, %entry ], [ %spec.select, %while.body.t5 ]
ret i32 %t2_ret.0.lcssa
}
; Function Attrs: nofree norecurse nosync nounwind memory(none)
define i32 @getNumPos(i32 %num, i32 %pos) local_unnamed_addr #1 {
entry:
%t333 = icmp sgt i32 %pos, 0
br i1 %t333, label %while.body.t29, label %while.exit.t30
while.body.t29: ; preds = %entry, %while.body.t29
%t27_i.05 = phi i32 [ %t37, %while.body.t29 ], [ 0, %entry ]
%t24.04 = phi i32 [ %t35, %while.body.t29 ], [ %num, %entry ]
%t35 = sdiv i32 %t24.04, 16
%t37 = add nuw nsw i32 %t27_i.05, 1
%t33 = icmp slt i32 %t37, %pos
br i1 %t33, label %while.body.t29, label %while.exit.t30
while.exit.t30: ; preds = %while.body.t29, %entry
%t24.0.lcssa = phi i32 [ %num, %entry ], [ %t35, %while.body.t29 ]
%t39 = srem i32 %t24.0.lcssa, 16
ret i32 %t39
}
define void @radixSort(i32 %bitround, ptr nocapture %a, i32 %l, i32 %r) local_unnamed_addr {
entry:
%t43 = tail call ptr @sysy_alloc_i32(i32 16)
tail call void @sysy_zero_i32(ptr %t43, i32 16)
%t46 = tail call ptr @sysy_alloc_i32(i32 16)
tail call void @sysy_zero_i32(ptr %t46, i32 16)
%t48 = tail call ptr @sysy_alloc_i32(i32 16)
tail call void @sysy_zero_i32(ptr %t48, i32 16)
%t54 = icmp eq i32 %bitround, -1
%t56 = add i32 %l, 1
%t58 = icmp sge i32 %t56, %r
%t59 = or i1 %t54, %t58
br i1 %t59, label %cleanup.t44, label %while.cond.t62.preheader
while.cond.t62.preheader: ; preds = %entry
%t6796 = icmp slt i32 %l, %r
br i1 %t6796, label %while.body.t63.lr.ph, label %while.exit.t64
while.body.t63.lr.ph: ; preds = %while.cond.t62.preheader
%t333.i = icmp sgt i32 %bitround, 0
br label %while.body.t63
cleanup.t44: ; preds = %merge.t196, %entry
tail call void @sysy_free_i32(ptr %t48)
tail call void @sysy_free_i32(ptr %t46)
tail call void @sysy_free_i32(ptr %t43)
ret void
while.body.t63: ; preds = %while.body.t63.lr.ph, %getNumPos.exit28
%storemerge97 = phi i32 [ %l, %while.body.t63.lr.ph ], [ %t83, %getNumPos.exit28 ]
%0 = sext i32 %storemerge97 to i64
%t69 = getelementptr i32, ptr %a, i64 %0
%t70 = load i32, ptr %t69, align 4
br i1 %t333.i, label %while.body.t29.i, label %getNumPos.exit.thread
getNumPos.exit.thread: ; preds = %while.body.t63
%t39.i80 = srem i32 %t70, 16
%1 = sext i32 %t39.i80 to i64
%t7381 = getelementptr i32, ptr %t48, i64 %1
%t7482 = load i32, ptr %t7381, align 4
br label %getNumPos.exit28
while.body.t29.i: ; preds = %while.body.t63, %while.body.t29.i
%t27_i.05.i = phi i32 [ %t37.i, %while.body.t29.i ], [ 0, %while.body.t63 ]
%t24.04.i = phi i32 [ %t35.i, %while.body.t29.i ], [ %t70, %while.body.t63 ]
%t35.i = sdiv i32 %t24.04.i, 16
%t37.i = add nuw nsw i32 %t27_i.05.i, 1
%t33.i = icmp slt i32 %t37.i, %bitround
br i1 %t33.i, label %while.body.t29.i, label %getNumPos.exit
getNumPos.exit: ; preds = %while.body.t29.i
%t39.i = srem i32 %t35.i, 16
%2 = sext i32 %t39.i to i64
%t73 = getelementptr i32, ptr %t48, i64 %2
%t74 = load i32, ptr %t73, align 4
br label %while.body.t29.i22
while.body.t29.i22: ; preds = %getNumPos.exit, %while.body.t29.i22
%t27_i.05.i23 = phi i32 [ %t37.i26, %while.body.t29.i22 ], [ 0, %getNumPos.exit ]
%t24.04.i24 = phi i32 [ %t35.i25, %while.body.t29.i22 ], [ %t70, %getNumPos.exit ]
%t35.i25 = sdiv i32 %t24.04.i24, 16
%t37.i26 = add nuw nsw i32 %t27_i.05.i23, 1
%t33.i27 = icmp slt i32 %t37.i26, %bitround
br i1 %t33.i27, label %while.body.t29.i22, label %getNumPos.exit28.loopexit
getNumPos.exit28.loopexit: ; preds = %while.body.t29.i22
%.pre114 = srem i32 %t35.i25, 16
%.pre115 = sext i32 %.pre114 to i64
br label %getNumPos.exit28
getNumPos.exit28: ; preds = %getNumPos.exit28.loopexit, %getNumPos.exit.thread
%.pre-phi116 = phi i64 [ %.pre115, %getNumPos.exit28.loopexit ], [ %1, %getNumPos.exit.thread ]
%t7584.in = phi i32 [ %t74, %getNumPos.exit28.loopexit ], [ %t7482, %getNumPos.exit.thread ]
%t7584 = add i32 %t7584.in, 1
%t81 = getelementptr i32, ptr %t48, i64 %.pre-phi116
store i32 %t7584, ptr %t81, align 4
%t83 = add nsw i32 %storemerge97, 1
%t67 = icmp slt i32 %t83, %r
br i1 %t67, label %while.body.t63, label %while.exit.t64
while.exit.t64: ; preds = %getNumPos.exit28, %while.cond.t62.preheader
store i32 %l, ptr %t43, align 4
%t88 = load i32, ptr %t48, align 4
%t89 = add i32 %t88, %l
store i32 %t89, ptr %t46, align 4
%invariant.gep = getelementptr i32, ptr %t46, i64 -1
%t101 = getelementptr i32, ptr %t43, i64 1
store i32 %t89, ptr %t101, align 4
%t106 = getelementptr i32, ptr %t48, i64 1
%t107 = load i32, ptr %t106, align 4
%t108 = add i32 %t107, %t89
%t110 = getelementptr i32, ptr %t46, i64 1
store i32 %t108, ptr %t110, align 4
%t101.1 = getelementptr i32, ptr %t43, i64 2
store i32 %t108, ptr %t101.1, align 4
%t106.1 = getelementptr i32, ptr %t48, i64 2
%t107.1 = load i32, ptr %t106.1, align 4
%t108.1 = add i32 %t107.1, %t108
%t110.1 = getelementptr i32, ptr %t46, i64 2
store i32 %t108.1, ptr %t110.1, align 4
%t101.2 = getelementptr i32, ptr %t43, i64 3
store i32 %t108.1, ptr %t101.2, align 4
%t106.2 = getelementptr i32, ptr %t48, i64 3
%t107.2 = load i32, ptr %t106.2, align 4
%t108.2 = add i32 %t107.2, %t108.1
%t110.2 = getelementptr i32, ptr %t46, i64 3
store i32 %t108.2, ptr %t110.2, align 4
%t101.3 = getelementptr i32, ptr %t43, i64 4
store i32 %t108.2, ptr %t101.3, align 4
%t106.3 = getelementptr i32, ptr %t48, i64 4
%t107.3 = load i32, ptr %t106.3, align 4
%t108.3 = add i32 %t107.3, %t108.2
%t110.3 = getelementptr i32, ptr %t46, i64 4
store i32 %t108.3, ptr %t110.3, align 4
%t101.4 = getelementptr i32, ptr %t43, i64 5
store i32 %t108.3, ptr %t101.4, align 4
%t106.4 = getelementptr i32, ptr %t48, i64 5
%t107.4 = load i32, ptr %t106.4, align 4
%t108.4 = add i32 %t107.4, %t108.3
%t110.4 = getelementptr i32, ptr %t46, i64 5
store i32 %t108.4, ptr %t110.4, align 4
%t101.5 = getelementptr i32, ptr %t43, i64 6
store i32 %t108.4, ptr %t101.5, align 4
%t106.5 = getelementptr i32, ptr %t48, i64 6
%t107.5 = load i32, ptr %t106.5, align 4
%t108.5 = add i32 %t107.5, %t108.4
%t110.5 = getelementptr i32, ptr %t46, i64 6
store i32 %t108.5, ptr %t110.5, align 4
%t101.6 = getelementptr i32, ptr %t43, i64 7
store i32 %t108.5, ptr %t101.6, align 4
%t106.6 = getelementptr i32, ptr %t48, i64 7
%t107.6 = load i32, ptr %t106.6, align 4
%t108.6 = add i32 %t107.6, %t108.5
%t110.6 = getelementptr i32, ptr %t46, i64 7
store i32 %t108.6, ptr %t110.6, align 4
%t101.7 = getelementptr i32, ptr %t43, i64 8
store i32 %t108.6, ptr %t101.7, align 4
%t106.7 = getelementptr i32, ptr %t48, i64 8
%t107.7 = load i32, ptr %t106.7, align 4
%t108.7 = add i32 %t107.7, %t108.6
%t110.7 = getelementptr i32, ptr %t46, i64 8
store i32 %t108.7, ptr %t110.7, align 4
%t101.8 = getelementptr i32, ptr %t43, i64 9
store i32 %t108.7, ptr %t101.8, align 4
%t106.8 = getelementptr i32, ptr %t48, i64 9
%t107.8 = load i32, ptr %t106.8, align 4
%t108.8 = add i32 %t107.8, %t108.7
%t110.8 = getelementptr i32, ptr %t46, i64 9
store i32 %t108.8, ptr %t110.8, align 4
%t101.9 = getelementptr i32, ptr %t43, i64 10
store i32 %t108.8, ptr %t101.9, align 4
%t106.9 = getelementptr i32, ptr %t48, i64 10
%t107.9 = load i32, ptr %t106.9, align 4
%t108.9 = add i32 %t107.9, %t108.8
%t110.9 = getelementptr i32, ptr %t46, i64 10
store i32 %t108.9, ptr %t110.9, align 4
%t101.10 = getelementptr i32, ptr %t43, i64 11
store i32 %t108.9, ptr %t101.10, align 4
%t106.10 = getelementptr i32, ptr %t48, i64 11
%t107.10 = load i32, ptr %t106.10, align 4
%t108.10 = add i32 %t107.10, %t108.9
%t110.10 = getelementptr i32, ptr %t46, i64 11
store i32 %t108.10, ptr %t110.10, align 4
%t101.11 = getelementptr i32, ptr %t43, i64 12
store i32 %t108.10, ptr %t101.11, align 4
%t106.11 = getelementptr i32, ptr %t48, i64 12
%t107.11 = load i32, ptr %t106.11, align 4
%t108.11 = add i32 %t107.11, %t108.10
%t110.11 = getelementptr i32, ptr %t46, i64 12
store i32 %t108.11, ptr %t110.11, align 4
%t101.12 = getelementptr i32, ptr %t43, i64 13
store i32 %t108.11, ptr %t101.12, align 4
%t106.12 = getelementptr i32, ptr %t48, i64 13
%t107.12 = load i32, ptr %t106.12, align 4
%t108.12 = add i32 %t107.12, %t108.11
%t110.12 = getelementptr i32, ptr %t46, i64 13
store i32 %t108.12, ptr %t110.12, align 4
%t101.13 = getelementptr i32, ptr %t43, i64 14
store i32 %t108.12, ptr %t101.13, align 4
%t106.13 = getelementptr i32, ptr %t48, i64 14
%t107.13 = load i32, ptr %t106.13, align 4
%t108.13 = add i32 %t107.13, %t108.12
%t110.13 = getelementptr i32, ptr %t46, i64 14
store i32 %t108.13, ptr %t110.13, align 4
%t101.14 = getelementptr i32, ptr %t43, i64 15
store i32 %t108.13, ptr %t101.14, align 4
%t106.14 = getelementptr i32, ptr %t48, i64 15
%t107.14 = load i32, ptr %t106.14, align 4
%t108.14 = add i32 %t107.14, %t108.13
%t110.14 = getelementptr i32, ptr %t46, i64 15
store i32 %t108.14, ptr %t110.14, align 4
%t333.i29 = icmp sgt i32 %bitround, 0
br label %while.cond.t118.preheader
while.cond.t118.preheader: ; preds = %while.exit.t64, %while.exit.t120
%storemerge17104 = phi i32 [ 0, %while.exit.t64 ], [ %t180, %while.exit.t120 ]
%3 = zext nneg i32 %storemerge17104 to i64
%t122 = getelementptr i32, ptr %t43, i64 %3
%t125 = getelementptr i32, ptr %t46, i64 %3
%t123100 = load i32, ptr %t122, align 4
%t126101 = load i32, ptr %t125, align 4
%t127102 = icmp slt i32 %t123100, %t126101
br i1 %t127102, label %while.body.t119, label %while.exit.t120
while.body.t191.peel.next: ; preds = %while.exit.t120
store i32 %l, ptr %t43, align 4
%t187 = load i32, ptr %t48, align 4
%t188 = add i32 %t187, %l
store i32 %t188, ptr %t46, align 4
%t215 = add i32 %bitround, -1
%t218.peel.pre = load i32, ptr %t43, align 4
tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.peel.pre, i32 %t188)
br label %merge.t196
while.body.t119: ; preds = %while.cond.t118.preheader, %while.exit.t136
%t123103 = phi i32 [ %t176, %while.exit.t136 ], [ %t123100, %while.cond.t118.preheader ]
%4 = sext i32 %t123103 to i64
%t132 = getelementptr i32, ptr %a, i64 %4
%t133 = load i32, ptr %t132, align 4
br label %while.cond.t134
while.exit.t120: ; preds = %while.exit.t136, %while.cond.t118.preheader
%t180 = add nuw nsw i32 %storemerge17104, 1
%t117 = icmp ult i32 %storemerge17104, 15
br i1 %t117, label %while.cond.t118.preheader, label %while.body.t191.peel.next
while.cond.t134: ; preds = %getNumPos.exit78, %while.body.t119
%t15099 = phi i32 [ %t150129, %getNumPos.exit78 ], [ %t133, %while.body.t119 ]
br i1 %t333.i29, label %while.body.t29.i32, label %getNumPos.exit38.thread
while.body.t29.i32: ; preds = %while.cond.t134, %while.body.t29.i32
%t27_i.05.i33 = phi i32 [ %t37.i36, %while.body.t29.i32 ], [ 0, %while.cond.t134 ]
%t24.04.i34 = phi i32 [ %t35.i35, %while.body.t29.i32 ], [ %t15099, %while.cond.t134 ]
%t35.i35 = sdiv i32 %t24.04.i34, 16
%t37.i36 = add nuw nsw i32 %t27_i.05.i33, 1
%t33.i37 = icmp slt i32 %t37.i36, %bitround
br i1 %t33.i37, label %while.body.t29.i32, label %getNumPos.exit38
getNumPos.exit38: ; preds = %while.body.t29.i32
%t39.i31 = srem i32 %t35.i35, 16
%t141.not = icmp eq i32 %t39.i31, %storemerge17104
br i1 %t141.not, label %while.exit.t136, label %while.body.t29.i42
getNumPos.exit38.thread: ; preds = %while.cond.t134
%t39.i3186 = srem i32 %t15099, 16
%t141.not87 = icmp eq i32 %t39.i3186, %storemerge17104
br i1 %t141.not87, label %while.exit.t136, label %getNumPos.exit48.thread
getNumPos.exit48.thread: ; preds = %getNumPos.exit38.thread
%5 = sext i32 %t39.i3186 to i64
%t147125 = getelementptr i32, ptr %t43, i64 %5
%t148126 = load i32, ptr %t147125, align 4
%6 = sext i32 %t148126 to i64
%t149127 = getelementptr i32, ptr %a, i64 %6
%t150128 = load i32, ptr %t149127, align 4
br label %getNumPos.exit68.thread
while.body.t29.i42: ; preds = %getNumPos.exit38, %while.body.t29.i42
%t27_i.05.i43 = phi i32 [ %t37.i46, %while.body.t29.i42 ], [ 0, %getNumPos.exit38 ]
%t24.04.i44 = phi i32 [ %t35.i45, %while.body.t29.i42 ], [ %t15099, %getNumPos.exit38 ]
%t35.i45 = sdiv i32 %t24.04.i44, 16
%t37.i46 = add nuw nsw i32 %t27_i.05.i43, 1
%t33.i47 = icmp slt i32 %t37.i46, %bitround
br i1 %t33.i47, label %while.body.t29.i42, label %getNumPos.exit48
getNumPos.exit48: ; preds = %while.body.t29.i42
%.pre117 = srem i32 %t35.i45, 16
%7 = sext i32 %.pre117 to i64
%t147 = getelementptr i32, ptr %t43, i64 %7
%t148 = load i32, ptr %t147, align 4
%8 = sext i32 %t148 to i64
%t149 = getelementptr i32, ptr %a, i64 %8
%t150 = load i32, ptr %t149, align 4
br i1 %t333.i29, label %while.body.t29.i52, label %getNumPos.exit68.thread
while.body.t29.i52: ; preds = %getNumPos.exit48, %while.body.t29.i52
%t27_i.05.i53 = phi i32 [ %t37.i56, %while.body.t29.i52 ], [ 0, %getNumPos.exit48 ]
%t24.04.i54 = phi i32 [ %t35.i55, %while.body.t29.i52 ], [ %t15099, %getNumPos.exit48 ]
%t35.i55 = sdiv i32 %t24.04.i54, 16
%t37.i56 = add nuw nsw i32 %t27_i.05.i53, 1
%t33.i57 = icmp slt i32 %t37.i56, %bitround
br i1 %t33.i57, label %while.body.t29.i52, label %while.body.t29.i62.preheader
while.body.t29.i62.preheader: ; preds = %while.body.t29.i52
%t39.i51 = srem i32 %t35.i55, 16
%9 = sext i32 %t39.i51 to i64
%t155 = getelementptr i32, ptr %t43, i64 %9
%t156 = load i32, ptr %t155, align 4
%10 = sext i32 %t156 to i64
%t157 = getelementptr i32, ptr %a, i64 %10
store i32 %t15099, ptr %t157, align 4
br label %while.body.t29.i62
getNumPos.exit68.thread: ; preds = %getNumPos.exit48, %getNumPos.exit48.thread
%t150130 = phi i32 [ %t150128, %getNumPos.exit48.thread ], [ %t150, %getNumPos.exit48 ]
%t39.i51.c = srem i32 %t15099, 16
%11 = sext i32 %t39.i51.c to i64
%t155.c = getelementptr i32, ptr %t43, i64 %11
%t156.c = load i32, ptr %t155.c, align 4
%12 = sext i32 %t156.c to i64
%t157.c = getelementptr i32, ptr %a, i64 %12
store i32 %t15099, ptr %t157.c, align 4
%t16191 = getelementptr i32, ptr %t43, i64 %11
%t16292 = load i32, ptr %t16191, align 4
br label %getNumPos.exit78
while.body.t29.i62: ; preds = %while.body.t29.i62.preheader, %while.body.t29.i62
%t27_i.05.i63 = phi i32 [ %t37.i66, %while.body.t29.i62 ], [ 0, %while.body.t29.i62.preheader ]
%t24.04.i64 = phi i32 [ %t35.i65, %while.body.t29.i62 ], [ %t15099, %while.body.t29.i62.preheader ]
%t35.i65 = sdiv i32 %t24.04.i64, 16
%t37.i66 = add nuw nsw i32 %t27_i.05.i63, 1
%t33.i67 = icmp slt i32 %t37.i66, %bitround
br i1 %t33.i67, label %while.body.t29.i62, label %getNumPos.exit68
getNumPos.exit68: ; preds = %while.body.t29.i62
%t39.i61 = srem i32 %t35.i65, 16
%13 = sext i32 %t39.i61 to i64
%t161 = getelementptr i32, ptr %t43, i64 %13
%t162 = load i32, ptr %t161, align 4
br label %while.body.t29.i72
while.body.t29.i72: ; preds = %getNumPos.exit68, %while.body.t29.i72
%t27_i.05.i73 = phi i32 [ %t37.i76, %while.body.t29.i72 ], [ 0, %getNumPos.exit68 ]
%t24.04.i74 = phi i32 [ %t35.i75, %while.body.t29.i72 ], [ %t15099, %getNumPos.exit68 ]
%t35.i75 = sdiv i32 %t24.04.i74, 16
%t37.i76 = add nuw nsw i32 %t27_i.05.i73, 1
%t33.i77 = icmp slt i32 %t37.i76, %bitround
br i1 %t33.i77, label %while.body.t29.i72, label %getNumPos.exit78.loopexit
getNumPos.exit78.loopexit: ; preds = %while.body.t29.i72
%.pre118 = srem i32 %t35.i75, 16
%.pre119 = sext i32 %.pre118 to i64
br label %getNumPos.exit78
getNumPos.exit78: ; preds = %getNumPos.exit78.loopexit, %getNumPos.exit68.thread
%t150129 = phi i32 [ %t150, %getNumPos.exit78.loopexit ], [ %t150130, %getNumPos.exit68.thread ]
%.pre-phi120 = phi i64 [ %.pre119, %getNumPos.exit78.loopexit ], [ %11, %getNumPos.exit68.thread ]
%t16394.in = phi i32 [ %t162, %getNumPos.exit78.loopexit ], [ %t16292, %getNumPos.exit68.thread ]
%t16394 = add i32 %t16394.in, 1
%t167 = getelementptr i32, ptr %t43, i64 %.pre-phi120
store i32 %t16394, ptr %t167, align 4
br label %while.cond.t134
while.exit.t136: ; preds = %getNumPos.exit38.thread, %getNumPos.exit38
%t171 = load i32, ptr %t122, align 4
%14 = sext i32 %t171 to i64
%t172 = getelementptr i32, ptr %a, i64 %14
store i32 %t15099, ptr %t172, align 4
%t175 = load i32, ptr %t122, align 4
%t176 = add i32 %t175, 1
store i32 %t176, ptr %t122, align 4
%t126 = load i32, ptr %t125, align 4
%t127 = icmp slt i32 %t176, %t126
br i1 %t127, label %while.body.t119, label %while.exit.t120
merge.t196: ; preds = %while.body.t191.peel.next, %merge.t196
%storemerge18107 = phi i32 [ 1, %while.body.t191.peel.next ], [ %t224, %merge.t196 ]
%15 = zext nneg i32 %storemerge18107 to i64
%gep106 = getelementptr i32, ptr %invariant.gep, i64 %15
%t202 = load i32, ptr %gep106, align 4
%t204 = getelementptr i32, ptr %t43, i64 %15
store i32 %t202, ptr %t204, align 4
%t209 = getelementptr i32, ptr %t48, i64 %15
%t210 = load i32, ptr %t209, align 4
%t211 = add i32 %t210, %t202
%t213 = getelementptr i32, ptr %t46, i64 %15
store i32 %t211, ptr %t213, align 4
%t218.pre = load i32, ptr %t204, align 4
tail call void @radixSort(i32 %t215, ptr %a, i32 %t218.pre, i32 %t211)
%t224 = add nuw nsw i32 %storemerge18107, 1
%t194 = icmp ult i32 %storemerge18107, 15
br i1 %t194, label %merge.t196, label %cleanup.t44, !llvm.loop !0
}
define noundef i32 @main() local_unnamed_addr {
entry:
%t231 = tail call i32 @getarray(ptr nonnull @a)
tail call void @starttime()
tail call void @radixSort(i32 8, ptr nonnull @a, i32 0, i32 %t231)
%ans.promoted = load i32, ptr @ans, align 4
%t2427 = icmp sgt i32 %t231, 0
br i1 %t2427, label %while.body.t238, label %while.exit.t239
while.body.t238: ; preds = %entry, %while.body.t238
%t236_i.09 = phi i32 [ %t254, %while.body.t238 ], [ 0, %entry ]
%t25268 = phi i32 [ %t252, %while.body.t238 ], [ %ans.promoted, %entry ]
%0 = zext nneg i32 %t236_i.09 to i64
%t246 = getelementptr [30000010 x i32], ptr @a, i64 0, i64 %0
%t247 = load i32, ptr %t246, align 4
%t249 = add nuw i32 %t236_i.09, 2
%t250 = srem i32 %t247, %t249
%t251 = mul i32 %t250, %t236_i.09
%t252 = add i32 %t251, %t25268
%t254 = add nuw nsw i32 %t236_i.09, 1
%t242 = icmp slt i32 %t254, %t231
br i1 %t242, label %while.body.t238, label %while.cond.t237.while.exit.t239_crit_edge
while.cond.t237.while.exit.t239_crit_edge: ; preds = %while.body.t238
store i32 %t252, ptr @ans, align 4
br label %while.exit.t239
while.exit.t239: ; preds = %while.cond.t237.while.exit.t239_crit_edge, %entry
%t257 = phi i32 [ %t252, %while.cond.t237.while.exit.t239_crit_edge ], [ %ans.promoted, %entry ]
%t258 = icmp slt i32 %t257, 0
br i1 %t258, label %then.t255, label %merge.t256
then.t255: ; preds = %while.exit.t239
%t260 = sub i32 0, %t257
store i32 %t260, ptr @ans, align 4
br label %merge.t256
merge.t256: ; preds = %then.t255, %while.exit.t239
tail call void @stoptime()
%t262 = load i32, ptr @ans, align 4
tail call void @putint(i32 %t262)
tail call void @putch(i32 10)
ret i32 0
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.smax.i32(i32, i32) #2
attributes #0 = { nofree norecurse nosync nounwind memory(argmem: read) }
attributes #1 = { nofree norecurse nosync nounwind memory(none) }
attributes #2 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!0 = distinct !{!0, !1}
!1 = !{!"llvm.loop.peeled.count", i32 1}

@ -5,6 +5,11 @@ ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
COMPILER="$ROOT_DIR/build/bin/compiler"
TMP_DIR="$ROOT_DIR/build/test_compiler"
TEST_DIRS=("$ROOT_DIR/test/test_case/functional" "$ROOT_DIR/test/test_case/performance")
CC_BIN="${CC:-cc}"
RUNTIME_SRC="$ROOT_DIR/sylib/sylib.c"
RUNTIME_OBJ="$TMP_DIR/sylib.o"
LLC_BIN="${LLC:-llc}"
CLANG_BIN="${CLANG:-clang}"
if [[ ! -x "$COMPILER" ]]; then
echo "未找到编译器: $COMPILER"
@ -14,6 +19,30 @@ fi
mkdir -p "$TMP_DIR"
if ! command -v "$LLC_BIN" >/dev/null 2>&1; then
echo "未找到 llc: $LLC_BIN"
echo "请安装 LLVM或通过 LLC 环境变量指定 llc 路径"
exit 1
fi
if ! command -v "$CLANG_BIN" >/dev/null 2>&1; then
echo "未找到 clang: $CLANG_BIN"
echo "请安装 Clang或通过 CLANG 环境变量指定 clang 路径"
exit 1
fi
# 编译运行库(供链接生成的可执行文件)
runtime_ready=0
if [[ -f "$RUNTIME_SRC" ]]; then
if "$CC_BIN" -c "$RUNTIME_SRC" -o "$RUNTIME_OBJ" >/dev/null 2>&1; then
runtime_ready=1
else
echo "[WARN] 运行库编译失败,生成的可执行文件将不链接 sylib: $RUNTIME_SRC"
fi
else
echo "[WARN] 未找到运行库源码: $RUNTIME_SRC"
fi
ir_total=0
ir_pass=0
result_total=0
@ -65,10 +94,15 @@ for test_dir in "${TEST_DIRS[@]}"; do
continue
fi
# 检查是否生成了有效的函数定义(在过滤后的内容中检查)
# 先过滤一下看看是否有define
filtered_content=$(sed -E '/^\[DEBUG\]|^SymbolTable::|^Check|^绑定|^保存|^dim_count:/d' "$raw_ll")
if ! echo "$filtered_content" | grep -qE '^define '; then
# 从混杂输出中提取 IR
# - 顶层实体define/declare/@global
# - 基本块标签
# - 缩进的指令行
# - 函数结束花括号
grep -E '^(define |declare |@|[[:space:]]|})|^[A-Za-z_.$%][A-Za-z0-9_.$%]*:$' "$raw_ll" > "$ll_file"
# 检查是否生成了有效函数定义
if ! grep -qE '^define ' "$ll_file"; then
echo " [IR] 失败: 未生成有效函数定义"
ir_failures+=("$input: invalid IR output")
# 失败:保留原始输出
@ -76,17 +110,7 @@ for test_dir in "${TEST_DIRS[@]}"; do
rm -f "$raw_ll"
continue
fi
# 编译成功过滤掉所有调试输出只保留IR
# 过滤规则:
# 1. 以 [DEBUG] 开头的行
# 2. SymbolTable:: 开头的行
# 3. CheckLValue: 开头的行
# 4. 绑定变量: 开头的行
# 5. dim_count: 开头的行
# 6. 空行(可选)
sed -E '/^(\[DEBUG|SymbolTable::|Check|绑定|保存|dim_)/d' "$raw_ll" > "$ll_file"
# 可选:删除多余的空行
sed -i '/^$/N;/\n$/D' "$ll_file"
@ -96,30 +120,72 @@ for test_dir in "${TEST_DIRS[@]}"; do
echo " [IR] 生成成功 (IR已保存到: $ll_file)"
# 运行测试
# 运行测试部分
if [[ -f "$expected_file" ]]; then
result_total=$((result_total+1))
# 运行LLVM IR
# 运行生成的可执行文件(优先链接运行库)
run_status=0
obj_file="$out_dir/$stem.o"
exe_file="$out_dir/$stem"
if ! "$LLC_BIN" -filetype=obj "$ll_file" -o "$obj_file" > "$stdout_file" 2>&1; then
echo " [RUN] llc 失败"
result_failures+=("$input: llc failed")
continue
fi
if [[ $runtime_ready -eq 1 ]]; then
if ! "$CLANG_BIN" "$obj_file" "$RUNTIME_OBJ" -o "$exe_file" >> "$stdout_file" 2>&1; then
echo " [RUN] clang 链接失败"
result_failures+=("$input: clang link failed")
continue
fi
else
if ! "$CLANG_BIN" "$obj_file" -o "$exe_file" >> "$stdout_file" 2>&1; then
echo " [RUN] clang 链接失败"
result_failures+=("$input: clang link failed")
continue
fi
fi
if [[ -f "$stdin_file" ]]; then
lli "$ll_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$?
"$exe_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$?
else
lli "$ll_file" > "$stdout_file" 2>&1 || run_status=$?
"$exe_file" > "$stdout_file" 2>&1 || run_status=$?
fi
# 读取预期返回值
expected=$(normalize_file "$expected_file")
# 读取预期文件内容
expected_content=$(normalize_file "$expected_file")
# 比较返回值
if [[ "$run_status" -eq "$expected" ]]; then
result_pass=$((result_pass+1))
echo " [RUN] 返回值匹配: $run_status"
# 成功:保留已清理的.ll文件删除输出文件
rm -f "$stdout_file"
# 判断预期文件是只包含退出码,还是包含输出+退出码
if [[ "$expected_content" =~ ^[0-9]+$ ]]; then
# 只包含退出码
expected=$expected_content
if [[ "$run_status" -eq "$expected" ]]; then
result_pass=$((result_pass+1))
echo " [RUN] 返回值匹配: $run_status"
rm -f "$stdout_file"
else
echo " [RUN] 返回值不匹配: got $run_status, expected $expected"
result_failures+=("$input: exit code mismatch (got $run_status, expected $expected)")
fi
else
echo " [RUN] 返回值不匹配: got $run_status, expected $expected"
result_failures+=("$input: exit code mismatch (got $run_status, expected $expected)")
# 失败:.ll文件已经保留输出文件也保留用于调试
# 包含输出和退出码(最后一行是退出码)
expected_output=$(head -n -1 <<< "$expected_content")
expected_exit=$(tail -n 1 <<< "$expected_content")
actual_output=$(cat "$stdout_file")
if [[ "$run_status" -eq "$expected_exit" ]] && [[ "$actual_output" == "$expected_output" ]]; then
result_pass=$((result_pass+1))
echo " [RUN] 成功: 退出码和输出都匹配"
rm -f "$stdout_file"
else
echo " [RUN] 不匹配: 退出码 got $run_status, expected $expected_exit"
if [[ "$actual_output" != "$expected_output" ]]; then
echo " 输出不匹配"
fi
result_failures+=("$input: mismatch")
fi
fi
else
echo " [RUN] 未找到预期返回值文件 $expected_file,跳过结果验证"

@ -1,9 +1,8 @@
// 管理基础类型、整型常量池和临时名生成。
// ir/IR.cpp
#include "ir/IR.h"
#include <cstring> // for memcpy
#include <cstring>
#include <sstream>
#include <functional>
namespace ir {
@ -17,9 +16,7 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get();
}
// 新增:获取浮点常量
ConstantFloat* Context::GetConstFloat(float v) {
// 使用浮点数的二进制表示作为键,避免精度问题
uint32_t key;
std::memcpy(&key, &v, sizeof(float));
@ -35,16 +32,68 @@ ConstantFloat* Context::GetConstFloat(float v) {
return ptr;
}
// 新增:创建数组常量
ConstantArray* Context::GetConstArray(std::shared_ptr<ArrayType> ty,
std::vector<ConstantValue*> elements) {
// 验证数组常量
size_t expected_size = ty->GetElementCount();
if (elements.size() != expected_size) {
// 如果元素数量不匹配,可能需要补零或报错
// 这里根据需求处理
if (elements.size() < expected_size) {
// 补零
auto elem_type = ty->GetElementType();
while (elements.size() < expected_size) {
if (elem_type->IsInt32()) {
elements.push_back(GetConstInt(0));
} else if (elem_type->IsFloat()) {
elements.push_back(GetConstFloat(0.0f));
}
}
} else {
throw std::runtime_error("Array constant size mismatch");
}
}
// 构建缓存键
struct ArrayKey {
std::shared_ptr<ArrayType> type;
std::vector<ConstantValue*> elements;
bool operator==(const ArrayKey& other) const {
if (type != other.type) return false;
if (elements.size() != other.elements.size()) return false;
for (size_t i = 0; i < elements.size(); ++i) {
if (elements[i] != other.elements[i]) return false;
}
return true;
}
};
struct ArrayKeyHash {
size_t operator()(const ArrayKey& key) const {
size_t hash = std::hash<Type*>{}(key.type.get());
for (auto* elem : key.elements) {
hash ^= std::hash<ConstantValue*>{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
}
return hash;
}
};
// 使用静态缓存(需要作为成员变量)
static std::unordered_map<ArrayKey, std::unique_ptr<ConstantArray>, ArrayKeyHash> cache;
ArrayKey key{ty, elements};
auto it = cache.find(key);
if (it != cache.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantArray>(ty, std::move(elements));
auto* ptr = constant.get();
const_arrays_.push_back(std::move(constant));
cache[std::move(key)] = std::move(constant);
return ptr;
}
// 新增:获取零常量
ConstantZero* Context::GetZeroConstant(std::shared_ptr<Type> ty) {
auto it = zero_constants_.find(ty.get());
if (it != zero_constants_.end()) {
@ -57,7 +106,6 @@ ConstantZero* Context::GetZeroConstant(std::shared_ptr<Type> ty) {
return ptr;
}
// 新增:获取聚合类型的零常量
ConstantAggregateZero* Context::GetAggregateZero(std::shared_ptr<Type> ty) {
auto it = aggregate_zeros_.find(ty.get());
if (it != aggregate_zeros_.end()) {
@ -76,5 +124,4 @@ std::string Context::NextTemp() {
return oss.str();
}
} // namespace ir

@ -1,9 +1,30 @@
// ir/GlobalValue.cpp
#include "ir/IR.h"
#include <stdexcept>
namespace ir {
namespace {
ConstantValue* GetScalarZeroConstant(const Type& type) {
if (type.IsInt32()) {
static ConstantInt* zero_i32 = new ConstantInt(Type::GetInt32Type(), 0);
return zero_i32;
}
if (type.IsFloat()) {
static ConstantFloat* zero_f32 = new ConstantFloat(Type::GetFloatType(), 0.0f);
return zero_f32;
}
if (type.IsInt1()) {
static ConstantInt* zero_i1 = new ConstantInt(Type::GetInt1Type(), 0);
return zero_i1;
}
return nullptr;
}
} // namespace
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
@ -13,42 +34,10 @@ void GlobalValue::SetInitializer(ConstantValue* init) {
}
// 获取实际的值类型(用于类型检查)
std::shared_ptr<Type> value_type;
// 如果当前类型是指针,获取指向的值类型
if (GetType()->IsPtrInt32()) {
value_type = Type::GetInt32Type();
} else if (GetType()->IsPtrFloat()) {
value_type = Type::GetFloatType();
} else if (GetType()->IsPtrInt1()) {
value_type = Type::GetInt1Type();
} else {
// 非指针类型:直接使用当前类型
value_type = GetType();
}
std::shared_ptr<Type> value_type = GetValueType();
// 类型检查
bool type_match = false;
// 检查标量类型
if (value_type->IsInt32() && init->GetType()->IsInt32()) {
type_match = true;
} else if (value_type->IsFloat() && init->GetType()->IsFloat()) {
type_match = true;
} else if (value_type->IsInt1() && init->GetType()->IsInt1()) {
type_match = true;
}
// 检查数组类型:允许用单个标量初始化整个数组
else if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
auto* elem_type = array_ty->GetElementType().get();
if (elem_type->IsInt32() && init->GetType()->IsInt32()) {
type_match = true;
} else if (elem_type->IsFloat() && init->GetType()->IsFloat()) {
type_match = true;
}
}
bool type_match = CheckTypeCompatibility(value_type, init);
if (!type_match) {
throw std::runtime_error("GlobalValue::SetInitializer: type mismatch");
@ -60,23 +49,14 @@ void GlobalValue::SetInitializer(ConstantValue* init) {
void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
if (init.empty()) {
initializer_.clear();
return;
}
// 获取实际的值类型
std::shared_ptr<Type> value_type;
std::shared_ptr<Type> value_type = GetValueType();
if (GetType()->IsPtrInt32()) {
value_type = Type::GetInt32Type();
} else if (GetType()->IsPtrFloat()) {
value_type = Type::GetFloatType();
} else if (GetType()->IsPtrInt1()) {
value_type = Type::GetInt1Type();
} else {
value_type = GetType();
}
// 检查类型
// 类型检查
if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
size_t array_size = array_ty->GetElementCount();
@ -87,16 +67,23 @@ void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
// 检查每个初始化值的类型
auto* elem_type = array_ty->GetElementType().get();
for (auto* elem : init) {
for (size_t i = 0; i < init.size(); ++i) {
auto* elem = init[i];
if (!elem) {
throw std::runtime_error("GlobalValue::SetInitializer: null initializer at index " + std::to_string(i));
}
bool elem_match = false;
if (elem_type->IsInt32() && elem->GetType()->IsInt32()) {
elem_match = true;
} else if (elem_type->IsFloat() && elem->GetType()->IsFloat()) {
elem_match = true;
} else if (elem_type->IsInt1() && elem->GetType()->IsInt1()) {
elem_match = true;
}
if (!elem_match) {
throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch");
throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch at index " + std::to_string(i));
}
}
}
@ -105,6 +92,10 @@ void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
throw std::runtime_error("GlobalValue::SetInitializer: scalar requires exactly one initializer");
}
if (!init[0]) {
throw std::runtime_error("GlobalValue::SetInitializer: null initializer");
}
if ((value_type->IsInt32() && !init[0]->GetType()->IsInt32()) ||
(value_type->IsFloat() && !init[0]->GetType()->IsFloat()) ||
(value_type->IsInt1() && !init[0]->GetType()->IsInt1())) {
@ -118,4 +109,87 @@ void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
initializer_ = init;
}
// 辅助方法:获取实际的值类型(处理指针包装)
std::shared_ptr<Type> GlobalValue::GetValueType() const {
if (GetType()->IsPtrInt32()) {
return Type::GetInt32Type();
} else if (GetType()->IsPtrFloat()) {
return Type::GetFloatType();
} else if (GetType()->IsPtrInt1()) {
return Type::GetInt1Type();
}
return GetType();
}
// 辅助方法:检查类型兼容性
bool GlobalValue::CheckTypeCompatibility(std::shared_ptr<Type> value_type,
ConstantValue* init) const {
// 检查标量类型
if (value_type->IsInt32() && init->GetType()->IsInt32()) {
return true;
} else if (value_type->IsFloat() && init->GetType()->IsFloat()) {
return true;
} else if (value_type->IsInt1() && init->GetType()->IsInt1()) {
return true;
}
// 检查数组类型:允许用单个标量初始化整个数组
else if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
auto* elem_type = array_ty->GetElementType().get();
if (elem_type->IsInt32() && init->GetType()->IsInt32()) {
return true;
} else if (elem_type->IsFloat() && init->GetType()->IsFloat()) {
return true;
} else if (elem_type->IsInt1() && init->GetType()->IsInt1()) {
return true;
}
// 也可以允许 ConstantArray 作为初始化器
else if (init->GetType()->IsArray()) {
auto* init_array = static_cast<ConstantArray*>(init);
return init_array->IsValid();
}
}
// 检查指针类型(用于数组参数)
else if (value_type->IsPtrInt32() && init->GetType()->IsInt32()) {
return true;
} else if (value_type->IsPtrFloat() && init->GetType()->IsFloat()) {
return true;
}
return false;
}
// 添加获取数组元素的便捷方法
ConstantValue* GlobalValue::GetArrayElement(size_t index) const {
if (!GetType()->IsArray()) {
return nullptr;
}
auto* array_ty = dynamic_cast<ArrayType*>(GetType().get());
if (!array_ty) {
return nullptr;
}
if (index >= static_cast<size_t>(array_ty->GetElementCount())) {
return nullptr;
}
if (index >= initializer_.size()) {
return GetScalarZeroConstant(*array_ty->GetElementType());
}
return initializer_[index];
}
// 添加获取数组元素数量的方法
size_t GlobalValue::GetArraySize() const {
if (!IsArrayConstant()) {
return 0;
}
return initializer_.size();
}
// 添加判断是否为数组常量的方法
bool GlobalValue::IsArrayConstant() const {
return GetType()->IsArray() && !initializer_.empty();
}
} // namespace ir

@ -119,6 +119,17 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!ty) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAlloca 缺少类型"));
}
return insert_block_->Append<AllocaInst>(ty, name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -190,18 +201,21 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
}
} else if (ptr_ty->IsPtrFloat()) {
if (!val_ty->IsFloat()) {
throw std::runtime_error(FormatError("ir", "存储类型不匹配:期望 float"));
throw std::runtime_error(
FormatError("ir", "存储类型不匹配:期望 float, 实际 kind=" +
std::to_string(static_cast<int>(val_ty->GetKind()))));
}
} else if (ptr_ty->IsArray()) {
// 数组存储:检查元素类型
auto* array_ty = dynamic_cast<ArrayType*>(ptr_ty.get());
if (array_ty) {
auto elem_ty = array_ty->GetElementType();
if (elem_ty->IsInt32() && !val_ty->IsInt32()) {
throw std::runtime_error(FormatError("ir", "数组元素类型不匹配:期望 int32"));
} else if (elem_ty->IsFloat() && !val_ty->IsFloat()) {
throw std::runtime_error(FormatError("ir", "数组元素类型不匹配:期望 float"));
}
// 数组存储支持两种形式:
// 1. 标量元素写入(通常配合 GEP 后落到元素指针,不会走到这里)
// 2. 聚合数组整体写入,例如 `store [16 x i32] zeroinitializer, [16 x i32]* %arr`
if (!val_ty->IsArray()) {
throw std::runtime_error(
FormatError("ir", "数组地址仅支持聚合数组整体存储"));
}
if (val_ty->GetKind() != ptr_ty->GetKind()) {
throw std::runtime_error(
FormatError("ir", "聚合数组存储类型不匹配"));
}
}
@ -212,10 +226,6 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!v) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
}
@ -386,7 +396,8 @@ ZExtInst* IRBuilder::CreateZExt(Value* value, std::shared_ptr<Type> target_ty,
FormatError("ir", "ZExt 目标类型必须是整数类型"));
}
return insert_block_->Append<ZExtInst>(value, target_ty, name);
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<ZExtInst>(value, target_ty, inst_name);
}
// 创建截断指令
@ -416,7 +427,8 @@ TruncInst* IRBuilder::CreateTrunc(Value* value, std::shared_ptr<Type> target_ty,
FormatError("ir", "Trunc 目标类型必须是整数类型"));
}
return insert_block_->Append<TruncInst>(value, target_ty, name);
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<TruncInst>(value, target_ty, inst_name);
}
// 便捷方法i1 转 i32
@ -466,7 +478,9 @@ BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name
if (!rhs) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAnd 缺少 rhs"));
}
return insert_block_->Append<BinaryInst>(Opcode::And, Type::GetInt32Type(), lhs, rhs, name);
auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type()
: Type::GetInt32Type();
return insert_block_->Append<BinaryInst>(Opcode::And, result_ty, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) {
@ -479,7 +493,9 @@ BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name)
if (!rhs) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateOr 缺少 rhs"));
}
return insert_block_->Append<BinaryInst>(Opcode::Or, Type::GetInt32Type(), lhs, rhs, name);
auto result_ty = lhs->GetType()->IsInt1() ? Type::GetInt1Type()
: Type::GetInt32Type();
return insert_block_->Append<BinaryInst>(Opcode::Or, result_ty, lhs, rhs, name);
}
IcmpInst* IRBuilder::CreateNot(Value* val, const std::string& name) {
@ -489,7 +505,12 @@ IcmpInst* IRBuilder::CreateNot(Value* val, const std::string& name) {
if (!val) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateNot 缺少 operand"));
}
auto zero = CreateConstInt(0);
if (val->GetType()->IsInt1()) {
auto* ext = CreateZExtI1ToI32(val, "");
auto* zero = CreateConstInt(0);
return CreateICmpEQ(ext, zero, name);
}
auto* zero = CreateConstInt(0);
return CreateICmpEQ(val, zero, name);
}
@ -511,8 +532,29 @@ GEPInst* IRBuilder::CreateGEP(Value* base,
}
}
// GEP返回指针类型假设与base类型相同
return insert_block_->Append<GEPInst>(base->GetType(), base, indices, name);
// 结果类型推断:
// - 对 i32*/float* 基址,结果仍分别为 i32*/float*
// - 对数组基址,按多索引向下剥离元素类型;若到达标量则返回对应标量指针
// (本项目没有“指向数组的指针类型”,未完全剥离时退回数组类型)
std::shared_ptr<Type> result_ty = base->GetType();
if (base->GetType()->IsPtrInt32()) {
result_ty = Type::GetPtrInt32Type();
} else if (base->GetType()->IsPtrFloat()) {
result_ty = Type::GetPtrFloatType();
} else if (base->GetType()->IsArray()) {
std::shared_ptr<Type> cur = base->GetType();
for (size_t i = 1; i < indices.size(); ++i) {
auto* at = dynamic_cast<ArrayType*>(cur.get());
if (!at) break;
cur = at->GetElementType();
}
if (cur->IsInt32()) result_ty = Type::GetPtrInt32Type();
else if (cur->IsFloat()) result_ty = Type::GetPtrFloatType();
else result_ty = cur;
}
return insert_block_->Append<GEPInst>(result_ty, base, indices, name);
}
@ -609,12 +651,20 @@ FcmpInst* IRBuilder::CreateFCmpOGE(Value* lhs, Value* rhs, const std::string& na
// 类型转换
SIToFPInst* IRBuilder::CreateSIToFP(Value* value, std::shared_ptr<Type> target_ty,
const std::string& name) {
return insert_block_->Append<SIToFPInst>(value, target_ty, name);
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<SIToFPInst>(value, target_ty, inst_name);
}
FPToSIInst* IRBuilder::CreateFPToSI(Value* value, std::shared_ptr<Type> target_ty,
const std::string& name) {
return insert_block_->Append<FPToSIInst>(value, target_ty, name);
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
const std::string inst_name = name.empty() ? ctx_.NextTemp() : name;
return insert_block_->Append<FPToSIInst>(value, target_ty, inst_name);
}
} // namespace ir

@ -4,6 +4,9 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <ostream>
#include <stdexcept>
#include <string>
@ -12,7 +15,109 @@
namespace ir {
static const char* TypeToString(const Type& ty) {
static std::string TypeToString(const Type& ty);
static std::string ArrayTypeToStringFrom(const Type& base_ty,
const std::vector<int>& dims,
size_t begin) {
std::string s = TypeToString(base_ty);
for (size_t i = dims.size(); i-- > begin;) {
s = "[" + std::to_string(dims[i]) + " x " + s + "]";
}
return s;
}
static bool IsZeroConstant(const ConstantValue* value) {
if (!value) {
return true;
}
if (auto* ci = dynamic_cast<const ConstantInt*>(value)) {
return ci->GetValue() == 0;
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(value)) {
return cf->GetValue() == 0.0f;
}
if (dynamic_cast<const ConstantZero*>(value) ||
dynamic_cast<const ConstantAggregateZero*>(value)) {
return true;
}
if (auto* arr = dynamic_cast<const ConstantArray*>(value)) {
for (auto* elem : arr->GetElements()) {
if (!IsZeroConstant(elem)) {
return false;
}
}
return true;
}
return false;
}
static size_t AggregateSpan(const std::vector<int>& dims, size_t level) {
size_t span = 1;
for (size_t i = level; i < dims.size(); ++i) {
span *= static_cast<size_t>(dims[i]);
}
return span;
}
static bool IsZeroRange(const std::vector<ConstantValue*>& init,
size_t begin,
size_t count) {
for (size_t i = 0; i < count; ++i) {
const size_t index = begin + i;
if (index >= init.size()) {
continue;
}
if (!IsZeroConstant(init[index])) {
return false;
}
}
return true;
}
static void PrintFlatArrayBody(std::ostream& os,
const Type& base_ty,
const std::vector<int>& dims,
size_t level,
const std::vector<ConstantValue*>& init,
size_t& flat_index) {
const size_t span = AggregateSpan(dims, level);
if (IsZeroRange(init, flat_index, span)) {
os << "zeroinitializer";
flat_index += span;
return;
}
os << "[";
for (int i = 0; i < dims[level]; ++i) {
if (i > 0) os << ", ";
if (level + 1 < dims.size()) {
os << ArrayTypeToStringFrom(base_ty, dims, level + 1) << " ";
PrintFlatArrayBody(os, base_ty, dims, level + 1, init, flat_index);
continue;
}
os << TypeToString(base_ty) << " ";
if (flat_index < init.size() && init[flat_index]) {
if (auto* ci = dynamic_cast<const ConstantInt*>(init[flat_index])) {
os << ci->GetValue();
} else if (auto* cf = dynamic_cast<const ConstantFloat*>(init[flat_index])) {
os << cf->GetValue();
} else if (IsZeroConstant(init[flat_index])) {
os << "0";
} else {
os << "0";
}
} else {
os << "0";
}
++flat_index;
}
os << "]";
}
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void: return "void";
case Type::Kind::Int32: return "i32";
@ -20,10 +125,23 @@ static const char* TypeToString(const Type& ty) {
case Type::Kind::PtrInt32: return "i32*";
case Type::Kind::PtrFloat: return "float*";
case Type::Kind::Label: return "label";
case Type::Kind::Array: return "array";
case Type::Kind::Function: return "function";
case Type::Kind::Int1: return "i1";
case Type::Kind::PtrInt1: return "i1*";
case Type::Kind::Array: {
// 打印数组类型为 LLVM 风格,如 [4 x [2 x i32]]
auto* at = dynamic_cast<const ArrayType*>(&ty);
if (!at) return "array";
// 递归构建类型字符串
std::string elem = TypeToString(*at->GetElementType());
const auto& dims = at->GetDimensions();
// 从外到内构建
std::string s = elem;
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
s = "[" + std::to_string(*it) + " x " + s + "]";
}
return s;
}
default: return "unknown";
}
throw std::runtime_error(FormatError("ir", "未知类型"));
@ -54,9 +172,9 @@ static const char* OpcodeToString(Opcode op) {
case Opcode::Icmp:
return "icmp";
case Opcode::Div:
return "div";
return "sdiv";
case Opcode::Mod:
return "mod";
return "srem";
case Opcode::ZExt:
return "zext";
case Opcode::Trunc:
@ -82,12 +200,29 @@ static const char* OpcodeToString(Opcode op) {
return "?";
}
// 将 float 值转为 LLVM IR 接受的 64-bit 十六进制浮点格式
static std::string FloatToLLVMHex(float f) {
double d = static_cast<double>(f);
uint64_t bits;
memcpy(&bits, &d, sizeof(bits));
char buf[20];
snprintf(buf, sizeof(buf), "0x%016llX", (unsigned long long)bits);
return buf;
}
static std::string ValueToString(const Value* v) {
if (!v) {
return "<null>";
}
if (dynamic_cast<const ConstantZero*>(v) ||
dynamic_cast<const ConstantAggregateZero*>(v)) {
return "zeroinitializer";
}
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
if (!v) {
return "<null>";
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return FloatToLLVMHex(cf->GetValue());
}
const auto& name = v->GetName();
if (name.empty()) {
@ -102,27 +237,107 @@ static std::string ValueToString(const Value* v) {
return "%" + name;
}
static std::string MemoryTypeToString(const Type& ty) {
std::string text = TypeToString(ty);
if (ty.IsArray()) {
text += "*";
}
return text;
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& global : module.GetGlobals()) {
if (!global) continue;
os << "@" << global->GetName() << " = global ";
os << "@" << global->GetName() << " = "
<< (global->IsConstant() ? "constant " : "global ");
if (global->GetType()->IsPtrInt32()) {
os << "i32 0\n";
} else if (global->GetType()->IsPtrFloat()) {
os << "float 0.0\n";
} else {
os << TypeToString(*global->GetType()) << " zeroinitializer\n";
os << "i32 ";
if (global->HasInitializer()) {
auto* ci = dynamic_cast<const ConstantInt*>(global->GetInitializer().front());
os << (ci ? ci->GetValue() : 0);
} else {
os << "0";
}
os << "\n";
continue;
}
if (global->GetType()->IsPtrFloat()) {
os << "float ";
if (global->HasInitializer()) {
auto* cf = dynamic_cast<const ConstantFloat*>(global->GetInitializer().front());
os << (cf ? ValueToString(cf) : FloatToLLVMHex(0.0f));
} else {
os << FloatToLLVMHex(0.0f);
}
os << "\n";
continue;
}
if (global->GetType()->IsArray()) {
auto* at = dynamic_cast<const ArrayType*>(global->GetType().get());
os << TypeToString(*global->GetType()) << " ";
if (!at || !global->HasInitializer() ||
IsZeroRange(global->GetInitializer(), 0, AggregateSpan(at->GetDimensions(), 0))) {
os << "zeroinitializer\n";
continue;
}
size_t flat_index = 0;
PrintFlatArrayBody(os,
*at->GetElementType(),
at->GetDimensions(),
0,
global->GetInitializer(),
flat_index);
os << "\n";
continue;
}
os << TypeToString(*global->GetType()) << " zeroinitializer\n";
}
for (const auto& func : module.GetFunctions()) {
auto* func_ty = static_cast<const FunctionType*>(func->GetType().get());
os << "define " << TypeToString(*func_ty->GetReturnType()) << " @" << func->GetName() << "(";
auto print_func_params = [&](const Function* func,
const FunctionType* func_ty) {
bool first = true;
for (const auto& arg : func->GetArguments()) {
if (!func->GetArguments().empty()) {
for (const auto& arg : func->GetArguments()) {
if (!first) os << ", ";
first = false;
os << TypeToString(*arg->GetType()) << " %" << arg->GetName();
}
return;
}
for (const auto& pty : func_ty->GetParamTypes()) {
if (!first) os << ", ";
first = false;
os << TypeToString(*arg->GetType()) << " %" << arg->GetName();
os << TypeToString(*pty);
}
};
auto is_declaration_only = [](const Function* func) {
const auto& blocks = func->GetBlocks();
if (blocks.size() != 1) return false;
const auto& only = blocks.front();
if (!only) return false;
return only->GetInstructions().empty();
};
for (const auto& func : module.GetFunctions()) {
auto* func_ty = static_cast<const FunctionType*>(func->GetType().get());
if (is_declaration_only(func.get())) {
os << "declare " << TypeToString(*func_ty->GetReturnType()) << " @"
<< func->GetName() << "(";
print_func_params(func.get(), func_ty);
os << ")\n";
continue;
}
os << "define " << TypeToString(*func_ty->GetReturnType()) << " @"
<< func->GetName() << "(";
print_func_params(func.get(), func_ty);
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
@ -139,7 +354,11 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
case Opcode::Mod:
case Opcode::And:
case Opcode::Not:
case Opcode::Or:
case Opcode::Or:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
{
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
@ -166,7 +385,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< MemoryTypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
@ -174,21 +393,29 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
auto* store = static_cast<const StoreInst*>(inst);
os << " store " << TypeToString(*store->GetValue()->GetType()) << " "
<< ValueToString(store->GetValue())
<< ", " << TypeToString(*store->GetPtr()->GetType()) << " "
<< ", " << MemoryTypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
if (!ret->GetValue()) {
os << " ret void\n";
} else {
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
}
break;
}
// CallInst类在 include/ir/IR.h 中定义
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
os << " " << call->GetName() << " = call "
<< TypeToString(*call->GetType()) << " @" << call->GetCallee()->GetName() << "(";
os << " ";
if (!call->GetType()->IsVoid()) {
os << call->GetName() << " = ";
}
os << "call " << TypeToString(*call->GetType()) << " @"
<< call->GetCallee()->GetName() << "(";
bool first = true;
for (auto* arg : call->GetArgs()) {
if (!first) os << ", ";
@ -248,16 +475,81 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
break;
}
case Opcode::GEP:{
// 简化打印:只打印基本信息和操作数数量
// 打印为类似 LLVM 的 getelementptr 形式:
// getelementptr <elem_ty>, <base_ty> <base>, i32 <idx0>, i32 <idx1>, ...
os << " " << inst->GetName() << " = getelementptr ";
os << TypeToString(*inst->GetType()) << " (";
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (i > 0) os << ", ";
os << ValueToString(inst->GetOperand(i));
// 基地址类型使用第一个操作数的类型
Value* base = inst->GetOperand(0);
// GEP 的第一个类型参数应是基址指向的元素类型pointee
std::string elem_ty;
if (base->GetType()->IsPtrInt32()) elem_ty = "i32";
else if (base->GetType()->IsPtrFloat()) elem_ty = "float";
else if (base->GetType()->IsArray()) elem_ty = TypeToString(*base->GetType());
else elem_ty = TypeToString(*inst->GetType());
std::string base_ty = TypeToString(*base->GetType());
if (base->GetType()->IsArray()) {
base_ty += "*";
}
os << ")\n";
os << elem_ty << ", " << base_ty << " " << ValueToString(base);
// 后续操作数为索引,按照 i32 打印
// 特殊处理:如果 base 是标量指针i32*/float*)且第一个索引是常量 0
// 且后续还有索引,则丢弃第一个 0对 T* 来说多余且会导致无效 IR
size_t start_idx = 1;
if ((base->GetType()->IsPtrInt32() || base->GetType()->IsPtrFloat()) &&
inst->GetNumOperands() >= 3) {
// 检查第一个索引是否为常量 0
auto* first_idx = inst->GetOperand(1);
if (auto* ci = dynamic_cast<const ConstantInt*>(first_idx)) {
if (ci->GetValue() == 0) {
start_idx = 2; // 跳过第一个 0
}
}
}
for (size_t i = start_idx; i < inst->GetNumOperands(); ++i) {
os << ", i32 " << ValueToString(inst->GetOperand(i));
}
os << "\n";
break;
}
case Opcode::FCmp: {
auto* fcmp = static_cast<const FcmpInst*>(inst);
os << " " << fcmp->GetName() << " = fcmp ";
switch (fcmp->GetPredicate()) {
case FcmpInst::Predicate::OEQ: os << "oeq"; break;
case FcmpInst::Predicate::ONE: os << "one"; break;
case FcmpInst::Predicate::OLT: os << "olt"; break;
case FcmpInst::Predicate::OLE: os << "ole"; break;
case FcmpInst::Predicate::OGT: os << "ogt"; break;
case FcmpInst::Predicate::OGE: os << "oge"; break;
default: os << "oeq"; break;
}
os << " " << TypeToString(*fcmp->GetLhs()->GetType())
<< " " << ValueToString(fcmp->GetLhs())
<< ", " << ValueToString(fcmp->GetRhs()) << "\n";
break;
}
case Opcode::SIToFP: {
auto* sitofp = static_cast<const SIToFPInst*>(inst);
os << " " << sitofp->GetName() << " = sitofp "
<< TypeToString(*sitofp->GetValue()->GetType()) << " "
<< ValueToString(sitofp->GetValue()) << " to "
<< TypeToString(*sitofp->GetType()) << "\n";
break;
}
case Opcode::FPToSI: {
auto* fptosi = static_cast<const FPToSIInst*>(inst);
os << " " << fptosi->GetName() << " = fptosi "
<< TypeToString(*fptosi->GetValue()->GetType()) << " "
<< ValueToString(fptosi->GetValue()) << " to "
<< TypeToString(*fptosi->GetType()) << "\n";
break;
}
default: {
// 处理未知操作码
os << " ; 未知指令: " << OpcodeToString(inst->GetOpcode()) << "\n";
@ -286,10 +578,10 @@ void IRPrinter::PrintConstant(const ConstantValue* constant, std::ostream& os) {
}
os << "]";
}
else if (auto* zero = dynamic_cast<const ConstantZero*>(constant)) {
else if (dynamic_cast<const ConstantZero*>(constant)) {
os << "zero";
}
else if (auto* agg_zero = dynamic_cast<const ConstantAggregateZero*>(constant)) {
else if (dynamic_cast<const ConstantAggregateZero*>(constant)) {
os << "zeroinitializer";
}
}

@ -73,6 +73,10 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
case Opcode::Mod:
case Opcode::And:
case Opcode::Or:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
// 有效的二元操作符
break;
case Opcode::Not:
@ -96,21 +100,26 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
throw std::runtime_error(FormatError("ir", "BinaryInst 操作数类型不匹配"));
}
bool is_logical = (op == Opcode::And || op == Opcode::Or);
// 检查操作数类型是否支持
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) {
throw std::runtime_error(
FormatError("ir", "BinaryInst 只支持 int32 和 float 类型"));
if (is_logical) {
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsInt1()) {
throw std::runtime_error(
FormatError("ir", "逻辑运算仅支持 i32/i1"));
}
} else {
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) {
throw std::runtime_error(
FormatError("ir", "BinaryInst 只支持 int32 和 float 类型"));
}
}
// 对于算术运算,结果类型应与操作数类型相同
bool is_logical = (op == Opcode::And || op == Opcode::Or);
if (is_logical) {
// 比较和逻辑运算的结果应该是整数类型
if (!type_->IsInt32()) {
// 逻辑运算结果类型应与操作数一致i1 或 i32
if (type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(
FormatError("ir", "比较和逻辑运算的结果类型必须是 int32"));
FormatError("ir", "逻辑运算结果类型与操作数类型不匹配"));
}
} else {
// 算术运算的结果类型应与操作数类型相同
@ -130,21 +139,27 @@ Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
AddOperand(val);
if (val) {
AddOperand(val);
}
}
Value* ReturnInst::GetValue() const { return GetOperand(0); }
Value* ReturnInst::GetValue() const {
if (GetNumOperands() == 0) {
return nullptr;
}
return GetOperand(0);
}
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ || !type_->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
if (!type_ ||
(!type_->IsPtrInt32() && !type_->IsPtrFloat() && !type_->IsArray())) {
throw std::runtime_error(
FormatError("ir", "AllocaInst 仅支持 i32* / float* / array"));
}
}
@ -153,12 +168,15 @@ LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
if (!type_ || (!type_->IsInt32() && !type_->IsFloat() && !type_->IsInt1())) {
throw std::runtime_error(
FormatError("ir", "LoadInst 仅支持加载 i32/float/i1"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
if (!ptr->GetType() ||
(!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() &&
!ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) {
throw std::runtime_error(
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
FormatError("ir", "LoadInst 仅支持从指针或数组地址加载"));
}
AddOperand(ptr);
}
@ -176,13 +194,25 @@ StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
if (!val->GetType() ||
(!val->GetType()->IsInt32() && !val->GetType()->IsFloat() &&
!val->GetType()->IsInt1() && !val->GetType()->IsArray())) {
throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
FormatError("ir", "StoreInst 仅支持存储 i32/float/i1/array"));
}
if (!ptr->GetType() ||
(!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() &&
!ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) {
throw std::runtime_error(FormatError("ir", "StoreInst 仅支持写入指针或数组地址"));
}
if (ptr->GetType()->IsArray()) {
if (!val->GetType()->IsArray() ||
val->GetType()->GetKind() != ptr->GetType()->GetKind()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 聚合存储要求 value/ptr 具有相同数组类型"));
}
}
AddOperand(val);
AddOperand(ptr);
}

File diff suppressed because it is too large Load Diff

@ -6,10 +6,11 @@
#include "ir/IR.h"
#include "utils/Log.h"
// 修改 GenerateIR 函数
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) {
const SemaResult& sema_result) {
auto module = std::make_unique<ir::Module>();
IRGenImpl gen(*module, sema);
IRGenImpl gen(*module, sema_result.context, sema_result.symbol_table);
tree.accept(&gen);
return module;
}

@ -21,8 +21,9 @@
// - 条件与比较表达式
// - ...
// 表达式生成
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
std::cout << "[DEBUG IRGEN] EvalExpr: " << expr.getText() << std::endl;
std::cerr << "[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText() << std::endl;
try {
auto result_any = expr.accept(this);
@ -38,15 +39,7 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
} catch (const std::bad_any_cast& e) {
std::cerr << "[ERROR] EvalExpr: bad any_cast - " << e.what() << std::endl;
std::cerr << " Type info: " << result_any.type().name() << std::endl;
// 尝试其他可能的类型
try {
// 检查是否是无值的any可能来自visit函数返回{}
std::cerr << "[DEBUG] EvalExpr: Trying to handle empty any" << std::endl;
return nullptr;
} catch (...) {
throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型"));
}
throw std::runtime_error(FormatError("irgen", "表达式求值返回了错误的类型"));
}
} catch (const std::exception& e) {
std::cerr << "[ERROR] Exception in EvalExpr: " << e.what() << std::endl;
@ -54,16 +47,14 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
}
}
ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) {
std::cerr << "[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText() << std::endl;
return std::any_cast<ir::Value*>(cond.accept(this));
}
// 基本表达式:数字、变量、括号表达式
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::cout << "[DEBUG IRGEN] visitPrimaryExp: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少基本表达式"));
}
@ -82,9 +73,7 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (ctx->HEX_FLOAT()) {
std::string hex_float_str = ctx->HEX_FLOAT()->getText();
float value = 0.0f;
// 解析十六进制浮点数
try {
// C++11 的 std::stof 支持十六进制浮点数表示
value = std::stof(hex_float_str);
} catch (const std::exception& e) {
std::cerr << "[WARNING] 无法解析十六进制浮点数: " << hex_float_str
@ -97,7 +86,6 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
return static_cast<ir::Value*>(const_float);
}
// 处理十进制浮点常量
if (ctx->DEC_FLOAT()) {
std::string dec_float_str = ctx->DEC_FLOAT()->getText();
float value = 0.0f;
@ -118,6 +106,8 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::string hex = ctx->HEX_INT()->getText();
int value = std::stoi(hex, nullptr, 16);
ir::Value* const_int = builder_.CreateConstInt(value);
std::cerr << "[DEBUG] visitPrimaryExp: constant hex int " << value
<< " created as " << (void*)const_int << std::endl;
return static_cast<ir::Value*>(const_int);
}
@ -125,11 +115,14 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::string oct = ctx->OCTAL_INT()->getText();
int value = std::stoi(oct, nullptr, 8);
ir::Value* const_int = builder_.CreateConstInt(value);
std::cerr << "[DEBUG] visitPrimaryExp: constant octal int " << value
<< " created as " << (void*)const_int << std::endl;
return static_cast<ir::Value*>(const_int);
}
if (ctx->ZERO()) {
ir::Value* const_int = builder_.CreateConstInt(0);
std::cerr << "[DEBUG] visitPrimaryExp: constant zero int created" << std::endl;
return static_cast<ir::Value*>(const_int);
}
@ -149,12 +142,9 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型"));
}
// 左值(变量)处理
// 1. 先通过语义分析结果把变量使用绑定回声明;
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
// 3. 最后生成 load把内存中的值读出来。
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
@ -162,42 +152,95 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
std::string varName = ctx->Ident()->getText();
std::cerr << "[DEBUG] visitLVal: " << varName << std::endl;
// 优先检查是否是常量
auto const_it = const_value_map_.find(varName);
if (const_it != const_value_map_.end()) {
// 常量直接返回值不需要load
std::cerr << "[DEBUG] visitLVal: constant " << varName << std::endl;
return static_cast<ir::Value*>(const_it->second);
}
// 检查全局常量
auto const_global_it = const_global_map_.find(varName);
if (const_global_it != const_global_map_.end()) {
// 全局常量需要load
ir::Value* ptr = const_global_it->second;
if (!ctx->exp().empty()) {
// 数组访问
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : ctx->exp()) {
ir::Value* index = EvalExpr(*exp);
indices.push_back(index);
// 先检查语义分析中常量绑定
const SysYParser::ConstDefContext* const_decl = sema_.ResolveConstUse(ctx);
const Symbol* sym = nullptr;
if (const_decl) {
sym = symbol_table_.lookupByConstDef(const_decl);
if (!sym) {
sym = symbol_table_.lookupAll(varName);
}
} else {
sym = symbol_table_.lookup(varName);
}
// 如果是常量,直接返回常量值
if (sym && sym->kind == SymbolKind::Constant) {
std::cerr << "[DEBUG] visitLVal: 找到常量 " << varName << std::endl;
if (sym->IsScalarConstant()) {
if (sym->type->IsInt32()) {
ir::ConstantValue* const_val = builder_.CreateConstInt(sym->GetIntConstant());
return static_cast<ir::Value*>(const_val);
} else if (sym->type->IsFloat()) {
ir::ConstantValue* const_val = builder_.CreateConstFloat(sym->GetFloatConstant());
return static_cast<ir::Value*>(const_val);
}
} else if (sym->IsArrayConstant()) {
auto it = const_global_map_.find(varName);
if (it != const_global_map_.end()) {
ir::GlobalValue* global_array = it->second;
// 尝试获取类型信息,用于维度判断与下标线性化
auto* array_ty = dynamic_cast<ir::ArrayType*>(sym->type.get());
if (!array_ty) {
// 无法获取数组类型,退回返回全局对象
return static_cast<ir::Value*>(global_array);
}
size_t ndims = array_ty->GetDimensions().size();
// 有下标访问
if (!ctx->exp().empty()) {
size_t provided = ctx->exp().size();
// 完全索引(所有维度都有下标)——直接返回常量元素,不生成 Load
if (provided == ndims) {
std::vector<int> idxs;
idxs.reserve(provided);
for (auto* exp : ctx->exp()) {
ir::Value* v = EvalExpr(*exp);
if (!v || !v->IsConstant()) {
throw std::runtime_error(FormatError("irgen", "常量数组索引必须为常量整数: " + varName));
}
auto* ci = dynamic_cast<ir::ConstantInt*>(v);
if (!ci) {
throw std::runtime_error(FormatError("irgen", "常量数组索引非整型常量: " + varName));
}
idxs.push_back(ci->GetValue());
}
// 计算线性下标(行主序)
const auto& dims = array_ty->GetDimensions();
int flat = idxs[0];
for (size_t i = 1; i < ndims; ++i) {
flat = flat * dims[i] + idxs[i];
}
ir::ConstantValue* elem = global_array->GetArrayElement(static_cast<size_t>(flat));
return static_cast<ir::Value*>(elem);
}
// 部分索引:返回指针(不做 Load由上层按需处理
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : ctx->exp()) {
indices.push_back(EvalExpr(*exp));
}
return static_cast<ir::Value*>(
builder_.CreateGEP(global_array, indices, module_.GetContext().NextTemp()));
} else {
// 无下标,直接返回全局常量对象
return static_cast<ir::Value*>(global_array);
}
}
ir::Value* elem_ptr = builder_.CreateGEP(
ptr, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
} else {
return static_cast<ir::Value*>(
builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
}
}
// 不是常量,按正常变量处理
// ... 原有的变量查找代码 ...
auto* decl = sema_.ResolveVarUse(ctx);
ir::Value* ptr = nullptr;
if (decl) {
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
@ -234,27 +277,124 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
// 检查是否有数组下标
bool is_array_access = !ctx->exp().empty();
if (is_array_access) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
// 收集下标表达式不含前导0
std::vector<ir::Value*> idx_vals;
for (auto* exp : ctx->exp()) {
ir::Value* index = EvalExpr(*exp);
indices.push_back(index);
idx_vals.push_back(index);
}
ir::Value* elem_ptr = builder_.CreateGEP(
ptr, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
const Symbol* var_sym = sym;
if (!var_sym) {
var_sym = symbol_table_.lookup(varName);
}
if (!var_sym && decl) {
var_sym = symbol_table_.lookupByVarDef(decl);
}
if (!var_sym) {
var_sym = symbol_table_.lookupAll(varName);
}
std::vector<int> dims;
if (var_sym) {
if (var_sym->is_array_param && !var_sym->array_dims.empty()) {
dims = var_sym->array_dims;
} else if (var_sym->type && var_sym->type->IsArray()) {
auto* at = dynamic_cast<ir::ArrayType*>(var_sym->type.get());
if (at) dims = at->GetDimensions();
}
}
if (dims.empty() && ptr->GetType()->IsArray()) {
if (auto* at = dynamic_cast<ir::ArrayType*>(ptr->GetType().get())) {
dims = at->GetDimensions();
}
}
// 兜底:从语法树声明提取维度,避免作用域关闭后符号查询不完整。
if (dims.empty() && const_decl) {
auto* mutable_const_decl = const_cast<SysYParser::ConstDefContext*>(const_decl);
for (auto* cexp : mutable_const_decl->constExp()) {
dims.push_back(symbol_table_.EvaluateConstExp(cexp));
}
}
if (dims.empty() && decl) {
for (auto* cexp : decl->constExp()) {
dims.push_back(symbol_table_.EvaluateConstExp(cexp));
}
}
const bool is_partial_array_access =
!dims.empty() && idx_vals.size() < dims.size();
// 如果 base 是标量指针(例如局部扁平数组或数组参数),
// 需要把多维下标折合为单一线性下标,然后用一个索引进行 GEP。
if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) {
// 如果没有维度信息,仍尝试用运行时算术合并下标(按后维乘积)
// flat = idx0 * (prod dims[1..]) + idx1 * (prod dims[2..]) + ...
ir::Value* flat = nullptr;
for (size_t i = 0; i < idx_vals.size(); ++i) {
ir::Value* term = idx_vals[i];
if (!term) continue;
// 计算乘数(后续维度乘积)
int mult = 1;
if (!dims.empty() && i + 1 < dims.size()) {
for (size_t j = i + 1; j < dims.size(); ++j) {
// 数组参数首维可能是 0表示省略不参与乘数。
if (dims[j] > 0) mult *= dims[j];
}
}
if (mult != 1) {
auto* mval = builder_.CreateConstInt(mult);
term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp());
}
if (!flat) flat = term;
else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp());
}
if (!flat) flat = builder_.CreateConstInt(0);
// 使用单一索引创建 GEP
std::vector<ir::Value*> gep_indices = { flat };
ir::Value* elem_ptr = builder_.CreateGEP(ptr, gep_indices, module_.GetContext().NextTemp());
if (is_partial_array_access) {
return elem_ptr;
}
return static_cast<ir::Value*>(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
}
std::vector<ir::Value*> indices;
// 标量指针T*使用单索引数组对象使用前导0进入首层。
if (ptr->GetType()->IsPtrInt32() || ptr->GetType()->IsPtrFloat()) {
for (auto* v : idx_vals) indices.push_back(v);
} else {
indices.push_back(builder_.CreateConstInt(0));
for (auto* v : idx_vals) indices.push_back(v);
}
ir::Value* elem_ptr = builder_.CreateGEP(ptr, indices, module_.GetContext().NextTemp());
if (is_partial_array_access) {
return elem_ptr;
}
return static_cast<ir::Value*>(builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp()));
} else {
return static_cast<ir::Value*>(
builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
if ((sym && sym->is_array_param) ||
pointer_param_names_.find(varName) != pointer_param_names_.end() ||
heap_local_array_names_.find(varName) != heap_local_array_names_.end()) {
return ptr;
}
if (ptr->GetType()->IsArray()) {
return ptr;
}
return static_cast<ir::Value*>(builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
}
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
std::cout << "[DEBUG IRGEN] visitAddExp: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
@ -318,7 +458,7 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
std::cout << "[DEBUG IRGEN] visitMulExp: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
@ -392,6 +532,7 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
// 逻辑与
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
if (!ctx->lAndExp()) {
@ -400,14 +541,28 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
ir::Value* left = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
ir::Value* right = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
auto zero = builder_.CreateConstInt(0);
auto left_bool = builder_.CreateICmpNE(left, zero, module_.GetContext().NextTemp());
auto right_bool = builder_.CreateICmpNE(right, zero, module_.GetContext().NextTemp());
return builder_.CreateAnd(left_bool, right_bool, module_.GetContext().NextTemp());
auto to_bool = [&](ir::Value* v) -> ir::Value* {
if (v->GetType()->IsInt1()) {
return v;
}
if (v->GetType()->IsFloat()) {
return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmpNE(v, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
};
auto* left_bool = to_bool(left);
auto* right_bool = to_bool(right);
return static_cast<ir::Value*>(
builder_.CreateAnd(left_bool, right_bool, module_.GetContext().NextTemp()));
}
// 逻辑或
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
if (!ctx->lOrExp()) {
@ -416,23 +571,39 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
ir::Value* left = std::any_cast<ir::Value*>(ctx->lOrExp()->accept(this));
ir::Value* right = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
auto zero = builder_.CreateConstInt(0);
auto left_bool = builder_.CreateICmpNE(left, zero, module_.GetContext().NextTemp());
auto right_bool = builder_.CreateICmpNE(right, zero, module_.GetContext().NextTemp());
return builder_.CreateOr(left_bool, right_bool, module_.GetContext().NextTemp());
auto to_bool = [&](ir::Value* v) -> ir::Value* {
if (v->GetType()->IsInt1()) {
return v;
}
if (v->GetType()->IsFloat()) {
return builder_.CreateFCmpONE(v, builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmpNE(v, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
};
auto* left_bool = to_bool(left);
auto* right_bool = to_bool(right);
return static_cast<ir::Value*>(
builder_.CreateOr(left_bool, right_bool, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法表达式"));
return ctx->addExp()->accept(this);
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
return ctx->lOrExp()->accept(this);
}
std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法函数调用"));
}
@ -466,6 +637,33 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
}
}
// 按形参类型修正实参(数组衰减为指针等)。
if (auto* fty = dynamic_cast<ir::FunctionType*>(callee->GetType().get())) {
const auto& param_tys = fty->GetParamTypes();
size_t n = std::min(param_tys.size(), args.size());
for (size_t i = 0; i < n; ++i) {
if (!args[i] || !param_tys[i]) continue;
// 数组实参传给指针形参时,执行数组到指针衰减。
if (args[i]->GetType()->IsArray() &&
(param_tys[i]->IsPtrInt32() || param_tys[i]->IsPtrFloat())) {
std::vector<ir::Value*> idx;
idx.push_back(builder_.CreateConstInt(0));
idx.push_back(builder_.CreateConstInt(0));
args[i] = builder_.CreateGEP(args[i], idx, module_.GetContext().NextTemp());
}
// 标量实参的隐式类型转换int <-> float
if (param_tys[i]->IsFloat() && args[i]->GetType()->IsInt32()) {
args[i] = builder_.CreateSIToFP(args[i], ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (param_tys[i]->IsInt32() && args[i]->GetType()->IsFloat()) {
args[i] = builder_.CreateFPToSI(args[i], ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
}
}
// 生成调用指令
ir::Value* callResult = builder_.CreateCall(callee, args, module_.GetContext().NextTemp());
@ -481,7 +679,7 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) {
// 动态创建运行时函数声明的辅助函数
ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) {
std::cout << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: " << funcName << std::endl;
std::cerr << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName << std::endl;
// 根据常见运行时函数名创建对应的函数类型
if (funcName == "getint" || funcName == "getch") {
@ -498,7 +696,7 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName)
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
{ir::Type::GetPtrInt32Type()}));
}
else if (funcName == "putarray") {
return module_.CreateFunction(funcName,
@ -522,15 +720,27 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName)
ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()}));
}
else if (funcName == "read_map") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}));
else if (funcName == "getfloat") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(ir::Type::GetFloatType(), {}));
}
else if (funcName == "float_eq") {
return module_.CreateFunction(funcName,
else if (funcName == "putfloat") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetFloatType()}));
}
else if (funcName == "getfarray") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetFloatType(), ir::Type::GetFloatType()}));
{ir::Type::GetPtrFloatType()}));
}
else if (funcName == "putfarray") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()}));
}
else if (funcName == "memset") {
return module_.CreateFunction(funcName,
@ -540,12 +750,49 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName)
ir::Type::GetInt32Type(),
ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_alloc_i32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetPtrInt32Type(),
{ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_alloc_f32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetPtrFloatType(),
{ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_free_i32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type()}));
}
else if (funcName == "sysy_free_f32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType()}));
}
else if (funcName == "sysy_zero_i32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
}
else if (funcName == "sysy_zero_f32") {
return module_.CreateFunction(funcName,
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()}));
}
// 其他函数不支持动态创建
return nullptr;
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
}
@ -587,14 +834,15 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
ir::Value* zero;
if (operand->GetType()->IsFloat()) {
zero = builder_.CreateConstFloat(0.0f);
// 浮点比较:不等于0
ir::Value* cmp = builder_.CreateFCmpONE(operand, zero, module_.GetContext().NextTemp());
// 浮点逻辑非x == 0.0
ir::Value* cmp = builder_.CreateFCmpOEQ(operand, zero, module_.GetContext().NextTemp());
// 将bool转换为int
return static_cast<ir::Value*>(
builder_.CreateZExt(cmp, ir::Type::GetInt32Type()));
} else {
zero = builder_.CreateConstInt(0);
return builder_.CreateNot(operand, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateNot(operand, module_.GetContext().NextTemp()));
}
}
}
@ -604,6 +852,7 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 实现函数调用
std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) return std::vector<ir::Value*>{};
std::vector<ir::Value*> args;
for (auto* exp : ctx->exp()) {
@ -612,67 +861,37 @@ std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
return args;
}
// 修改 visitConstExp 以支持常量表达式求值
// visitConstExp - 处理常量表达式
std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) {
if (!ctx) {
std::cerr << "[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法常量表达式"));
}
auto result = ctx->addExp()->accept(this);
if (!result.has_value()) {
throw std::runtime_error(FormatError("irgen", "常量表达式求值失败"));
}
try {
if (ctx->addExp()) {
// 尝试获取数值
auto result = ctx->addExp()->accept(this);
if (result.has_value()) {
try {
ir::Value* value = std::any_cast<ir::Value*>(result);
// 尝试判断是否是 ConstantInt
// 暂时简化:返回 IR 值
return static_cast<ir::Value*>(value);
} catch (const std::bad_any_cast&) {
// 可能是其他类型
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
}
}
}
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
} catch (const std::exception& e) {
std::cerr << "[WARNING] visitConstExp: 常量表达式求值失败: " << e.what()
<< "返回0" << std::endl;
// 如果普通表达式求值失败返回0
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
return std::any_cast<ir::Value*>(result);
} catch (const std::bad_any_cast& e) {
throw std::runtime_error(FormatError("irgen",
"常量表达式返回类型错误: " + std::string(e.what())));
}
}
// visitConstInitVal - 处理常量初始化值
std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法常量初始化值"));
}
// 如果是单个常量表达式
if (ctx->constExp()) {
try {
auto result = ctx->constExp()->accept(this);
if (result.has_value()) {
try {
ir::Value* value = std::any_cast<ir::Value*>(result);
// 尝试提取常量值
if (auto* const_int = dynamic_cast<ir::ConstantInt*>(value)) {
return static_cast<ir::Value*>(const_int);
} else {
// 如果不是常量,尝试计算数值
int int_val = TryEvaluateConstInt(ctx->constExp());
return static_cast<ir::Value*>(builder_.CreateConstInt(int_val));
}
} catch (const std::bad_any_cast&) {
int int_val = TryEvaluateConstInt(ctx->constExp());
return static_cast<ir::Value*>(builder_.CreateConstInt(int_val));
}
}
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
} catch (const std::exception& e) {
std::cerr << "[WARNING] visitConstInitVal: " << e.what() << std::endl;
return static_cast<ir::Value*>(builder_.CreateConstInt(0));
}
return ctx->constExp()->accept(this);
}
// 如果是聚合初始化(花括号列表)
else if (!ctx->constInitVal().empty()) {
@ -680,22 +899,24 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
for (auto* init_val : ctx->constInitVal()) {
auto result = init_val->accept(this);
if (result.has_value()) {
if (!result.has_value()) {
throw std::runtime_error(FormatError("irgen", "常量初始化值求值失败"));
}
try {
// 尝试获取单个常量值
ir::Value* value = std::any_cast<ir::Value*>(result);
all_values.push_back(value);
} catch (const std::bad_any_cast&) {
try {
// 尝试获取单个常量值
ir::Value* value = std::any_cast<ir::Value*>(result);
all_values.push_back(value);
} catch (const std::bad_any_cast&) {
try {
// 尝试获取值列表(嵌套情况)
std::vector<ir::Value*> nested_values =
std::any_cast<std::vector<ir::Value*>>(result);
all_values.insert(all_values.end(),
nested_values.begin(), nested_values.end());
} catch (const std::bad_any_cast&) {
throw std::runtime_error(
FormatError("irgen", "不支持的常量初始化值类型"));
}
// 尝试获取值列表(嵌套情况)
std::vector<ir::Value*> nested_values =
std::any_cast<std::vector<ir::Value*>>(result);
all_values.insert(all_values.end(),
nested_values.begin(), nested_values.end());
} catch (const std::bad_any_cast& e) {
throw std::runtime_error(FormatError("irgen",
"不支持的常量初始化值类型: " + std::string(e.what())));
}
}
}
@ -708,6 +929,7 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
@ -782,6 +1004,7 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
std::cerr << "[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
@ -839,6 +1062,7 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
std::cerr << "[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cout << "[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx || !ctx->lVal() || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法赋值语句"));
@ -864,15 +1088,29 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
ir::Value* base_ptr = it->second;
auto convert_for_store = [&](ir::Value* value, ir::Value* ptr) -> ir::Value* {
if (ptr->GetType()->IsPtrFloat() && value->GetType()->IsInt32()) {
return builder_.CreateSIToFP(value, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
}
if (ptr->GetType()->IsPtrInt32() && value->GetType()->IsFloat()) {
return builder_.CreateFPToSI(value, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
return value;
};
// 检查是否有数组下标
auto exp_list = lval->exp();
if (!exp_list.empty()) {
// 这是数组元素赋值需要生成GEP指令
std::vector<ir::Value*> indices;
// 第一个索引是0假设一维数组
indices.push_back(builder_.CreateConstInt(0));
// 标量指针参数T*不应添加前导0数组对象需要前导0。
if (!(base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat())) {
indices.push_back(builder_.CreateConstInt(0));
}
// 添加用户提供的下标
for (auto* exp : exp_list) {
ir::Value* index = EvalExpr(*exp);
@ -884,6 +1122,7 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
base_ptr, indices, module_.GetContext().NextTemp());
// 生成store指令
rhs = convert_for_store(rhs, elem_ptr);
builder_.CreateStore(rhs, elem_ptr);
} else {
// 普通标量赋值
@ -891,11 +1130,12 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) {
std::cerr << "[DEBUG] base_ptr type: " << base_ptr->GetType() << std::endl;
std::cerr << "[DEBUG] rhs type: " << rhs->GetType()<< std::endl;
// 如果 base_ptr 不是指针类型,可能需要特殊处理
if (!base_ptr->GetType()->IsPtrInt32()) {
// 如果 base_ptr 不是标量指针类型,可能需要特殊处理
if (!base_ptr->GetType()->IsPtrInt32() && !base_ptr->GetType()->IsPtrFloat()) {
std::cerr << "[ERROR] base_ptr is not a pointer type!" << std::endl;
throw std::runtime_error("尝试存储到非指针类型");
}
rhs = convert_for_store(rhs, base_ptr);
builder_.CreateStore(rhs, base_ptr);
}
} else {

@ -20,18 +20,41 @@ void VerifyFunctionStructure(const ir::Function& func) {
}
}
} // namespace
bool HasDirectSelfCall(antlr4::ParserRuleContext* node,
const std::string& func_name) {
if (!node) {
return false;
}
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module),
sema_(sema),
func_(nullptr),
builder_(module.GetContext(), nullptr) {
AddRuntimeFunctions();
if (auto* unary = dynamic_cast<SysYParser::UnaryExpContext*>(node)) {
if (unary->Ident() && unary->Ident()->getText() == func_name) {
return true;
}
}
for (auto* child : node->children) {
if (auto* rule = dynamic_cast<antlr4::ParserRuleContext*>(child)) {
if (HasDirectSelfCall(rule, func_name)) {
return true;
}
}
}
return false;
}
} // namespace
// 实现新的构造函数
IRGenImpl::IRGenImpl(ir::Module& module,
const SemanticContext& sema,
const SymbolTable& sym_table)
: module_(module), sema_(sema), symbol_table_(sym_table),
builder_(module.GetContext(), nullptr), func_(nullptr) {
AddRuntimeFunctions();
}
void IRGenImpl::AddRuntimeFunctions() {
std::cout << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl;
std::cerr << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl;
// 输入函数(返回 int
module_.CreateFunction("getint",
@ -43,7 +66,7 @@ void IRGenImpl::AddRuntimeFunctions() {
module_.CreateFunction("getarray",
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
{ir::Type::GetPtrInt32Type()}));
// 输出函数(返回 void
module_.CreateFunction("putint",
@ -83,16 +106,22 @@ void IRGenImpl::AddRuntimeFunctions() {
module_.CreateFunction("stoptime",
ir::Type::GetFunctionType(ir::Type::GetVoidType(), {}));
// 其他可能需要的函数
module_.CreateFunction("read_map",
ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}));
// 浮点数
module_.CreateFunction("float_eq",
// 浮点 I/O
module_.CreateFunction("getfloat",
ir::Type::GetFunctionType(ir::Type::GetFloatType(), {}));
module_.CreateFunction("putfloat",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetFloatType()}));
module_.CreateFunction("getfarray",
ir::Type::GetFunctionType(
ir::Type::GetInt32Type(),
{ir::Type::GetFloatType(), ir::Type::GetFloatType()}));
{ir::Type::GetPtrFloatType()}));
module_.CreateFunction("putfarray",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()}));
// 内存操作函数
module_.CreateFunction("memset",
ir::Type::GetFunctionType(
@ -100,13 +129,48 @@ void IRGenImpl::AddRuntimeFunctions() {
{ir::Type::GetPtrInt32Type(),
ir::Type::GetInt32Type(),
ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_alloc_i32",
ir::Type::GetFunctionType(
ir::Type::GetPtrInt32Type(),
{ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_alloc_f32",
ir::Type::GetFunctionType(
ir::Type::GetPtrFloatType(),
{ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_free_i32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type()}));
module_.CreateFunction("sysy_free_f32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType()}));
module_.CreateFunction("sysy_zero_i32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrInt32Type(), ir::Type::GetInt32Type()}));
module_.CreateFunction("sysy_zero_f32",
ir::Type::GetFunctionType(
ir::Type::GetVoidType(),
{ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()}));
std::cout << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl;
std::cerr << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl;
}
// 修正:没有 mainFuncDef通过函数名找到 main
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
std::cout << "[DEBUG IRGEN] visitCompUnit" << std::endl;
std::cerr << "[DEBUG IRGEN] visitCompUnit" << std::endl;
std::cerr << "[DEBUG] IRGen: 符号表地址 = " << &symbol_table_ << std::endl;
std::cerr << "[DEBUG] IRGen: 开始生成 IR" << std::endl;
// 尝试查找 main 函数
const Symbol* main_sym = symbol_table_.lookup("main");
if (main_sym) {
std::cerr << "[DEBUG] IRGen: 找到 main 函数符号" << std::endl;
} else {
std::cerr << "[DEBUG] IRGen: 未找到 main 函数符号" << std::endl;
}
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
}
@ -129,7 +193,7 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
}
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
std::cout << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
@ -216,6 +280,19 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
builder_.SetInsertPoint(entry_block);
storage_map_.clear();
param_map_.clear();
pointer_param_names_.clear();
heap_local_array_names_.clear();
current_function_name_ = funcName;
current_function_is_recursive_ = HasDirectSelfCall(ctx->block(), funcName);
function_cleanup_block_ = nullptr;
function_cleanup_actions_.clear();
function_return_slot_ = nullptr;
if (ret_type->IsInt32()) {
function_return_slot_ = CreateEntryAllocaI32(module_.GetContext().NextTemp() + ".retval");
} else if (ret_type->IsFloat()) {
function_return_slot_ = CreateEntryAllocaFloat(module_.GetContext().NextTemp() + ".retval");
}
// 函数参数处理
if (ctx->funcFParams()) {
@ -271,23 +348,29 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
throw std::runtime_error(FormatError("irgen", "添加参数失败: " + name));
}
// 为参数创建存储槽位
ir::AllocaInst* slot = nullptr;
if (param_ty->IsInt32() || param_ty->IsPtrInt32()) {
slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
} else if (param_ty->IsFloat() || param_ty->IsPtrFloat()) {
slot = builder_.CreateAllocaFloat(module_.GetContext().NextTemp());
// 标量参数:入栈到本地槽位;数组参数(指针)直接作为地址使用。
if (param_ty->IsPtrInt32() || param_ty->IsPtrFloat()) {
param_map_[name] = added_arg;
pointer_param_names_.insert(name);
} else {
throw std::runtime_error(FormatError("irgen", "不支持的参数类型"));
}
if (!slot) {
throw std::runtime_error(FormatError("irgen", "创建参数存储槽位失败: " + name));
ir::AllocaInst* slot = nullptr;
if (param_ty->IsInt32()) {
slot = CreateEntryAllocaI32(module_.GetContext().NextTemp());
} else if (param_ty->IsFloat()) {
slot = CreateEntryAllocaFloat(module_.GetContext().NextTemp());
} else {
throw std::runtime_error(FormatError("irgen", "不支持的参数类型"));
}
if (!slot) {
throw std::runtime_error(FormatError("irgen", "创建参数存储槽位失败: " + name));
}
builder_.CreateStore(added_arg, slot);
param_map_[name] = slot;
pointer_param_names_.erase(name);
}
builder_.CreateStore(added_arg, slot);
param_map_[name] = slot;
std::cerr << "[DEBUG] visitFuncDef: 参数 " << name << " 处理完成" << std::endl;
}
}
@ -296,11 +379,37 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
std::cerr << "[DEBUG] visitFuncDef: 开始生成函数体" << std::endl;
ctx->block()->accept(this);
// 如果函数没有终止指令,添加默认返回
if (!func_->GetEntry()->HasTerminator()) {
// 如果当前插入块没有终止指令,添加默认返回
if (auto* cur = builder_.GetInsertBlock(); cur && !cur->HasTerminator()) {
std::cerr << "[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回" << std::endl;
auto retVal = builder_.CreateConstInt(0);
builder_.CreateRet(retVal);
if (function_cleanup_block_) {
if (ret_type->IsFloat()) {
builder_.CreateStore(builder_.CreateConstFloat(0.0f), function_return_slot_);
} else if (ret_type->IsInt32()) {
builder_.CreateStore(builder_.CreateConstInt(0), function_return_slot_);
}
builder_.CreateBr(function_cleanup_block_);
} else if (ret_type->IsVoid()) {
builder_.CreateRet(nullptr);
} else if (ret_type->IsFloat()) {
builder_.CreateRet(builder_.CreateConstFloat(0.0f));
} else {
builder_.CreateRet(builder_.CreateConstInt(0));
}
}
if (function_cleanup_block_ && !function_cleanup_block_->HasTerminator()) {
builder_.SetInsertPoint(function_cleanup_block_);
for (auto it = function_cleanup_actions_.rbegin();
it != function_cleanup_actions_.rend(); ++it) {
builder_.CreateCall(it->first, {it->second}, module_.GetContext().NextTemp());
}
if (ret_type->IsVoid()) {
builder_.CreateRet(nullptr);
} else {
builder_.CreateRet(builder_.CreateLoad(function_return_slot_, module_.GetContext().NextTemp()));
}
}
// 验证函数结构
@ -313,12 +422,52 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
std::cerr << "[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成" << std::endl;
func_ = nullptr;
current_function_name_.clear();
current_function_is_recursive_ = false;
function_return_slot_ = nullptr;
function_cleanup_block_ = nullptr;
function_cleanup_actions_.clear();
return {};
}
ir::BasicBlock* IRGenImpl::EnsureCleanupBlock() {
if (!function_cleanup_block_) {
std::string name = module_.GetContext().NextTemp();
if (!name.empty() && name[0] == '%') {
name.erase(0, 1);
}
function_cleanup_block_ = func_->CreateBlock("cleanup." + name);
}
return function_cleanup_block_;
}
void IRGenImpl::RegisterCleanup(ir::Function* free_func, ir::Value* ptr) {
if (!free_func || !ptr) {
return;
}
EnsureCleanupBlock();
function_cleanup_actions_.push_back({free_func, ptr});
}
ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name) {
if (!func_ || !func_->GetEntry()) {
throw std::runtime_error(FormatError("irgen", "缺少函数入口块,无法创建入口栈槽位"));
}
return func_->GetEntry()->InsertBeforeTerminator<ir::AllocaInst>(ty, name);
}
ir::AllocaInst* IRGenImpl::CreateEntryAllocaI32(const std::string& name) {
return CreateEntryAlloca(ir::Type::GetPtrInt32Type(), name);
}
ir::AllocaInst* IRGenImpl::CreateEntryAllocaFloat(const std::string& name) {
return CreateEntryAlloca(ir::Type::GetPtrFloatType(), name);
}
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
std::cout << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
@ -333,7 +482,7 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
}
auto* cur = builder_.GetInsertBlock();
std::cout << "[DEBUG] current insert block: "
std::cerr << "[DEBUG] current insert block: "
<< (cur ? cur->GetName() : "<null>") << std::endl;
if (cur && cur->HasTerminator()) {
break;
@ -351,7 +500,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
}
// 用于遍历块内项,返回是否继续访问后续项(如遇到 return/break/continue 则终止访问)
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
std::cout << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}

@ -16,7 +16,7 @@
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
std::cout << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
@ -65,7 +65,7 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
// 修改 HandleReturnStmt 函数
IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
std::cout << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
}
@ -88,8 +88,12 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
// 表达式被忽略(可计算但不使用)
EvalExpr(*ctx->exp());
}
// 对于void函数创建返回指令不传参数
builder_.CreateRet(nullptr);
if (function_cleanup_block_) {
builder_.CreateBr(function_cleanup_block_);
} else {
// 对于void函数创建返回指令不传参数
builder_.CreateRet(nullptr);
}
} else {
ir::Value* retValue = nullptr;
if (ctx->exp()) {
@ -115,7 +119,12 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
retValue = builder_.CreateConstInt(0); // fallback
}
}
builder_.CreateRet(retValue);
if (function_cleanup_block_) {
builder_.CreateStore(retValue, function_return_slot_);
builder_.CreateBr(function_cleanup_block_);
} else {
builder_.CreateRet(retValue);
}
}
return BlockFlow::Terminated;
}
@ -123,49 +132,63 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) {
// if语句待实现
IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) {
std::cout << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
std::cerr << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "<null>") << std::endl;
auto* cond = ctx->cond();
auto* thenStmt = ctx->stmt(0);
auto* elseStmt = ctx->stmt(1);
// 创建基本块
auto* thenBlock = func_->CreateBlock("then");
auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock("else") : nullptr;
auto* mergeBlock = func_->CreateBlock("merge");
std::cout << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl;
if (elseBlock) std::cout << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl;
std::cout << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl;
std::cout << "[DEBUG IF] current insert block before cond: "
// 创建基本块(使用唯一名称,避免同名标签)
auto uniq = [&](const std::string& prefix) {
std::string t = module_.GetContext().NextTemp();
if (!t.empty() && t[0] == '%') t.erase(0, 1);
return prefix + "." + t;
};
auto* thenBlock = func_->CreateBlock(uniq("then"));
auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock(uniq("else")) : nullptr;
auto* mergeBlock = func_->CreateBlock(uniq("merge"));
std::cerr << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl;
if (elseBlock) std::cerr << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl;
std::cerr << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl;
std::cerr << "[DEBUG IF] current insert block before cond: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
// 生成条件
auto* condValue = EvalCond(*cond);
if (!condValue->GetType()->IsInt1()) {
if (condValue->GetType()->IsFloat()) {
condValue = builder_.CreateFCmpONE(
condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp());
} else {
condValue = builder_.CreateICmpNE(
condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp());
}
}
// 创建条件跳转
if (elseBlock) {
std::cout << "[DEBUG IF] Creating condbr: " << condValue->GetName()
std::cerr << "[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << elseBlock->GetName() << std::endl;
builder_.CreateCondBr(condValue, thenBlock, elseBlock);
} else {
std::cout << "[DEBUG IF] Creating condbr: " << condValue->GetName()
std::cerr << "[DEBUG IF] Creating condbr: " << condValue->GetName()
<< " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName() << std::endl;
builder_.CreateCondBr(condValue, thenBlock, mergeBlock);
}
// 生成 then 分支
std::cout << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl;
std::cerr << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl;
builder_.SetInsertPoint(thenBlock);
auto thenResult = thenStmt->accept(this);
bool thenTerminated = (std::any_cast<BlockFlow>(thenResult) == BlockFlow::Terminated);
std::cout << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl;
std::cerr << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl;
if (!thenTerminated) {
std::cout << "[DEBUG IF] Adding br to merge block from then" << std::endl;
std::cerr << "[DEBUG IF] Adding br to merge block from then" << std::endl;
builder_.CreateBr(mergeBlock);
}
std::cout << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl;
std::cerr << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl;
// 生成 else 分支
bool elseTerminated = false;
@ -188,16 +211,9 @@ IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) {
<< ", elseTerminated=" << elseTerminated << std::endl;
if (elseBlock) {
if (thenTerminated && elseTerminated) {
auto* afterIfBlock = func_->CreateBlock("after.if");
std::cout << "[DEBUG IF] Both branches terminated, creating new block: "
<< afterIfBlock->GetName() << std::endl;
builder_.SetInsertPoint(afterIfBlock);
} else {
std::cout << "[DEBUG IF] Setting insert point to merge block: "
<< mergeBlock->GetName() << std::endl;
builder_.SetInsertPoint(mergeBlock);
}
std::cout << "[DEBUG IF] Setting insert point to merge block: "
<< mergeBlock->GetName() << std::endl;
builder_.SetInsertPoint(mergeBlock);
} else {
std::cout << "[DEBUG IF] No else, setting insert point to merge block: "
<< mergeBlock->GetName() << std::endl;
@ -221,9 +237,14 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
std::cout << "[DEBUG WHILE] Current insert block before while: "
<< builder_.GetInsertBlock()->GetName() << std::endl;
auto* condBlock = func_->CreateBlock("while.cond");
auto* bodyBlock = func_->CreateBlock("while.body");
auto* exitBlock = func_->CreateBlock("while.exit");
auto uniq = [&](const std::string& prefix) {
std::string t = module_.GetContext().NextTemp();
if (!t.empty() && t[0] == '%') t.erase(0, 1);
return prefix + "." + t;
};
auto* condBlock = func_->CreateBlock(uniq("while.cond"));
auto* bodyBlock = func_->CreateBlock(uniq("while.body"));
auto* exitBlock = func_->CreateBlock(uniq("while.exit"));
std::cout << "[DEBUG WHILE] condBlock: " << condBlock->GetName() << std::endl;
std::cout << "[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName() << std::endl;
@ -239,6 +260,15 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) {
std::cout << "[DEBUG WHILE] Generating condition in block: " << condBlock->GetName() << std::endl;
builder_.SetInsertPoint(condBlock);
auto* condValue = EvalCond(*ctx->cond());
if (!condValue->GetType()->IsInt1()) {
if (condValue->GetType()->IsFloat()) {
condValue = builder_.CreateFCmpONE(
condValue, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp());
} else {
condValue = builder_.CreateICmpNE(
condValue, builder_.CreateConstInt(0), module_.GetContext().NextTemp());
}
}
builder_.CreateCondBr(condValue, bodyBlock, exitBlock);
std::cout << "[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator() << std::endl;
@ -387,17 +417,83 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
auto exp_list = lval->exp();
if (!exp_list.empty()) {
// 数组元素赋值
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
std::vector<ir::Value*> idx_vals;
for (auto* exp : exp_list) {
ir::Value* index = EvalExpr(*exp);
indices.push_back(index);
idx_vals.push_back(index);
}
ir::Value* elem_ptr = nullptr;
// 扁平数组/数组参数T*)的多维访问:先线性化,再单索引 GEP。
if ((base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) && idx_vals.size() > 1) {
const Symbol* var_sym = symbol_table_.lookup(varName);
if (!var_sym && var_decl) {
var_sym = symbol_table_.lookupByVarDef(var_decl);
}
if (!var_sym) {
var_sym = symbol_table_.lookupAll(varName);
}
std::vector<int> dims;
if (var_sym) {
if (var_sym->is_array_param && !var_sym->array_dims.empty()) {
dims = var_sym->array_dims;
} else if (var_sym->type && var_sym->type->IsArray()) {
auto* at = dynamic_cast<ir::ArrayType*>(var_sym->type.get());
if (at) dims = at->GetDimensions();
}
}
if (dims.empty() && var_decl) {
for (auto* cexp : var_decl->constExp()) {
dims.push_back(symbol_table_.EvaluateConstExp(cexp));
}
}
ir::Value* flat = nullptr;
for (size_t i = 0; i < idx_vals.size(); ++i) {
ir::Value* term = idx_vals[i];
if (!term) continue;
int mult = 1;
if (!dims.empty() && i + 1 < dims.size()) {
for (size_t j = i + 1; j < dims.size(); ++j) {
if (dims[j] > 0) mult *= dims[j];
}
}
if (mult != 1) {
auto* mval = builder_.CreateConstInt(mult);
term = builder_.CreateMul(term, mval, module_.GetContext().NextTemp());
}
if (!flat) flat = term;
else flat = builder_.CreateAdd(flat, term, module_.GetContext().NextTemp());
}
if (!flat) flat = builder_.CreateConstInt(0);
std::vector<ir::Value*> gep_indices = {flat};
elem_ptr = builder_.CreateGEP(base_ptr, gep_indices, module_.GetContext().NextTemp());
} else {
std::vector<ir::Value*> indices;
if (base_ptr->GetType()->IsPtrInt32() || base_ptr->GetType()->IsPtrFloat()) {
for (auto* v : idx_vals) indices.push_back(v);
} else {
indices.push_back(builder_.CreateConstInt(0));
for (auto* v : idx_vals) indices.push_back(v);
}
elem_ptr = builder_.CreateGEP(base_ptr, indices, module_.GetContext().NextTemp());
}
ir::Value* elem_ptr = builder_.CreateGEP(
base_ptr, indices, module_.GetContext().NextTemp());
if (elem_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) {
rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (elem_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) {
rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
builder_.CreateStore(rhs, elem_ptr);
} else {
// 普通标量赋值
@ -417,7 +513,14 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) {
std::cerr << "[DEBUG] Value is int32: " << rhs->GetType()->IsInt32() << std::endl;
}
builder_.CreateStore(rhs, base_ptr);
if (base_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) {
rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (base_ptr->GetType()->IsPtrInt32() && rhs->GetType()->IsFloat()) {
rhs = builder_.CreateFPToSI(rhs, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
builder_.CreateStore(rhs, base_ptr);
}
return BlockFlow::Continue;

@ -35,7 +35,6 @@ public:
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
table_.enterScope(); // 创建全局作用域
for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用)
CollectFunctionDeclaration(func);
}
@ -46,7 +45,6 @@ public:
if (func) func->accept(this);
}
CheckMainFunction(); // 检查 main 函数存在且正确
table_.exitScope(); // 退出全局作用域
return {};
}
@ -238,6 +236,157 @@ public:
<< std::endl;
}
void CheckConstDef(SysYParser::ConstDefContext* ctx,
std::shared_ptr<ir::Type> base_type) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法常量定义"));
}
std::string name = ctx->Ident()->getText();
if (table_.lookupCurrent(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
// 确定类型
std::shared_ptr<ir::Type> type = base_type;
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
std::cout << "[DEBUG] CheckConstDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
for (auto* dim_exp : ctx->constExp()) {
int dim = table_.EvaluateConstExp(dim_exp);
if (dim <= 0) {
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
type = ir::Type::GetArrayType(base_type, dims);
std::cout << "[DEBUG] 创建数组类型完成IsArray: " << type->IsArray() << std::endl;
}
// ========== 绑定维度表达式 ==========
for (auto* dim_exp : ctx->constExp()) {
dim_exp->addExp()->accept(this);
}
// 求值初始化器
std::vector<SymbolTable::ConstValue> init_values;
if (ctx->constInitVal()) {
// ========== 绑定初始化表达式 ==========
BindConstInitVal(ctx->constInitVal());
init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type);
std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl;
}
// 计算期望的元素数量
size_t expected_count = 1;
if (is_array) {
expected_count = 1;
for (int d : dims) expected_count *= d;
std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl;
}
// 如果初始化值不足,补零
if (is_array && init_values.size() < expected_count) {
std::cout << "[DEBUG] 初始化值不足,补零" << std::endl;
SymbolTable::ConstValue zero;
if (base_type->IsInt32()) {
zero.kind = SymbolTable::ConstValue::INT;
zero.int_val = 0;
} else {
zero.kind = SymbolTable::ConstValue::FLOAT;
zero.float_val = 0.0f;
}
init_values.resize(expected_count, zero);
}
// 检查初始化值数量
if (init_values.size() > expected_count) {
throw std::runtime_error(FormatError("sema", "初始化值过多"));
}
// 创建符号
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Constant;
std::cout << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl;
sym.type = type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = true;
sym.var_def_ctx = nullptr;
sym.const_def_ctx = ctx;
std::cout << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl;
// ========== 存储常量值 ==========
if (is_array) {
// 存储数组常量(扁平化存储)
sym.is_array_const = true;
sym.array_const_values.clear();
for (const auto& val : init_values) {
Symbol::ConstantValue cv;
if (val.kind == SymbolTable::ConstValue::INT) {
cv.i32 = val.int_val;
} else {
cv.f32 = val.float_val;
}
sym.array_const_values.push_back(cv);
}
std::cout << "[DEBUG] 存储数组常量,共 " << sym.array_const_values.size()
<< " 个元素" << std::endl;
} else if (!init_values.empty()) {
// 存储标量常量
if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) {
sym.is_int_const = true;
sym.const_value.i32 = init_values[0].int_val;
std::cout << "[DEBUG] 存储整型常量: " << init_values[0].int_val << std::endl;
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
sym.is_int_const = false;
sym.const_value.f32 = init_values[0].float_val;
std::cout << "[DEBUG] 存储浮点常量: " << init_values[0].float_val << std::endl;
} else if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
// 整型常量用浮点数初始化(需要检查是否为整数)
float f = init_values[0].float_val;
int i = static_cast<int>(f);
if (std::abs(f - i) > 1e-6) {
throw std::runtime_error(FormatError("sema",
"整型常量不能用非整数值的浮点数初始化: " + std::to_string(f)));
}
sym.is_int_const = true;
sym.const_value.i32 = i;
std::cout << "[DEBUG] 浮点转整型常量: " << f << " -> " << i << std::endl;
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::INT) {
// 浮点常量用整型初始化,隐式转换
sym.is_int_const = false;
sym.const_value.f32 = static_cast<float>(init_values[0].int_val);
std::cout << "[DEBUG] 整型转浮点常量: " << init_values[0].int_val
<< " -> " << static_cast<float>(init_values[0].int_val) << std::endl;
}
} else {
// 没有初始化值,对于标量常量这是错误的
if (!is_array) {
throw std::runtime_error(FormatError("sema", "常量必须有初始化值: " + name));
}
std::cout << "[DEBUG] 数组常量无初始化器,将全部补零" << std::endl;
}
table_.addSymbol(sym);
std::cout << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl;
auto* stored = table_.lookup(name);
std::cout << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl;
std::cout << "[DEBUG] 常量符号添加完成: " << name
<< " is_array_const: " << sym.is_array_const
<< " element_count: " << sym.array_const_values.size() << std::endl;
}
// ==================== 常量声明 ====================
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) {
@ -252,91 +401,6 @@ public:
return {};
}
void CheckConstDef(SysYParser::ConstDefContext* ctx,
std::shared_ptr<ir::Type> base_type) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法常量定义"));
}
std::string name = ctx->Ident()->getText();
if (table_.lookupCurrent(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
// 确定类型
std::shared_ptr<ir::Type> type = base_type;
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
std::cout << "[DEBUG] CheckConstDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
for (auto* dim_exp : ctx->constExp()) {
int dim = table_.EvaluateConstExp(dim_exp);
if (dim <= 0) {
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
type = ir::Type::GetArrayType(base_type, dims);
std::cout << "[DEBUG] 创建数组类型完成IsArray: " << type->IsArray() << std::endl;
}
// ========== 绑定维度表达式 ==========
for (auto* dim_exp : ctx->constExp()) {
dim_exp->addExp()->accept(this);
}
// 求值初始化器
std::vector<SymbolTable::ConstValue> init_values;
if (ctx->constInitVal()) {
// ========== 绑定初始化表达式 ==========
BindConstInitVal(ctx->constInitVal());
init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type);
std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl;
}
// 检查初始化值数量
size_t expected_count = 1;
if (is_array) {
expected_count = 1;
for (int d : dims) expected_count *= d;
std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl;
}
if (init_values.size() > expected_count) {
throw std::runtime_error(FormatError("sema", "初始化值过多"));
}
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Constant;
std::cout << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl;
sym.type = type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = true;
sym.var_def_ctx = nullptr;
sym.const_def_ctx = ctx;
sym.const_def_ctx = ctx;
std::cout << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl;
// 存储常量值(仅对非数组有效)
if (!is_array && !init_values.empty()) {
if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) {
sym.is_int_const = true;
sym.const_value.i32 = init_values[0].int_val;
} else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) {
sym.is_int_const = false;
sym.const_value.f32 = init_values[0].float_val;
}
} else if (is_array) {
std::cout << "[DEBUG] 数组常量,不存储单个常量值" << std::endl;
}
table_.addSymbol(sym);
std::cout << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl;
auto* stored = table_.lookup(name);
std::cout << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl;
std::cout << "[DEBUG] 常量符号添加完成" << std::endl;
}
// ==================== 语句语义检查 ====================
// 处理所有语句 - 通过运行时类型判断
@ -1004,9 +1068,27 @@ public:
sema_.SetExprType(ctx, result);
return {};
}
// 获取语义上下文
// 新增:获取符号表
SymbolTable TakeSymbolTable() { return std::move(table_); }
SemanticContext TakeSemanticContext() { return std::move(sema_); }
// 新增:同时返回两者
SemaResult TakeResult() {
std::cerr << "[DEBUG] TakeResult 前: 符号表作用域数量 = "
<< table_.getScopeCount() << std::endl;
// 可选:打印符号表内容
// table_.dump();
SemaResult result;
result.context = std::move(sema_);
result.symbol_table = std::move(table_);
std::cerr << "[DEBUG] TakeResult 后: 符号表作用域数量 = "
<< result.symbol_table.getScopeCount() << std::endl;
return result;
}
private:
SymbolTable table_;
@ -1020,7 +1102,6 @@ private:
bool current_func_has_return_ = false;
// ==================== 辅助函数 ====================
ExprInfo CheckExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "无效表达式"));
@ -1497,9 +1578,10 @@ private:
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
SemanticContext ctx = visitor.TakeSemanticContext();
return ctx;
// 修改 RunSema 函数,使其返回 SemaResult 结构体,包含符号表和语义上下文
SemaResult RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
// 直接返回 TakeResult(),利用移动语义
return visitor.TakeResult();
}

@ -4,6 +4,7 @@
#include <stdexcept>
#include <string>
#include <cmath>
#include <functional>
#define DEBUG_SYMBOL_TABLE
@ -17,28 +18,33 @@
// ---------- 构造函数 ----------
SymbolTable::SymbolTable() {
scopes_.emplace_back(); // 初始化全局作用域
active_scope_stack_.push_back(0);
registerBuiltinFunctions(); // 注册内置库函数
}
// ---------- 作用域管理 ----------
void SymbolTable::enterScope() {
scopes_.emplace_back();
active_scope_stack_.push_back(scopes_.size() - 1);
}
void SymbolTable::exitScope() {
if (scopes_.size() > 1) {
scopes_.pop_back();
if (active_scope_stack_.size() > 1) {
active_scope_stack_.pop_back();
}
// 不能退出全局作用域
}
// ---------- 符号添加与查找 ----------
bool SymbolTable::addSymbol(const Symbol& sym) {
auto& current_scope = scopes_.back();
auto& current_scope = scopes_[active_scope_stack_.back()];
if (current_scope.find(sym.name) != current_scope.end()) {
return false; // 重复定义
}
current_scope[sym.name] = sym;
Symbol stored_sym = sym;
stored_sym.scope_level = currentScopeLevel();
current_scope[sym.name] = stored_sym;
// 立即验证存储的符号
const auto& stored = current_scope[sym.name];
@ -59,16 +65,15 @@ Symbol* SymbolTable::lookupCurrent(const std::string& name) {
}
const Symbol* SymbolTable::lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
const auto& scope = *it;
for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) {
const auto& scope = scopes_[*it];
auto found = scope.find(name);
if (found != scope.end()) {
std::cout << "SymbolTable::lookup: found " << name
<< " in scope level " << (scopes_.rend() - it - 1)
<< " in active scope index " << *it
<< ", kind=" << (int)found->second.kind
<< ", const_def_ctx=" << found->second.const_def_ctx
<< std::endl;
return &found->second;
}
}
@ -76,7 +81,7 @@ const Symbol* SymbolTable::lookup(const std::string& name) const {
}
const Symbol* SymbolTable::lookupCurrent(const std::string& name) const {
const auto& current_scope = scopes_.back();
const auto& current_scope = scopes_[active_scope_stack_.back()];
auto it = current_scope.find(name);
if (it != current_scope.end()) {
return &it->second;
@ -84,6 +89,40 @@ const Symbol* SymbolTable::lookupCurrent(const std::string& name) const {
return nullptr;
}
const Symbol* SymbolTable::lookupAll(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}
const Symbol* SymbolTable::lookupByVarDef(const SysYParser::VarDefContext* decl) const {
if (!decl) return nullptr;
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
for (const auto& [name, sym] : *it) {
if (sym.var_def_ctx == decl) {
return &sym;
}
}
}
return nullptr;
}
const Symbol* SymbolTable::lookupByConstDef(const SysYParser::ConstDefContext* decl) const {
if (!decl) return nullptr;
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
for (const auto& [name, sym] : *it) {
if (sym.const_def_ctx == decl) {
return &sym;
}
}
}
return nullptr;
}
// ---------- 兼容原接口 ----------
void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl) {
Symbol sym;
@ -96,9 +135,9 @@ void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl)
}
bool SymbolTable::Contains(const std::string& name) const {
// const 方法不能修改 scopes_我们模拟查找
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
if (it->find(name) != it->end()) {
for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) {
const auto& scope = scopes_[*it];
if (scope.find(name) != scope.end()) {
return true;
}
}
@ -106,9 +145,10 @@ bool SymbolTable::Contains(const std::string& name) const {
}
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
for (auto it = active_scope_stack_.rbegin(); it != active_scope_stack_.rend(); ++it) {
const auto& scope = scopes_[*it];
auto found = scope.find(name);
if (found != scope.end()) {
// 只返回变量定义的上下文(函数等其他符号返回 nullptr
if (found->second.kind == SymbolKind::Variable) {
return found->second.var_def_ctx;
@ -638,7 +678,7 @@ std::vector<SymbolTable::ConstValue> SymbolTable::EvaluateConstInitVal(
// 隐式类型转换
if (base_type->IsInt32() && val.kind == ConstValue::FLOAT) {
val.kind = ConstValue::INT;
val.float_val = static_cast<int>(val.int_val);
val.int_val = static_cast<int>(val.float_val);
}
if (base_type->IsFloat() && val.kind == ConstValue::INT) {
val.kind = ConstValue::FLOAT;
@ -648,32 +688,88 @@ std::vector<SymbolTable::ConstValue> SymbolTable::EvaluateConstInitVal(
}
// ========== 2. 数组常量dims 非空)==========
// 计算数组总元素个数
size_t total = 1;
for (int d : dims) total *= d;
// 展平初始化列表(递归处理花括号)
std::vector<ConstValue> flat;
flattenInit(ctx, flat, base_type);
// 检查数量是否超过数组容量
if (flat.size() > total) {
throw std::runtime_error("常量初始化:提供的初始值数量超过数组元素总数");
ConstValue zero;
if (base_type->IsInt32()) {
zero.kind = ConstValue::INT;
zero.int_val = 0;
} else {
zero.kind = ConstValue::FLOAT;
zero.float_val = 0.0f;
}
// 不足的部分补零
if (flat.size() < total) {
ConstValue zero;
// 先整体补零,再按 C 语言花括号规则覆盖显式初始化项。
std::vector<ConstValue> flat(total, zero);
auto convert_value = [&](ConstValue v) -> ConstValue {
if (base_type->IsInt32()) {
zero.kind = ConstValue::INT;
zero.int_val = 0;
} else {
zero.kind = ConstValue::FLOAT;
zero.float_val = 0.0f;
if (v.kind == ConstValue::FLOAT) {
throw std::runtime_error("常量初始化:整型数组不能使用浮点常量");
}
return v;
}
flat.resize(total, zero);
}
if (v.kind == ConstValue::INT) {
ConstValue t;
t.kind = ConstValue::FLOAT;
t.float_val = static_cast<float>(v.int_val);
return t;
}
return v;
};
auto subarray_span = [&](size_t depth) -> size_t {
size_t span = 1;
for (size_t i = depth + 1; i < dims.size(); ++i) span *= static_cast<size_t>(dims[i]);
return span;
};
std::function<size_t(SysYParser::ConstInitValContext*, size_t, size_t, size_t)> fill;
fill = [&](SysYParser::ConstInitValContext* node,
size_t depth,
size_t begin,
size_t end) -> size_t {
if (!node || begin >= end) return begin;
// 标量初始化项
if (node->constExp()) {
ConstValue v = convert_value(EvaluateAddExp(node->constExp()->addExp()));
if (begin < flat.size()) flat[begin] = v;
return std::min(begin + 1, end);
}
size_t cursor = begin;
for (auto* child : node->constInitVal()) {
if (cursor >= end) break;
if (child->constExp()) {
ConstValue v = convert_value(EvaluateAddExp(child->constExp()->addExp()));
if (cursor < flat.size()) flat[cursor] = v;
++cursor;
continue;
}
// 花括号子列表:在非最内层需要按子聚合边界对齐。
if (depth + 1 < dims.size()) {
const size_t span = subarray_span(depth);
const size_t rel = (cursor - begin) % span;
if (rel != 0) cursor += (span - rel);
if (cursor >= end) break;
const size_t sub_end = std::min(cursor + span, end);
fill(child, depth + 1, cursor, sub_end);
// 一个带花括号的子初始化器会消费当前层的一个子聚合。
cursor = sub_end;
} else {
// 最内层(标量数组)遇到额外花括号,按同层顺序展开。
cursor = fill(child, depth, cursor, end);
}
}
return cursor;
};
fill(ctx, 0, 0, total);
return flat;
}

@ -1,4 +1,162 @@
// SysY 运行库实现:
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为
#include "sylib.h"
#include <math.h>
#include <stdlib.h>
extern int scanf(const char* format, ...);
extern int printf(const char* format, ...);
extern int getchar(void);
extern int putchar(int c);
int getint(void) {
int x = 0;
scanf("%d", &x);
return x;
}
int getch(void) {
return getchar();
}
int getarray(int a[]) {
int n;
scanf("%d", &n);
int i = 0;
for (; i < n; ++i) {
scanf("%d", &a[i]);
}
return n;
}
float getfloat(void) {
float x = 0.0f;
scanf("%f", &x);
return x;
}
int getfarray(float a[]) {
int n = 0;
if (scanf("%d", &n) != 1) {
return 0;
}
int i = 0;
for (; i < n; ++i) {
if (scanf("%f", &a[i]) != 1) {
return i;
}
}
return n;
}
void putint(int x) {
printf("%d", x);
}
void putch(int x) {
putchar(x);
}
void putarray(int n, int a[]) {
int i = 0;
printf("%d:", n);
for (; i < n; ++i) {
printf(" %d", a[i]);
}
putchar('\n');
}
void putfloat(float x) {
printf("%a", x);
}
void putfarray(int n, float a[]) {
int i = 0;
printf("%d:", n);
for (; i < n; ++i) {
printf(" %a", a[i]);
}
putchar('\n');
}
void puts(int s[]) {
if (!s) return;
while (*s) {
putchar(*s);
++s;
}
}
void _sysy_starttime(int lineno) {
(void)lineno;
}
void _sysy_stoptime(int lineno) {
(void)lineno;
}
void starttime(void) {
_sysy_starttime(0);
}
void stoptime(void) {
_sysy_stoptime(0);
}
int* memset(int* ptr, int value, int count) {
unsigned char* p = (unsigned char*)ptr;
unsigned char byte = (unsigned char)(value & 0xFF);
int i = 0;
for (; i < count; ++i) {
p[i] = byte;
}
return ptr;
}
int* sysy_alloc_i32(int count) {
if (count <= 0) {
return 0;
}
return (int*)malloc((size_t)count * sizeof(int));
}
float* sysy_alloc_f32(int count) {
if (count <= 0) {
return 0;
}
return (float*)malloc((size_t)count * sizeof(float));
}
void sysy_free_i32(int* ptr) {
if (!ptr) {
return;
}
free(ptr);
}
void sysy_free_f32(float* ptr) {
if (!ptr) {
return;
}
free(ptr);
}
void sysy_zero_i32(int* ptr, int count) {
int i = 0;
if (!ptr || count <= 0) {
return;
}
for (; i < count; ++i) {
ptr[i] = 0;
}
}
void sysy_zero_f32(float* ptr, int count) {
int i = 0;
if (!ptr || count <= 0) {
return;
}
for (; i < count; ++i) {
ptr[i] = 0.0f;
}
}

@ -1,4 +1,29 @@
// SysY 运行库头文件:
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明
#pragma once
int getint(void);
int getch(void);
int getarray(int a[]);
float getfloat(void);
int getfarray(float a[]);
void putint(int x);
void putch(int x);
void putarray(int n, int a[]);
void putfloat(float x);
void putfarray(int n, float a[]);
void puts(int s[]);
void _sysy_starttime(int lineno);
void _sysy_stoptime(int lineno);
void starttime(void);
void stoptime(void);
int read_map(void);
int* memset(int* ptr, int value, int count);
int* sysy_alloc_i32(int count);
float* sysy_alloc_f32(int count);
void sysy_free_i32(int* ptr);
void sysy_free_f32(float* ptr);
void sysy_zero_i32(int* ptr, int count);
void sysy_zero_f32(float* ptr, int count);

Loading…
Cancel
Save