Compare commits

...

8 Commits

@ -37,6 +37,7 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <variant>
namespace ir {
@ -45,17 +46,15 @@ class Value;
class User;
class ConstantValue;
class ConstantInt;
class ConstantFloat;
class ConstantArray;
class GlobalValue;
class Instruction;
class BasicBlock;
class Function;
class Module;
// Use 表示一个 Value 的一次使用记录。
// 当前实现设计:
// - value被使用的值
// - user使用该值的 User
// - operand_index该值在 user 操作数列表中的位置
// ======================== Use 类 ========================
class Use {
public:
Use() = default;
@ -76,54 +75,130 @@ class Use {
size_t operand_index_ = 0;
};
// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。
// ======================== Context 类 ========================
class Context {
public:
Context() = default;
~Context();
// 去重创建 i32 常量。
// 常量创建
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
// 临时变量名生成
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1;
};
// ======================== Type 类型体系 ========================
// 类型基类,支持参数化类型
class Type {
public:
enum class Kind { Void, Int32, PtrInt32 };
enum class Kind {
Void,
Int32,
Float32,
Pointer,
Array,
Function,
Label
};
explicit Type(Kind k);
// 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type()
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type();
virtual ~Type() = default;
Kind GetKind() const;
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
bool IsFloat32() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunction() const;
bool IsLabel() const;
bool IsPtrInt32() const; // 兼容旧接口
bool IsPtrFloat32() const; // 判断是否为 float32*
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const;
// 静态单例获取基础类型
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloat32Type();
static const std::shared_ptr<Type>& GetLabelType();
static const std::shared_ptr<Type>& GetPtrInt32Type();
// 复合类型工厂方法
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> pointee);
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> elem, size_t size);
static std::shared_ptr<Type> GetFunctionType(std::shared_ptr<Type> ret,
std::vector<std::shared_ptr<Type>> params);
private:
Kind kind_;
};
// 指针类型
class PointerType : public Type {
public:
PointerType(std::shared_ptr<Type> pointee)
: Type(Type::Kind::Pointer), pointee_(std::move(pointee)) {}
const std::shared_ptr<Type>& GetPointeeType() const { return pointee_; }
private:
std::shared_ptr<Type> pointee_;
};
// 数组类型
class ArrayType : public Type {
public:
ArrayType(std::shared_ptr<Type> elem, size_t size)
: Type(Type::Kind::Array), elem_type_(std::move(elem)), size_(size) {}
const std::shared_ptr<Type>& GetElementType() const { return elem_type_; }
size_t GetSize() const { return size_; }
private:
std::shared_ptr<Type> elem_type_;
size_t size_;
};
// 函数类型
class FunctionType : public Type {
public:
FunctionType(std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params)
: Type(Type::Kind::Function), ret_type_(std::move(ret)), param_types_(std::move(params)) {}
const std::shared_ptr<Type>& GetReturnType() const { return ret_type_; }
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const { return param_types_; }
private:
std::shared_ptr<Type> ret_type_;
std::vector<std::shared_ptr<Type>> param_types_;
};
// ======================== Value 类 ========================
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
bool IsFloat32() const;
bool IsPtrInt32() const; // 兼容旧接口,实际上判断是否为 i32*
bool IsPtrFloat32() const; // 判断是否为 float32*
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
bool IsGlobalValue() const;
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
@ -135,8 +210,7 @@ class Value {
std::vector<Use> uses_;
};
// ConstantValue 是常量体系的基类。
// 当前只实现了 ConstantInt后续可继续扩展更多常量种类。
// ======================== 常量体系 ========================
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
@ -151,11 +225,26 @@ class ConstantInt : public ConstantValue {
int value_{};
};
// 后续还需要扩展更多指令类型。
enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret };
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。
// 常量数组(简单聚合,可存储常量元素)
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elems);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
private:
std::vector<ConstantValue*> elements_;
};
// ======================== User 类 ========================
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
@ -164,20 +253,44 @@ class User : public Value {
void SetOperand(size_t index, Value* value);
protected:
// 统一的 operand 入口。
void AddOperand(Value* value);
private:
std::vector<Value*> operands_;
};
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
// ======================== GlobalValue 类 ========================
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
ConstantValue* GetInitializer() const { return init_; }
void SetInitializer(ConstantValue* init) { init_ = init; }
private:
ConstantValue* init_ = nullptr;
};
// ======================== 指令操作码 ========================
enum class Opcode {
// 算术
Add, Sub, Mul, Div, Mod,
// 位运算
And, Or, Xor, Shl, LShr, AShr,
// 比较
ICmp, FCmp,
// 内存
Alloca, Load, Store,
// 控制流
Ret, Br, CondBr,
// 函数调用
Call,
// 数组访问
GEP,
// Phi
Phi
};
// ======================== Instruction 类 ========================
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
@ -191,31 +304,94 @@ class Instruction : public User {
BasicBlock* parent_ = nullptr;
};
// 二元运算指令
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
Value* GetRhs() const;
};
// 比较指令icmp/fcmp
class CmpInst : public Instruction {
public:
enum Predicate {
EQ, NE, LT, LE, GT, GE
};
CmpInst(Opcode op, Predicate pred, Value* lhs, Value* rhs, std::string name);
Predicate GetPredicate() const { return pred_; }
Value* GetLhs() const { return lhs_; }
Value* GetRhs() const { return rhs_; }
private:
Predicate pred_;
Value* lhs_;
Value* rhs_;
};
// 返回指令
class ReturnInst : public Instruction {
public:
ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
ReturnInst(std::shared_ptr<Type> void_ty, Value* val = nullptr);
Value* GetValue() const;
};
// 无条件分支
class BranchInst : public Instruction {
public:
BranchInst(BasicBlock* target);
BasicBlock* GetTarget() const;
};
// 条件分支
class CondBranchInst : public Instruction {
public:
CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb);
Value* GetCond() const;
BasicBlock* GetTrueBlock() const;
BasicBlock* GetFalseBlock() const;
};
// 函数调用
class CallInst : public Instruction {
public:
CallInst(Function* callee, std::vector<Value*> args, std::string name);
Function* GetCallee() const;
const std::vector<Value*>& GetArgs() const;
};
// Phi 指令(用于 SSA
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* val, BasicBlock* block);
const std::vector<std::pair<Value*, BasicBlock*>>& GetIncomings() const;
};
// GetElementPtr 指令(数组/结构体指针计算)
class GetElementPtrInst : public Instruction {
public:
GetElementPtrInst(std::shared_ptr<Type> ty, Value* ptr,
std::vector<Value*> indices, std::string name);
Value* GetPtr() const;
const std::vector<Value*>& GetIndices() const;
};
// 分配栈内存指令
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
};
// 加载指令
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* GetPtr() const;
};
// 存储指令
class StoreInst : public Instruction {
public:
StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr);
@ -223,8 +399,7 @@ class StoreInst : public Instruction {
Value* GetPtr() const;
};
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
// ======================== BasicBlock 类 ========================
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
@ -234,6 +409,11 @@ class BasicBlock : public Value {
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
@ -254,65 +434,119 @@ class BasicBlock : public Value {
std::vector<BasicBlock*> successors_;
};
// Function 当前也采用了最小实现。
// 需要特别注意:由于项目里还没有单独的 FunctionType
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
// 形参和调用,通常需要引入专门的函数类型表示。
// ======================== Function 类 ========================
class Function : public Value {
public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。
Function(std::string name, std::shared_ptr<Type> ret_type);
// 构造函数,接收函数名、返回类型和参数类型列表(可选)
Function(std::string name, std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types = {});
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
// 参数管理
const std::vector<Value*>& GetParams() const { return params_; }
void AddParam(Value* param);
// 函数类型(完整签名)
std::shared_ptr<FunctionType> GetFunctionType() const;
private:
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<Value*> params_; // 参数值(通常是 Argument 类型,后续可定义)
// Owned parameter storage to keep argument Values alive
std::vector<std::unique_ptr<Value>> owned_params_;
std::shared_ptr<FunctionType> func_type_; // 缓存函数类型
};
// ======================== Module 类 ========================
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
// 创建函数,支持参数类型
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type);
std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types = {});
// 创建全局变量
GlobalValue* CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> ty,
ConstantValue* init = nullptr);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
const std::vector<std::unique_ptr<GlobalValue>>& GetGlobalVariables() const;
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalValue>> global_vars_;
};
// ======================== IRBuilder 类 ========================
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
IRBuilder(Context& ctx, BasicBlock* bb = nullptr);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
// 构造常量、二元运算、返回指令的最小集合。
// 常量创建
ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
// 算术指令
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name);
// 比较指令
CmpInst* CreateICmp(CmpInst::Predicate pred, Value* lhs, Value* rhs,
const std::string& name);
CmpInst* CreateFCmp(CmpInst::Predicate pred, Value* lhs, Value* rhs,
const std::string& name);
// 内存指令
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaFloat(const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
// 控制流指令
ReturnInst* CreateRet(Value* v);
BranchInst* CreateBr(BasicBlock* target);
CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_bb,
BasicBlock* false_bb);
// 函数调用
CallInst* CreateCall(Function* callee, std::vector<Value*> args,
const std::string& name);
// 数组访问
GetElementPtrInst* CreateGEP(Value* ptr, std::vector<Value*> indices,
const std::string& name);
// Phi 指令
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
private:
Context& ctx_;
BasicBlock* insert_block_;
};
// ======================== IRPrinter 类 ========================
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
} // namespace ir
} // namespace ir

@ -22,20 +22,26 @@ class Value;
class IRGenImpl final : public SysYBaseVisitor {
public:
IRGenImpl(ir::Module& module, const SemanticContext& sema);
IRGenImpl(ir::Module& module, IRGenContext& sema);
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlock(SysYParser::BlockContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
private:
enum class BlockFlow {
@ -47,12 +53,20 @@ class IRGenImpl final : public SysYBaseVisitor {
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::Module& module_;
const SemanticContext& sema_;
IRGenContext& sema_;
ir::Function* func_;
ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
std::unordered_map<antlr4::ParserRuleContext*, ir::Value*> storage_map_;
// 额外增加按名称的快速映射,以防有时无法直接通过声明节点指针匹配。
std::unordered_map<std::string, ir::Value*> name_map_;
// 常量名称到整数值的快速映射(供数组维度解析使用)
std::unordered_map<std::string, long> const_values_;
// 当前正在处理的声明基础类型(由 visitDecl 设置visitVarDef/visitConstDef 使用)
std::string current_btype_;
std::vector<ir::BasicBlock*> break_targets_;
std::vector<ir::BasicBlock*> continue_targets_;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema);
IRGenContext& sema);

@ -1,30 +1,181 @@
// 基于语法树的语义检查与名称绑定。
#pragma once
#ifndef SEMANTIC_ANALYSIS_H
#define SEMANTIC_ANALYSIS_H
#include "SymbolTable.h"
#include "SysYBaseVisitor.h"
#include <vector>
#include <string>
#include <sstream>
#include <unordered_map>
#include <any>
#include <memory>
#include "SysYParser.h"
class SemanticContext {
public:
void BindVarUse(SysYParser::VarContext* use,
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl;
}
SysYParser::VarDefContext* ResolveVarUse(
const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second;
}
private:
std::unordered_map<const SysYParser::VarContext*,
SysYParser::VarDefContext*>
var_uses_;
// 错误信息结构体
struct ErrorMsg {
std::string msg;
int line;
int column;
ErrorMsg(std::string m, int l, int c) : msg(std::move(m)), line(l), column(c) {}
};
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
// 前向声明
namespace antlr4 {
class ParserRuleContext;
namespace tree {
class ParseTree;
}
}
// 语义/IR生成上下文核心类
class IRGenContext {
public:
// 错误管理
void RecordError(const ErrorMsg& err) { errors_.push_back(err); }
const std::vector<ErrorMsg>& GetErrors() const { return errors_; }
bool HasError() const { return !errors_.empty(); }
void ClearErrors() { errors_.clear(); }
// 类型绑定/查询 - 使用 void* 以兼容测试代码
void SetType(void* ctx, SymbolType type) {
node_type_map_[ctx] = type;
}
SymbolType GetType(void* ctx) const {
auto it = node_type_map_.find(ctx);
return it == node_type_map_.end() ? SymbolType::TYPE_UNKNOWN : it->second;
}
// 常量值绑定/查询 - 使用 void* 以兼容测试代码
void SetConstVal(void* ctx, const std::any& val) {
const_val_map_[ctx] = val;
}
std::any GetConstVal(void* ctx) const {
auto it = const_val_map_.find(ctx);
return it == const_val_map_.end() ? std::any() : it->second;
}
// 循环状态管理
void EnterLoop() { sym_table_.EnterLoop(); }
void ExitLoop() { sym_table_.ExitLoop(); }
bool InLoop() const { return sym_table_.InLoop(); }
// 类型判断工具函数
bool IsIntType(const std::any& val) const {
return val.type() == typeid(long) || val.type() == typeid(int);
}
bool IsFloatType(const std::any& val) const {
return val.type() == typeid(double) || val.type() == typeid(float);
}
// 当前函数返回类型
SymbolType GetCurrentFuncReturnType() const {
return current_func_ret_type_;
}
void SetCurrentFuncReturnType(SymbolType type) {
current_func_ret_type_ = type;
}
// 符号表访问
SymbolTable& GetSymbolTable() { return sym_table_; }
const SymbolTable& GetSymbolTable() const { return sym_table_; }
// 作用域管理
void EnterScope() { sym_table_.EnterScope(); }
void LeaveScope() { sym_table_.LeaveScope(); }
size_t GetScopeDepth() const { return sym_table_.GetScopeDepth(); }
private:
SymbolTable sym_table_;
std::unordered_map<void*, SymbolType> node_type_map_;
std::unordered_map<void*, std::any> const_val_map_;
std::vector<ErrorMsg> errors_;
SymbolType current_func_ret_type_ = SymbolType::TYPE_UNKNOWN;
};
// 错误信息格式化工具函数
inline std::string FormatErrMsg(const std::string& msg, int line, int col) {
std::ostringstream oss;
oss << "[行:" << line << ",列:" << col << "] " << msg;
return oss.str();
}
// 语义分析访问器 - 继承自生成的基类
class SemaVisitor : public SysYBaseVisitor {
public:
explicit SemaVisitor(IRGenContext& ctx) : ir_ctx_(ctx) {}
// 必须实现的 ANTLR4 接口
std::any visit(antlr4::tree::ParseTree* tree) override {
if (tree) {
return tree->accept(this);
}
return std::any();
}
std::any visitTerminal(antlr4::tree::TerminalNode* node) override {
return std::any();
}
std::any visitErrorNode(antlr4::tree::ErrorNode* node) override {
if (node) {
int line = node->getSymbol()->getLine();
int col = node->getSymbol()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("语法错误节点", line, col));
}
return std::any();
}
// 核心访问方法
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
std::any visitBlock(SysYParser::BlockContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitExp(SysYParser::ExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override;
// 通用子节点访问
std::any visitChildren(antlr4::tree::ParseTree* node) override {
std::any result;
if (node) {
for (auto* child : node->children) {
if (child) {
result = child->accept(this);
}
}
}
return result;
}
// 获取上下文引用
IRGenContext& GetContext() { return ir_ctx_; }
const IRGenContext& GetContext() const { return ir_ctx_; }
private:
IRGenContext& ir_ctx_;
};
// 语义分析入口函数
void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx);
// 兼容主程序的简单包装:返回构造好的 IRGenContext按值
IRGenContext RunSema(SysYParser::CompUnitContext& ctx);
#endif // SEMANTIC_ANALYSIS_H

@ -1,17 +1,201 @@
// 极简符号表:记录局部变量定义点。
#pragma once
#ifndef SYMBOL_TABLE_H
#define SYMBOL_TABLE_H
#include <any>
#include <string>
#include <vector>
#include <unordered_map>
#include <stack>
#include <utility>
#include "SysYParser.h"
// 核心类型枚举
enum class SymbolType {
TYPE_UNKNOWN, // 未知类型
TYPE_INT, // 整型
TYPE_FLOAT, // 浮点型
TYPE_VOID, // 空类型
TYPE_ARRAY, // 数组类型
TYPE_FUNCTION // 函数类型
};
// 获取类型名称字符串
inline const char* SymbolTypeToString(SymbolType type) {
switch (type) {
case SymbolType::TYPE_INT: return "int";
case SymbolType::TYPE_FLOAT: return "float";
case SymbolType::TYPE_VOID: return "void";
case SymbolType::TYPE_ARRAY: return "array";
case SymbolType::TYPE_FUNCTION: return "function";
default: return "unknown";
}
}
// 变量信息结构体
struct VarInfo {
SymbolType type = SymbolType::TYPE_UNKNOWN;
bool is_const = false;
std::any const_val;
std::vector<int> array_dims; // 数组维度,空表示非数组
void* decl_ctx = nullptr; // 关联的语法节点
// 检查是否为数组类型
bool IsArray() const { return !array_dims.empty(); }
// 获取数组元素总数
int GetArrayElementCount() const {
int count = 1;
for (int dim : array_dims) {
count *= dim;
}
return count;
}
};
// 函数信息结构体
struct FuncInfo {
SymbolType ret_type = SymbolType::TYPE_UNKNOWN;
std::string name;
std::vector<SymbolType> param_types; // 参数类型列表
void* decl_ctx = nullptr; // 关联的语法节点
// 检查参数匹配
bool CheckParams(const std::vector<SymbolType>& actual_params) const {
if (actual_params.size() != param_types.size()) {
return false;
}
for (size_t i = 0; i < param_types.size(); ++i) {
if (param_types[i] != actual_params[i] &&
param_types[i] != SymbolType::TYPE_UNKNOWN &&
actual_params[i] != SymbolType::TYPE_UNKNOWN) {
return false;
}
}
return true;
}
};
// 作用域条目结构体
struct ScopeEntry {
// 变量符号表:符号名 -> (符号信息, 声明节点)
std::unordered_map<std::string, std::pair<VarInfo, void*>> var_symbols;
// 函数符号表:符号名 -> (函数信息, 声明节点)
std::unordered_map<std::string, std::pair<FuncInfo, void*>> func_symbols;
// 清空作用域
void Clear() {
var_symbols.clear();
func_symbols.clear();
}
};
// 符号表核心类
class SymbolTable {
public:
void Add(const std::string& name, SysYParser::VarDefContext* decl);
bool Contains(const std::string& name) const;
SysYParser::VarDefContext* Lookup(const std::string& name) const;
public:
// ========== 作用域管理 ==========
// 进入新作用域
void EnterScope();
// 离开当前作用域
void LeaveScope();
// 获取当前作用域深度
size_t GetScopeDepth() const { return scopes_.size(); }
// 检查作用域栈是否为空
bool IsEmpty() const { return scopes_.empty(); }
// ========== 变量符号管理 ==========
// 检查当前作用域是否包含指定变量
bool CurrentScopeHasVar(const std::string& name) const;
// 绑定变量到当前作用域
void BindVar(const std::string& name, const VarInfo& info, void* decl_ctx);
// 查找变量(从当前作用域向上遍历)
bool LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const;
// 快速查找变量(不获取详细信息)
bool HasVar(const std::string& name) const {
VarInfo info;
void* ctx;
return LookupVar(name, info, ctx);
}
// ========== 函数符号管理 ==========
// 检查当前作用域是否包含指定函数
bool CurrentScopeHasFunc(const std::string& name) const;
// 绑定函数到当前作用域
void BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx);
// 查找函数(从当前作用域向上遍历)
bool LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const;
// 快速查找函数(不获取详细信息)
bool HasFunc(const std::string& name) const {
FuncInfo info;
void* ctx;
return LookupFunc(name, info, ctx);
}
// ========== 循环状态管理 ==========
// 进入循环
void EnterLoop();
// 离开循环
void ExitLoop();
// 检查是否在循环内
bool InLoop() const;
// 获取循环嵌套深度
int GetLoopDepth() const { return loop_depth_; }
// ========== 辅助功能 ==========
// 清空所有作用域和状态
void Clear();
// 获取当前作用域中所有变量名
std::vector<std::string> GetCurrentScopeVarNames() const;
// 获取当前作用域中所有函数名
std::vector<std::string> GetCurrentScopeFuncNames() const;
// 调试:打印符号表内容
void Dump() const;
private:
std::unordered_map<std::string, SysYParser::VarDefContext*> table_;
private:
// 作用域栈
std::stack<ScopeEntry> scopes_;
// 循环嵌套深度
int loop_depth_ = 0;
};
// 类型兼容性检查函数
inline bool IsTypeCompatible(SymbolType expected, SymbolType actual) {
if (expected == SymbolType::TYPE_UNKNOWN || actual == SymbolType::TYPE_UNKNOWN) {
return true; // 未知类型视为兼容
}
// 基本类型兼容规则
if (expected == actual) {
return true;
}
// int 可以隐式转换为 float
if (expected == SymbolType::TYPE_FLOAT && actual == SymbolType::TYPE_INT) {
return true;
}
return false;
}
#endif // SYMBOL_TABLE_H

@ -0,0 +1,39 @@
import os
import subprocess
COMPILER = "./build/bin/compiler"
TEST_DIR = "./test/test_case/functional"
pass_cnt = 0
fail_cnt = 0
print("===== SysY Batch Test Start =====")
for file in os.listdir(TEST_DIR):
if not file.endswith(".sy"):
continue
path = os.path.join(TEST_DIR, file)
print(f"[TEST] {file} ... ", end="")
result = subprocess.run(
[COMPILER, "--emit-parse-tree", path],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
if result.returncode == 0:
print("PASS")
pass_cnt += 1
else:
print("FAIL")
fail_cnt += 1
print("---- Error ----")
print(result.stderr.decode())
print("---------------")
print("===============================")
print(f"Total: {pass_cnt + fail_cnt}")
print(f"PASS : {pass_cnt}")
print(f"FAIL : {fail_cnt}")
print("===============================")

@ -0,0 +1,60 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT=$(cd "$(dirname "$0")/.." && pwd)
FUNC_DIR="$ROOT/test/test_case/functional"
OUT_BASE="$ROOT/test/test_result/function/ir"
LOG_DIR="$ROOT/test/test_result/function/ir_logs"
VERIFY="$ROOT/scripts/verify_ir.sh"
mkdir -p "$OUT_BASE"
mkdir -p "$LOG_DIR"
if [ ! -x "$VERIFY" ]; then
echo "verify script not executable, trying to run with bash: $VERIFY"
fi
files=("$FUNC_DIR"/*.sy)
if [ ${#files[@]} -eq 0 ]; then
echo "No .sy files found in $FUNC_DIR" >&2
exit 1
fi
total=0
pass=0
fail=0
failed_list=()
for f in "${files[@]}"; do
((total++))
name=$(basename "$f")
echo "=== Test: $name ==="
log="$LOG_DIR/${name%.sy}.log"
set +e
bash "$VERIFY" "$f" "$OUT_BASE" --run >"$log" 2>&1
rc=$?
set -e
if [ $rc -eq 0 ]; then
echo "PASS: $name"
((pass++))
else
echo "FAIL: $name (log: $log)"
failed_list+=("$name")
((fail++))
fi
done
echo
echo "Summary: total=$total pass=$pass fail=$fail"
if [ $fail -ne 0 ]; then
echo "Failed tests:"; for t in "${failed_list[@]}"; do echo " - $t"; done
echo "Tail of failure logs (last 200 lines each):"
for t in "${failed_list[@]}"; do
logfile="$LOG_DIR/${t%.sy}.log"
echo
echo "--- $t ---"
tail -n 200 "$logfile" || true
done
fi
exit $fail

@ -1,68 +1,70 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY;
/*===-------------------------------------------===*/
/* Lexer rules */
/*===-------------------------------------------===*/
INT: 'int';
RETURN: 'return';
ASSIGN: '=';
ADD: '+';
// ======================
// Parser Rules
// ======================
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
SEMICOLON: ';';
compUnit
: (decl | funcDef)+
;
ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
decl
: constDecl
| varDecl
;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
constDecl
: 'const' bType constDef (',' constDef)* ';'
;
/*===-------------------------------------------===*/
/* Syntax rules */
/*===-------------------------------------------===*/
bType
: 'int'
| 'float'
;
compUnit
: funcDef EOF
constDef
: Ident ('[' constExp ']')* '=' constInitVal
;
decl
: btype varDef SEMICOLON
constInitVal
: constExp
| '{' (constInitVal (',' constInitVal)*)? '}'
;
btype
: INT
varDecl
: bType varDef (',' varDef)* ';'
;
varDef
: lValue (ASSIGN initValue)?
: Ident ('[' constExp ']')*
| Ident ('[' constExp ']')* '=' initVal
;
initValue
initVal
: exp
| '{' (initVal (',' initVal)*)? '}'
;
funcDef
: funcType ID LPAREN RPAREN blockStmt
: funcType Ident '(' funcFParams? ')' block
;
funcType
: INT
: 'void'
| 'int'
| 'float'
;
blockStmt
: LBRACE blockItem* RBRACE
funcFParams
: funcFParam (',' funcFParam)*
;
funcFParam
: bType Ident ('[' ']' ('[' exp ']')*)?
;
block
: '{' blockItem* '}'
;
blockItem
@ -71,28 +73,129 @@ blockItem
;
stmt
: returnStmt
: lVal '=' exp ';'
| exp? ';'
| block
| 'if' '(' cond ')' stmt ('else' stmt)?
| 'while' '(' cond ')' stmt
| 'break' ';'
| 'continue' ';'
| 'return' exp? ';'
;
returnStmt
: RETURN exp SEMICOLON
exp
: addExp
;
exp
: LPAREN exp RPAREN # parenExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
cond
: lOrExp
;
var
: ID
lVal
: Ident ('[' exp ']')*
;
lValue
: ID
primaryExp
: '(' exp ')'
| lVal
| number
;
number
: ILITERAL
: FloatConst
| IntConst
;
unaryExp
: primaryExp
| Ident '(' funcRParams? ')'
| unaryOp unaryExp
;
unaryOp
: '+'
| '-'
| '!'
;
funcRParams
: exp (',' exp)*
;
mulExp
: unaryExp (('*' | '/' | '%') unaryExp)*
;
addExp
: mulExp (('+' | '-') mulExp)*
;
relExp
: addExp (('<' | '>' | '<=' | '>=') addExp)*
;
eqExp
: relExp (('==' | '!=') relExp)*
;
lAndExp
: eqExp ('&&' eqExp)*
;
lOrExp
: lAndExp ('||' lAndExp)*
;
constExp
: addExp
;
// ======================
// Lexer Rules
// ======================
fragment DIGIT : [0-9] ;
fragment HEXDIGIT : [0-9a-fA-F] ;
fragment EXP : [eE][+-]? DIGIT+ ;
fragment PEXP : [pP][+-]? DIGIT+ ;
// Float含 hex float
FloatConst
: ('0x' | '0X')
(
HEXDIGIT+ '.' HEXDIGIT*
| '.' HEXDIGIT+
| HEXDIGIT+
)
PEXP
| '.' DIGIT+ EXP?
| DIGIT+ '.' DIGIT* EXP?
| DIGIT+ EXP
;
// Int完整三种
IntConst
: '0'
| [1-9][0-9]* // decimal
| '0'[0-7]+ // octal
| ('0x' | '0X')[0-9a-fA-F]+ // hex
;
// ---------- 标识符 ----------
Ident
: [a-zA-Z_][a-zA-Z0-9_]*
;
// ---------- 空白 ----------
WS
: [ \t\r\n]+ -> skip
;
// ---------- 注释 ----------
LINE_COMMENT
: '//' ~[\r\n]* -> skip
;
BLOCK_COMMENT
: '/*' .*? '*/' -> skip
;

@ -10,6 +10,7 @@
#include "ir/IR.h"
#include <utility>
#include <algorithm> // 用于 std::find
namespace ir {
@ -21,7 +22,6 @@ Function* BasicBlock::GetParent() const { return parent_; }
void BasicBlock::SetParent(Function* parent) { parent_ = parent; }
bool BasicBlock::HasTerminator() const {
return !instructions_.empty() && instructions_.back()->IsTerminator();
}
@ -42,4 +42,36 @@ const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_;
}
} // namespace ir
// 添加前驱基本块(避免重复)
void BasicBlock::AddPredecessor(BasicBlock* pred) {
if (std::find(predecessors_.begin(), predecessors_.end(), pred) ==
predecessors_.end()) {
predecessors_.push_back(pred);
}
}
// 添加后继基本块(避免重复)
void BasicBlock::AddSuccessor(BasicBlock* succ) {
if (std::find(successors_.begin(), successors_.end(), succ) ==
successors_.end()) {
successors_.push_back(succ);
}
}
// 移除前驱基本块
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
auto it = std::find(predecessors_.begin(), predecessors_.end(), pred);
if (it != predecessors_.end()) {
predecessors_.erase(it);
}
}
// 移除后继基本块
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
auto it = std::find(successors_.begin(), successors_.end(), succ);
if (it != successors_.end()) {
successors_.erase(it);
}
}
} // namespace ir

@ -15,9 +15,17 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get();
}
ConstantFloat* Context::GetConstFloat(float v) {
auto it = const_floats_.find(v);
if (it != const_floats_.end()) return it->second.get();
auto inserted = const_floats_.emplace(
v, std::make_unique<ConstantFloat>(Type::GetFloat32Type(), v)).first;
return inserted->second.get();
}
std::string Context::NextTemp() {
std::ostringstream oss;
oss << "%" << ++temp_index_;
oss << ++temp_index_;
return oss.str();
}

@ -5,9 +5,18 @@
namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type)
Function::Function(std::string name, std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types)
: Value(std::move(ret_type), std::move(name)) {
func_type_ = std::static_pointer_cast<FunctionType>(
Type::GetFunctionType(GetType(), param_types));
entry_ = CreateBlock("entry");
// Create arguments
for (size_t i = 0; i < param_types.size(); ++i) {
owned_params_.push_back(std::make_unique<Value>(param_types[i], "arg" + std::to_string(i)));
params_.push_back(owned_params_.back().get());
// Note: arguments are owned in owned_params_ to ensure lifetime
}
}
BasicBlock* Function::CreateBlock(const std::string& name) {
@ -29,4 +38,15 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_;
}
void Function::AddParam(Value* param) {
if (!param) {
throw std::runtime_error("Function::AddParam cannot add null param");
}
params_.push_back(param);
}
std::shared_ptr<FunctionType> Function::GetFunctionType() const {
return func_type_;
}
} // namespace ir

@ -21,6 +21,10 @@ ConstantInt* IRBuilder::CreateConstInt(int v) {
return ctx_.GetConstInt(v);
}
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return ctx_.GetConstFloat(v);
}
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
@ -42,6 +46,7 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -49,15 +54,35 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
}
AllocaInst* IRBuilder::CreateAllocaFloat(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPointerType(Type::GetFloat32Type()), name);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!ty || !ty->IsPointer()) {
throw std::runtime_error(FormatError("ir", "CreateAlloca 仅支持指针类型"));
}
return insert_block_->Append<AllocaInst>(ty, name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!ptr) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateLoad ptr 必须为指针类型"));
}
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
const auto* ptr_ty = static_cast<const PointerType*>(ptr->GetType().get());
return insert_block_->Append<LoadInst>(ptr_ty->GetPointeeType(), ptr, name);
}
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@ -79,11 +104,100 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!v) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
}
// ReturnInst expects its own type to be void; the returned value is an
// operand. Always use void as the instruction type.
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
}
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Div, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMod(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Mod, lhs, rhs, name);
}
CmpInst* IRBuilder::CreateICmp(CmpInst::Predicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CmpInst>(Opcode::ICmp, pred, lhs, rhs, name);
}
CmpInst* IRBuilder::CreateFCmp(CmpInst::Predicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CmpInst>(Opcode::FCmp, pred, lhs, rhs, name);
}
BranchInst* IRBuilder::CreateBr(BasicBlock* target) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<BranchInst>(target);
}
CondBranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CondBranchInst>(cond, true_bb, false_bb);
}
CallInst* IRBuilder::CreateCall(Function* callee, std::vector<Value*> args, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CallInst>(callee, args, name);
}
GetElementPtrInst* IRBuilder::CreateGEP(Value* ptr, std::vector<Value*> indices, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
// 计算结果类型:根据传入的 indices 逐步从 pointee 类型走到目标元素类型。
// 例如 ptr 是指向数组的指针GEP 使用一个索引应返回指向数组元素的指针。
std::shared_ptr<Type> current;
if (ptr->GetType() && ptr->GetType()->IsPointer()) {
const PointerType* pty = static_cast<const PointerType*>(ptr->GetType().get());
current = pty->GetPointeeType();
} else {
current = ptr->GetType();
}
// 根据每个索引推进类型层次:数组 -> 元素类型,指针 -> 指向类型
for (size_t i = 0; i < indices.size(); ++i) {
if (!current) break;
if (current->IsArray()) {
const ArrayType* aty = static_cast<const ArrayType*>(current.get());
current = aty->GetElementType();
} else if (current->IsPointer()) {
const PointerType* ppty = static_cast<const PointerType*>(current.get());
current = ppty->GetPointeeType();
} else {
// 非数组/指针类型,无法继续下钻,保持当前类型
break;
}
}
auto result_ty = Type::GetPointerType(current);
return insert_block_->Append<GetElementPtrInst>(result_ty, ptr, indices, name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> ty, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<PhiInst>(ty, name);
}
} // namespace ir

@ -12,49 +12,140 @@
namespace ir {
static const char* TypeToString(const Type& ty) {
static std::string PredicateToString(CmpInst::Predicate pred, bool is_float) {
if (is_float) {
switch (pred) {
case CmpInst::EQ: return "oeq";
case CmpInst::NE: return "one";
case CmpInst::LT: return "olt";
case CmpInst::LE: return "ole";
case CmpInst::GT: return "ogt";
case CmpInst::GE: return "oge";
}
} else {
switch (pred) {
case CmpInst::EQ: return "eq";
case CmpInst::NE: return "ne";
case CmpInst::LT: return "slt";
case CmpInst::LE: return "sle";
case CmpInst::GT: return "sgt";
case CmpInst::GE: return "sge";
}
}
return "unknown";
}
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
case Type::Kind::Int32:
return "i32";
case Type::Kind::PtrInt32:
return "i32*";
case Type::Kind::Float32:
return "float";
case Type::Kind::Pointer: {
const PointerType* p = static_cast<const PointerType*>(&ty);
return TypeToString(*p->GetPointeeType()) + "*";
}
case Type::Kind::Array: {
const ArrayType* a = static_cast<const ArrayType*>(&ty);
return std::string("[") + std::to_string(a->GetSize()) + " x " + TypeToString(*a->GetElementType()) + "]";
}
case Type::Kind::Function:
return "[function]";
case Type::Kind::Label:
return "label";
}
throw std::runtime_error(FormatError("ir", "未知类型"));
}
static const char* OpcodeToString(Opcode op) {
switch (op) {
case Opcode::Add:
return "add";
case Opcode::Sub:
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Alloca:
return "alloca";
case Opcode::Load:
return "load";
case Opcode::Store:
return "store";
case Opcode::Ret:
return "ret";
case Opcode::Add: return "add";
case Opcode::Sub: return "sub";
case Opcode::Mul: return "mul";
case Opcode::Div: return "sdiv";
case Opcode::Mod: return "srem";
case Opcode::And: return "and";
case Opcode::Or: return "or";
case Opcode::Xor: return "xor";
case Opcode::Shl: return "shl";
case Opcode::LShr: return "lshr";
case Opcode::AShr: return "ashr";
case Opcode::ICmp: return "icmp";
case Opcode::FCmp: return "fcmp";
case Opcode::Alloca: return "alloca";
case Opcode::Load: return "load";
case Opcode::Store: return "store";
case Opcode::Ret: return "ret";
case Opcode::Br: return "br";
case Opcode::CondBr: return "br";
case Opcode::Call: return "call";
case Opcode::GEP: return "getelementptr";
case Opcode::Phi: return "phi";
}
return "?";
}
static std::string ConstantValueToString(const ConstantValue* cv);
static std::string ValueToString(const Value* v) {
if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
return v ? v->GetName() : "<null>";
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
// simple float literal
return std::to_string(cf->GetValue());
}
if (auto* ca = dynamic_cast<const ConstantArray*>(v)) {
return ConstantValueToString(ca);
}
// fallback to name for instructions/alloca/vars — prefix with '%'
return std::string("%") + v->GetName();
}
static std::string ConstantValueToString(const ConstantValue* cv) {
if (!cv) return "<null-const>";
if (auto* ci = dynamic_cast<const ConstantInt*>(cv)) return std::to_string(ci->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(cv)) {
std::string s = std::to_string(cf->GetValue());
size_t dot = s.find('.');
if (dot != std::string::npos) {
size_t e = s.find('e');
if (e == std::string::npos) e = s.size();
while (e > dot + 1 && s[e-1] == '0') e--;
if (e == dot + 1) s = s.substr(0, dot + 1) + "0";
else s = s.substr(0, e);
}
return s;
}
if (auto* ca = dynamic_cast<const ConstantArray*>(cv)) {
// format: [ <elem_ty> <elem>, <elem_ty> <elem>, ... ]
const auto& elems = ca->GetElements();
std::string out = "[";
for (size_t i = 0; i < elems.size(); ++i) {
if (i) out += ", ";
// each element should be printed with its type and value
auto* e = elems[i];
std::string etype = TypeToString(*e->GetType());
out += etype + " " + ConstantValueToString(e);
}
out += "]";
return out;
}
return "<const-unk>";
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "() {\n";
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() << "(";
const auto& params = func->GetParams();
for (size_t i = 0; i < params.size(); ++i) {
if (i) os << ", ";
os << TypeToString(*params[i]->GetType()) << " " << ValueToString(params[i]);
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
continue;
@ -65,30 +156,66 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul: {
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::LShr:
case Opcode::AShr: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
// choose opcode name: integer ops use e.g. 'add','sub', float ops use 'fadd','fsub', etc.
std::string op_name = OpcodeToString(bin->GetOpcode());
bool is_float = bin->GetLhs()->GetType()->IsFloat32();
if (is_float) {
switch (bin->GetOpcode()) {
case Opcode::Add: op_name = "fadd"; break;
case Opcode::Sub: op_name = "fsub"; break;
case Opcode::Mul: op_name = "fmul"; break;
case Opcode::Div: op_name = "fdiv"; break;
case Opcode::Mod: op_name = "frem"; break;
default: break;
}
}
os << " %" << bin->GetName() << " = "
<< op_name << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmp:
case Opcode::FCmp: {
auto* cmp = static_cast<const CmpInst*>(inst);
os << " %" << cmp->GetName() << " = "
<< OpcodeToString(cmp->GetOpcode()) << " " << PredicateToString(cmp->GetPredicate(), cmp->GetOpcode() == Opcode::FCmp) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca i32\n";
os << " %" << alloca->GetName() << " = alloca "
<< TypeToString(*static_cast<const PointerType*>(alloca->GetType().get())->GetPointeeType()) << "\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* "
<< ValueToString(load->GetPtr()) << "\n";
os << " %" << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n";
os << " store " << TypeToString(*store->GetValue()->GetType()) << " "
<< ValueToString(store->GetValue()) << ", "
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
@ -97,6 +224,59 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
<< ValueToString(ret->GetValue()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
os << " br label %" << br->GetTarget()->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* condbr = static_cast<const CondBranchInst*>(inst);
os << " br i1 " << ValueToString(condbr->GetCond())
<< ", label %" << condbr->GetTrueBlock()->GetName()
<< ", label %" << condbr->GetFalseBlock()->GetName() << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
os << " ";
if (!call->GetName().empty()) {
os << "%" << call->GetName() << " = ";
}
os << "call " << TypeToString(*call->GetCallee()->GetType()) << " @"
<< call->GetCallee()->GetName() << "(";
for (size_t i = 0; i < call->GetArgs().size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*call->GetArgs()[i]->GetType()) << " "
<< ValueToString(call->GetArgs()[i]);
}
os << ")\n";
break;
}
case Opcode::GEP: {
auto* gep = static_cast<const GetElementPtrInst*>(inst);
os << " %" << gep->GetName() << " = getelementptr ";
// Print element type first, then the pointer type and pointer value
const auto ptrType = gep->GetPtr()->GetType();
const PointerType* pty = static_cast<const PointerType*>(ptrType.get());
os << TypeToString(*pty->GetPointeeType()) << ", "
<< TypeToString(*ptrType) << " " << ValueToString(gep->GetPtr());
for (auto* idx : gep->GetIndices()) {
os << ", " << TypeToString(*idx->GetType()) << " " << ValueToString(idx);
}
os << "\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " %" << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType());
for (const auto& incoming : phi->GetIncomings()) {
os << " [ " << ValueToString(incoming.first) << ", %"
<< incoming.second->GetName() << " ]";
}
os << "\n";
break;
}
}
}
}

@ -52,7 +52,9 @@ Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; }
bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || opcode_ == Opcode::CondBr;
}
BasicBlock* Instruction::GetParent() const { return parent_; }
@ -61,8 +63,8 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && op != Opcode::Div && op != Opcode::Mod) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持算术操作"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
@ -74,8 +76,8 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
}
if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
if (!type_->IsInt32() && !type_->IsFloat32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32 和 float"));
}
AddOperand(lhs);
AddOperand(rhs);
@ -87,21 +89,22 @@ Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
AddOperand(val);
// val may be nullptr to represent a void return; only add operand when
// a returned value is present.
if (val) {
AddOperand(val);
}
}
Value* ReturnInst::GetValue() const { return GetOperand(0); }
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ || !type_->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
if (!type_ || !type_->IsPointer()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 类型必须为指针"));
}
}
@ -110,12 +113,12 @@ LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "LoadInst ptr 必须为指针类型"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
const auto* ptr_ty = static_cast<const PointerType*>(ptr->GetType().get());
if (!type_ || *type_ != *ptr_ty->GetPointeeType()) {
throw std::runtime_error(FormatError("ir", "LoadInst 类型不匹配"));
}
AddOperand(ptr);
}
@ -133,12 +136,12 @@ StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "StoreInst ptr 必须为指针类型"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
const auto* ptr_ty = static_cast<const PointerType*>(ptr->GetType().get());
if (!val->GetType() || *val->GetType() != *ptr_ty->GetPointeeType()) {
throw std::runtime_error(FormatError("ir", "StoreInst 类型不匹配"));
}
AddOperand(val);
AddOperand(ptr);
@ -148,4 +151,120 @@ Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); }
CmpInst::CmpInst(Opcode op, Predicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(op, Type::GetInt32Type(), std::move(name)), pred_(pred), lhs_(lhs), rhs_(rhs) {
if (op != Opcode::ICmp && op != Opcode::FCmp) {
throw std::runtime_error(FormatError("ir", "CmpInst 仅支持 ICmp 和 FCmp"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "CmpInst 缺少操作数"));
}
AddOperand(lhs);
AddOperand(rhs);
}
BranchInst::BranchInst(BasicBlock* target)
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
if (!target) {
throw std::runtime_error(FormatError("ir", "BranchInst 缺少目标基本块"));
}
AddOperand(target);
}
BasicBlock* BranchInst::GetTarget() const { return static_cast<BasicBlock*>(GetOperand(0)); }
CondBranchInst::CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb)
: Instruction(Opcode::CondBr, Type::GetVoidType(), "") {
if (!cond || !true_bb || !false_bb) {
throw std::runtime_error(FormatError("ir", "CondBranchInst 缺少操作数"));
}
AddOperand(cond);
AddOperand(true_bb);
AddOperand(false_bb);
}
Value* CondBranchInst::GetCond() const { return GetOperand(0); }
BasicBlock* CondBranchInst::GetTrueBlock() const { return static_cast<BasicBlock*>(GetOperand(1)); }
BasicBlock* CondBranchInst::GetFalseBlock() const { return static_cast<BasicBlock*>(GetOperand(2)); }
CallInst::CallInst(Function* callee, std::vector<Value*> args, std::string name)
: Instruction(Opcode::Call, callee->GetType(), std::move(name)) {
if (!callee) {
throw std::runtime_error(FormatError("ir", "CallInst 缺少被调用函数"));
}
AddOperand(callee);
for (auto* arg : args) {
if (!arg) {
throw std::runtime_error(FormatError("ir", "CallInst 参数不能为空"));
}
AddOperand(arg);
}
}
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* val, BasicBlock* block) {
if (!val || !block) {
throw std::runtime_error(FormatError("ir", "PhiInst AddIncoming 参数不能为空"));
}
AddOperand(val);
AddOperand(block);
}
GetElementPtrInst::GetElementPtrInst(std::shared_ptr<Type> ty, Value* ptr,
std::vector<Value*> indices, std::string name)
: Instruction(Opcode::GEP, std::move(ty), std::move(name)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "GetElementPtrInst 缺少指针"));
}
AddOperand(ptr);
for (auto* idx : indices) {
if (!idx) {
throw std::runtime_error(FormatError("ir", "GetElementPtrInst 索引不能为空"));
}
AddOperand(idx);
}
}
Function* CallInst::GetCallee() const {
return static_cast<Function*>(GetOperand(0));
}
const std::vector<Value*>& CallInst::GetArgs() const {
// 返回参数列表(跳过被调用函数)
static std::vector<Value*> args;
args.clear();
for (size_t i = 1; i < GetNumOperands(); ++i) {
args.push_back(GetOperand(i));
}
return args;
}
const std::vector<std::pair<Value*, BasicBlock*>>& PhiInst::GetIncomings() const {
// Phi 指令的操作数是成对的:值和基本块
static std::vector<std::pair<Value*, BasicBlock*>> incomings;
incomings.clear();
for (size_t i = 0; i < GetNumOperands(); i += 2) {
Value* val = GetOperand(i);
BasicBlock* block = static_cast<BasicBlock*>(GetOperand(i + 1));
incomings.emplace_back(val, block);
}
return incomings;
}
Value* GetElementPtrInst::GetPtr() const {
return GetOperand(0);
}
const std::vector<Value*>& GetElementPtrInst::GetIndices() const {
// 返回索引列表(跳过指针)
static std::vector<Value*> indices;
indices.clear();
for (size_t i = 1; i < GetNumOperands(); ++i) {
indices.push_back(GetOperand(i));
}
return indices;
}
} // namespace ir

@ -9,13 +9,30 @@ Context& Module::GetContext() { return context_; }
const Context& Module::GetContext() const { return context_; }
Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) {
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type)));
std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types) {
functions_.push_back(
std::make_unique<Function>(name, std::move(ret_type), std::move(param_types)));
return functions_.back().get();
}
GlobalValue* Module::CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> ty,
ConstantValue* init) {
auto gv = std::make_unique<GlobalValue>(std::move(ty), name);
if (init) {
gv->SetInitializer(init);
}
global_vars_.push_back(std::move(gv));
return global_vars_.back().get();
}
const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_;
}
const std::vector<std::unique_ptr<GlobalValue>>& Module::GetGlobalVariables() const {
return global_vars_;
}
} // namespace ir

@ -1,8 +1,20 @@
// 当前仅支持 void、i32 和 i32*。
#include "ir/IR.h"
#include <unordered_map>
#include <functional>
namespace ir {
// 用于缓存复合类型的静态映射(简单实现)
static std::unordered_map<std::size_t, std::shared_ptr<Type>> pointer_cache;
static std::unordered_map<std::size_t, std::shared_ptr<Type>> array_cache;
static std::unordered_map<std::size_t, std::shared_ptr<Type>> function_cache;
// 简单哈希组合函数
static std::size_t hash_combine(std::size_t seed, std::size_t v) {
return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}
Type::Type(Kind k) : kind_(k) {}
const std::shared_ptr<Type>& Type::GetVoidType() {
@ -15,17 +27,133 @@ const std::shared_ptr<Type>& Type::GetInt32Type() {
return type;
}
const std::shared_ptr<Type>& Type::GetFloat32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Float32);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Label);
return type;
}
// 兼容旧的 i32* 类型,返回指向 i32 的指针类型
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32);
static const std::shared_ptr<Type> type = GetPointerType(GetInt32Type());
return type;
}
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> pointee) {
// 简单缓存:使用 pointee 的地址作为键(实际应使用更可靠的标识,但作为演示足够)
std::size_t key = reinterpret_cast<std::size_t>(pointee.get());
auto it = pointer_cache.find(key);
if (it != pointer_cache.end()) {
return it->second;
}
auto ptr_type = std::make_shared<PointerType>(pointee);
pointer_cache[key] = ptr_type;
return ptr_type;
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> elem, size_t size) {
// 使用元素类型指针和大小组合哈希
std::size_t seed = 0;
seed = hash_combine(seed, reinterpret_cast<std::size_t>(elem.get()));
seed = hash_combine(seed, size);
auto it = array_cache.find(seed);
if (it != array_cache.end()) {
return it->second;
}
auto arr_type = std::make_shared<ArrayType>(elem, size);
array_cache[seed] = arr_type;
return arr_type;
}
std::shared_ptr<Type> Type::GetFunctionType(std::shared_ptr<Type> ret,
std::vector<std::shared_ptr<Type>> params) {
// 哈希组合:返回类型 + 参数类型列表
std::size_t seed = reinterpret_cast<std::size_t>(ret.get());
for (const auto& p : params) {
seed = hash_combine(seed, reinterpret_cast<std::size_t>(p.get()));
}
auto it = function_cache.find(seed);
if (it != function_cache.end()) {
return it->second;
}
auto func_type = std::make_shared<FunctionType>(ret, std::move(params));
function_cache[seed] = func_type;
return func_type;
}
Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; }
bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
bool Type::IsFloat32() const { return kind_ == Kind::Float32; }
bool Type::IsPointer() const { return kind_ == Kind::Pointer; }
bool Type::IsArray() const { return kind_ == Kind::Array; }
bool Type::IsFunction() const { return kind_ == Kind::Function; }
bool Type::IsLabel() const { return kind_ == Kind::Label; }
// 兼容旧代码,检查是否为 i32* 类型
bool Type::IsPtrInt32() const {
if (!IsPointer()) return false;
const auto* ptr_ty = static_cast<const PointerType*>(this);
return ptr_ty->GetPointeeType()->IsInt32();
}
// 检查是否为 float32* 类型
bool Type::IsPtrFloat32() const {
if (!IsPointer()) return false;
const auto* ptr_ty = static_cast<const PointerType*>(this);
return ptr_ty->GetPointeeType()->IsFloat32();
}
bool Type::operator==(const Type& other) const {
if (kind_ != other.kind_) return false;
switch (kind_) {
case Kind::Void:
case Kind::Int32:
case Kind::Float32:
case Kind::Label:
return true;
case Kind::Pointer: {
const auto* this_ptr = static_cast<const PointerType*>(this);
const auto* other_ptr = static_cast<const PointerType*>(&other);
return *this_ptr->GetPointeeType() == *other_ptr->GetPointeeType();
}
case Kind::Array: {
const auto* this_arr = static_cast<const ArrayType*>(this);
const auto* other_arr = static_cast<const ArrayType*>(&other);
return this_arr->GetSize() == other_arr->GetSize() &&
*this_arr->GetElementType() == *other_arr->GetElementType();
}
case Kind::Function: {
const auto* this_func = static_cast<const FunctionType*>(this);
const auto* other_func = static_cast<const FunctionType*>(&other);
if (*this_func->GetReturnType() != *other_func->GetReturnType()) return false;
const auto& this_params = this_func->GetParamTypes();
const auto& other_params = other_func->GetParamTypes();
if (this_params.size() != other_params.size()) return false;
for (size_t i = 0; i < this_params.size(); ++i) {
if (*this_params[i] != *other_params[i]) return false;
}
return true;
}
default:
return false;
}
}
bool Type::operator!=(const Type& other) const {
return !(*this == other);
}
} // namespace ir
} // namespace ir

@ -20,7 +20,19 @@ bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool Value::IsPtrInt32() const {
if (!type_ || !type_->IsPointer()) return false;
const auto* ptr_ty = static_cast<const PointerType*>(type_.get());
return ptr_ty->GetPointeeType()->IsInt32();
}
bool Value::IsPtrFloat32() const {
if (!type_ || !type_->IsPointer()) return false;
const auto* ptr_ty = static_cast<const PointerType*>(type_.get());
return ptr_ty->GetPointeeType()->IsFloat32();
}
bool Value::IsFloat32() const { return type_ && type_->IsFloat32(); }
bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr;
@ -38,6 +50,10 @@ bool Value::IsFunction() const {
return dynamic_cast<const Function*>(this) != nullptr;
}
bool Value::IsGlobalValue() const {
return dynamic_cast<const GlobalValue*>(this) != nullptr;
}
void Value::AddUse(User* user, size_t operand_index) {
if (!user) return;
uses_.push_back(Use(this, user, operand_index));
@ -80,4 +96,11 @@ ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) {}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float v)
: ConstantValue(std::move(ty), ""), value_(v) {}
ConstantArray::ConstantArray(std::shared_ptr<Type> ty,
std::vector<ConstantValue*> elems)
: ConstantValue(std::move(ty), ""), elements_(std::move(elems)) {}
} // namespace ir

@ -5,26 +5,17 @@
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
#include <functional>
namespace {
// helper functions removed; VarDef uses Ident() directly per current grammar.
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
return lvalue.ID()->getText();
}
} // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
for (auto* item : ctx->blockItem()) {
if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
// 当前语法要求 return 为块内最后一条语句;命中后可停止生成。
break;
}
}
@ -32,6 +23,52 @@ std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
return {};
}
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "缺少常量定义"));
if (!ctx->Ident()) throw std::runtime_error(FormatError("irgen", "常量声明缺少名称"));
if (!ctx->constInitVal()) throw std::runtime_error(FormatError("irgen", "常量必须初始化"));
if (storage_map_.find(ctx) != storage_map_.end()) throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
auto* slot = (current_btype_ == "float") ?
static_cast<ir::AllocaInst*>(builder_.CreateAllocaFloat(module_.GetContext().NextTemp())) :
static_cast<ir::AllocaInst*>(builder_.CreateAllocaI32(module_.GetContext().NextTemp()));
storage_map_[ctx] = slot;
name_map_[ctx->Ident()->getText()] = slot;
// Try to evaluate a scalar const initializer
ir::ConstantValue* cinit = nullptr;
try {
auto* initval = ctx->constInitVal();
if (initval && initval->constExp() && initval->constExp()->addExp()) {
if (current_btype_ == "float") {
auto* add = initval->constExp()->addExp();
float fv = std::stof(add->getText());
cinit = module_.GetContext().GetConstFloat(fv);
} else {
auto* add = initval->constExp()->addExp();
int iv = std::stoi(add->getText());
cinit = module_.GetContext().GetConstInt(iv);
}
}
} catch(...) {
// fallback: try evaluate via visitor
try {
auto* add = ctx->constInitVal()->constExp()->addExp();
ir::Value* v = std::any_cast<ir::Value*>(add->accept(this));
if (auto* cv = dynamic_cast<ir::ConstantValue*>(v)) cinit = cv;
} catch(...) {}
}
if (cinit) builder_.CreateStore(cinit, slot);
else builder_.CreateStore((current_btype_=="float"? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0)), slot);
// record simple integer consts for dimension evaluation
try {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(cinit)) {
const_values_[ctx->Ident()->getText()] = ci->GetValue();
}
} catch(...) {}
return {};
}
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this));
@ -63,15 +100,26 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
if (ctx->varDecl()) {
auto* vdecl = ctx->varDecl();
if (!vdecl->bType() || vdecl->bType()->getText() != "int") {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
}
for (auto* var_def : vdecl->varDef()) {
if (var_def) var_def->accept(this);
}
return {};
}
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
if (ctx->constDecl()) {
auto* cdecl = ctx->constDecl();
if (!cdecl->bType()) throw std::runtime_error(FormatError("irgen", "缺少常量基类型"));
current_btype_ = cdecl->bType()->getText();
if (current_btype_ != "int" && current_btype_ != "float") throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int/float 常量声明"));
for (auto* const_def : cdecl->constDef()) if (const_def) const_def->accept(this);
current_btype_.clear();
return {};
}
var_def->accept(this);
return {};
throw std::runtime_error(FormatError("irgen", "暂不支持的声明类型"));
}
@ -83,18 +131,164 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
}
if (!ctx->lValue()) {
if (!ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
GetLValueName(*ctx->lValue());
if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
// check if this is an array declaration (has constExp dimensions)
if (!ctx->constExp().empty()) {
// parse dims
std::vector<int> dims;
for (auto* ce : ctx->constExp()) {
try {
int v = 0;
auto anyv = sema_.GetConstVal(ce);
if (anyv.has_value()) {
if (anyv.type() == typeid(int)) v = std::any_cast<int>(anyv);
else if (anyv.type() == typeid(long)) v = (int)std::any_cast<long>(anyv);
else throw std::runtime_error("not-const-int");
} else {
// try simple patterns like NUM or IDENT+NUM or NUM+IDENT
std::string s = ce->addExp()->getText();
s.erase(std::remove_if(s.begin(), s.end(), ::isspace), s.end());
auto pos = s.find('+');
if (pos == std::string::npos) {
// plain number or identifier
try { v = std::stoi(s); }
catch(...) {
// try lookup identifier in recorded consts or symbol table
auto it = const_values_.find(s);
if (it != const_values_.end()) v = (int)it->second;
else {
VarInfo vi; void* declctx = nullptr;
if (sema_.GetSymbolTable().LookupVar(s, vi, declctx) && vi.const_val.has_value()) {
if (vi.const_val.type() == typeid(int)) v = std::any_cast<int>(vi.const_val);
else if (vi.const_val.type() == typeid(long)) v = (int)std::any_cast<long>(vi.const_val);
else throw std::runtime_error("not-const-int");
} else throw std::runtime_error("not-const-int");
}
}
} else {
// form A+B where A or B may be ident or number
std::string L = s.substr(0, pos);
std::string R = s.substr(pos + 1);
int lv = 0, rv = 0; bool ok = false;
// try left
try { lv = std::stoi(L); ok = true; } catch(...) {
auto it = const_values_.find(L);
if (it != const_values_.end()) { lv = (int)it->second; ok = true; }
else {
VarInfo vi; void* declctx = nullptr;
if (sema_.GetSymbolTable().LookupVar(L, vi, declctx) && vi.const_val.has_value()) {
if (vi.const_val.type() == typeid(int)) lv = std::any_cast<int>(vi.const_val);
else if (vi.const_val.type() == typeid(long)) lv = (int)std::any_cast<long>(vi.const_val);
ok = true;
}
}
}
// try right
try { rv = std::stoi(R); ok = ok && true; } catch(...) {
auto it2 = const_values_.find(R);
if (it2 != const_values_.end()) { rv = (int)it2->second; ok = ok && true; }
else {
VarInfo vi2; void* declctx2 = nullptr;
if (sema_.GetSymbolTable().LookupVar(R, vi2, declctx2) && vi2.const_val.has_value()) {
if (vi2.const_val.type() == typeid(int)) rv = std::any_cast<int>(vi2.const_val);
else if (vi2.const_val.type() == typeid(long)) rv = (int)std::any_cast<long>(vi2.const_val);
ok = ok && true;
} else ok = false;
}
}
if (!ok) throw std::runtime_error("not-const-int");
v = lv + rv;
}
}
dims.push_back(v);
} catch (...) {
throw std::runtime_error(FormatError("irgen", "数组维度必须为常量整数"));
}
}
std::shared_ptr<ir::Type> elemTy = (current_btype_ == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type();
// build nested array type
std::function<std::shared_ptr<ir::Type>(size_t)> makeArrayType = [&](size_t level) -> std::shared_ptr<ir::Type> {
if (level + 1 >= dims.size()) return ir::Type::GetArrayType(elemTy, dims[level]);
auto sub = makeArrayType(level + 1);
return ir::Type::GetArrayType(sub, dims[level]);
};
auto fullArrayTy = makeArrayType(0);
auto arr_ptr_ty = ir::Type::GetPointerType(fullArrayTy);
auto* array_slot = builder_.CreateAlloca(arr_ptr_ty, module_.GetContext().NextTemp());
storage_map_[ctx] = array_slot;
name_map_[ctx->Ident()->getText()] = array_slot;
// compute spans and total scalar slots
int nlevels = (int)dims.size();
std::vector<int> span(nlevels);
int total = 1;
for (int i = nlevels - 1; i >= 0; --i) {
if (i == nlevels - 1) span[i] = 1;
else span[i] = span[i + 1] * dims[i + 1];
total *= dims[i];
}
ir::Value* zero = elemTy->IsFloat32() ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0);
std::vector<ir::Value*> slots(total, zero);
// process initializer (if any) into linear slots
if (auto* init_value = ctx->initVal()) {
std::function<void(SysYParser::InitValContext*, int, int&)> process_group;
process_group = [&](SysYParser::InitValContext* init, int level, int& pos) {
if (level >= nlevels) return;
int sub_span = span[level];
int elems = dims[level];
if (!init) { pos += elems * sub_span; return; }
for (auto* child : init->initVal()) {
if (pos >= total) break;
if (!child) { pos += 1; continue; }
if (!child->initVal().empty()) {
int subpos = pos;
int inner = subpos;
process_group(child, level + 1, inner);
pos = subpos + sub_span;
} else if (child->exp()) {
try { ir::Value* v = EvalExpr(*child->exp()); if (pos < total) slots[pos] = v; } catch(...) {}
pos += 1;
} else { pos += 1; }
}
};
int pos0 = 0;
process_group(init_value, 0, pos0);
}
// emit stores for each scalar slot in row-major order
for (int idx = 0; idx < total; ++idx) {
std::vector<int> indices;
int rem = idx;
for (int L = 0; L < nlevels; ++L) {
int ind = rem / span[L];
indices.push_back(ind % dims[L]);
rem = rem % span[L];
}
std::vector<ir::Value*> gep_inds;
gep_inds.push_back(module_.GetContext().GetConstInt(0));
for (int v : indices) gep_inds.push_back(module_.GetContext().GetConstInt(v));
while (gep_inds.size() < (size_t)(1 + nlevels)) gep_inds.push_back(module_.GetContext().GetConstInt(0));
auto* gep = builder_.CreateGEP(array_slot, gep_inds, module_.GetContext().NextTemp());
builder_.CreateStore(slots[idx], gep);
}
return {};
}
// scalar variable
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) {
if (auto* init_value = ctx->initVal()) {
if (!init_value->exp()) {
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
}

@ -7,7 +7,7 @@
#include "utils/Log.h"
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) {
IRGenContext& sema) {
auto module = std::make_unique<ir::Module>();
IRGenImpl gen(*module, sema);
tree.accept(&gen);

@ -25,20 +25,24 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
}
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
}
return EvalExpr(*ctx->exp());
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 primary 表达式"));
if (ctx->exp()) return EvalExpr(*ctx->exp());
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
throw std::runtime_error(FormatError("irgen", "不支持的 primary 表达式"));
}
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法数字字面量"));
if (ctx->IntConst()) {
return static_cast<ir::Value*>(builder_.CreateConstInt(std::stoi(ctx->getText())));
}
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
if (ctx->FloatConst()) {
return static_cast<ir::Value*>(builder_.CreateConstFloat(std::stof(ctx->getText())));
}
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量或浮点字面量"));
}
// 变量使用的处理流程:
@ -47,34 +51,258 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
auto* decl = sema_.ResolveVarUse(ctx->var());
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
// find storage by matching declaration node stored in Sema context
// Sema stores types/decl contexts in IRGenContext maps; here we search storage_map_ by name
std::string name = ctx->Ident()->getText();
// 优先使用按名称的快速映射
auto nit = name_map_.find(name);
if (nit != name_map_.end()) {
// 支持下标访问:若有索引表达式列表,则生成 GEP + load
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
// 首个索引用于穿过数组对象
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) {
indices.push_back(EvalExpr(*e));
}
auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
// 如果映射到的是常量,直接返回常量值;否则按原来行为从槽位 load
if (nit->second->IsConstant()) return nit->second;
return static_cast<ir::Value*>(builder_.CreateLoad(nit->second, module_.GetContext().NextTemp()));
}
auto it = storage_map_.find(decl);
if (it == storage_map_.end()) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
for (auto& kv : storage_map_) {
if (!kv.first) continue;
if (auto* vdef = dynamic_cast<SysYParser::VarDefContext*>(kv.first)) {
if (vdef->Ident() && vdef->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
} else if (auto* fparam = dynamic_cast<SysYParser::FuncFParamContext*>(kv.first)) {
if (fparam->Ident() && fparam->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
} else if (auto* cdef = dynamic_cast<SysYParser::ConstDefContext*>(kv.first)) {
if (cdef->Ident() && cdef->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
}
}
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
throw std::runtime_error(FormatError("irgen", "变量声明缺少存储槽位: " + name));
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
try {
// left-associative: fold across all mulExp operands
if (ctx->mulExp().size() == 1) return ctx->mulExp(0)->accept(this);
ir::Value* cur = std::any_cast<ir::Value*>(ctx->mulExp(0)->accept(this));
// extract operator sequence from text (in-order '+' or '-')
std::string text = ctx->getText();
std::vector<char> ops;
for (char c : text) if (c == '+' || c == '-') ops.push_back(c);
for (size_t i = 1; i < ctx->mulExp().size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp(i)->accept(this));
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '+';
ir::Opcode op = (opch == '-') ? ir::Opcode::Sub : ir::Opcode::Add;
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
}
return static_cast<ir::Value*>(cur);
} catch (const std::exception& e) {
LogInfo(std::string("[irgen] exception in visitAddExp text=") + ctx->getText() + ", err=" + e.what(), std::cerr);
throw;
}
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
if (ctx->unaryExp().size() == 1) return ctx->unaryExp(0)->accept(this);
ir::Value* cur = std::any_cast<ir::Value*>(ctx->unaryExp(0)->accept(this));
// extract operator sequence for '*', '/', '%'
std::string text = ctx->getText();
std::vector<char> ops;
for (char c : text) if (c == '*' || c == '/' || c == '%') ops.push_back(c);
for (size_t i = 1; i < ctx->unaryExp().size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp(i)->accept(this));
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '*';
ir::Opcode op = ir::Opcode::Mul;
if (opch == '/') op = ir::Opcode::Div;
else if (opch == '%') op = ir::Opcode::Mod;
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
}
return static_cast<ir::Value*>(cur);
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
// function call: Ident '(' funcRParams? ')'
if (ctx->Ident() && ctx->getText().find("(") != std::string::npos) {
std::string fname = ctx->Ident()->getText();
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* e : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*e));
}
}
// find existing function or create an external declaration (assume int return)
ir::Function* callee = nullptr;
for (auto &fup : module_.GetFunctions()) {
if (fup && fup->GetName() == fname) { callee = fup.get(); break; }
}
if (!callee) {
std::vector<std::shared_ptr<ir::Type>> param_types;
for (auto* a : args) {
if (a && a->IsFloat32()) param_types.push_back(ir::Type::GetFloat32Type());
else param_types.push_back(ir::Type::GetInt32Type());
}
callee = module_.CreateFunction(fname, ir::Type::GetInt32Type(), param_types);
}
return static_cast<ir::Value*>(builder_.CreateCall(callee, args, module_.GetContext().NextTemp()));
}
if (ctx->unaryExp()) {
ir::Value* val = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "+") return static_cast<ir::Value*>(val);
else if (ctx->unaryOp() && ctx->unaryOp()->getText() == "-") {
// 负号0 - val区分整型/浮点
if (val->IsFloat32()) {
ir::Value* zero = builder_.CreateConstFloat(0.0f);
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
}
}
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "!") {
// logical not: produce int 1 if val == 0, else 0
if (val->IsFloat32()) {
ir::Value* zerof = builder_.CreateConstFloat(0.0f);
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::CmpInst::EQ, val, zerof, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateICmp(ir::CmpInst::EQ, val, zero, module_.GetContext().NextTemp()));
}
}
}
throw std::runtime_error(FormatError("irgen", "不支持的一元运算"));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
if (ctx->addExp().size() == 1) return ctx->addExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp(1)->accept(this));
// 类型提升
if (lhs->IsFloat32() && rhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
}
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
std::string text = ctx->getText();
if (text.find("<=") != std::string::npos) pred = ir::CmpInst::LE;
else if (text.find(">=") != std::string::npos) pred = ir::CmpInst::GE;
else if (text.find("<") != std::string::npos) pred = ir::CmpInst::LT;
else if (text.find(">") != std::string::npos) pred = ir::CmpInst::GT;
if (lhs->IsFloat32() || rhs->IsFloat32()) {
return static_cast<ir::Value*>(
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
if (ctx->relExp().size() == 1) return ctx->relExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp(1)->accept(this));
// 类型提升
if (lhs->IsFloat32() && rhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
}
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
std::string text = ctx->getText();
if (text.find("==") != std::string::npos) pred = ir::CmpInst::EQ;
else if (text.find("!=") != std::string::npos) pred = ir::CmpInst::NE;
if (lhs->IsFloat32() || rhs->IsFloat32()) {
return static_cast<ir::Value*>(
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
return static_cast<ir::Value*>(
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
module_.GetContext().NextTemp()));
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
if (ctx->eqExp().size() == 1) return ctx->eqExp(0)->accept(this);
// For simplicity, treat as int (0 or 1)
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->eqExp(1)->accept(this));
// lhs && rhs : (lhs != 0) && (rhs != 0)
ir::Value* zero = builder_.CreateConstInt(0);
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateMul(lhs_ne, rhs_ne, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
if (ctx->lAndExp().size() == 1) return ctx->lAndExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lAndExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->lAndExp(1)->accept(this));
// lhs || rhs : (lhs != 0) || (rhs != 0)
ir::Value* zero = builder_.CreateConstInt(0);
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
ir::Value* or_val = builder_.CreateAdd(lhs_ne, rhs_ne, module_.GetContext().NextTemp());
ir::Value* one = builder_.CreateConstInt(1);
return static_cast<ir::Value*>(
builder_.CreateICmp(ir::CmpInst::GE, or_val, one, module_.GetContext().NextTemp()));
}

@ -5,6 +5,7 @@
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
#include <functional>
namespace {
@ -21,7 +22,7 @@ void VerifyFunctionStructure(const ir::Function& func) {
} // namespace
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
IRGenImpl::IRGenImpl(ir::Module& module, IRGenContext& sema)
: module_(module),
sema_(sema),
func_(nullptr),
@ -38,11 +39,123 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func) {
// for simplicity take first function definition
if (ctx->funcDef().empty()) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
func->accept(this);
// 先处理顶层声明(仅支持简单的 const int 初始化作为全局常量)
for (auto* decl : ctx->decl()) {
if (!decl) continue;
if (decl->constDecl()) {
auto* cdecl = decl->constDecl();
for (auto* cdef : cdecl->constDef()) {
if (!cdef || !cdef->Ident() || !cdef->constInitVal() || !cdef->constInitVal()->constExp()) continue;
// 仅支持形如: const int a = 10; 的简单常量初始化(字面量)
auto* add = cdef->constInitVal()->constExp()->addExp();
if (!add) continue;
try {
int v = std::stoi(add->getText());
auto* cval = module_.GetContext().GetConstInt(v);
name_map_[cdef->Ident()->getText()] = cval;
} catch (...) {
// 无法解析则跳过,全局复杂常量暂不支持
}
}
}
// 支持简单的全局变量声明(数组或标量),初始化为零
if (decl->varDecl()) {
auto* vdecl = decl->varDecl();
if (!vdecl->bType()) continue;
std::string btype = vdecl->bType()->getText();
for (auto* vdef : vdecl->varDef()) {
if (!vdef) continue;
LogInfo(std::string("[irgen] global varDef text=") + vdef->getText() + std::string(" ident=") + (vdef->Ident() ? vdef->Ident()->getText() : std::string("<none>")) + std::string(" dims=") + std::to_string((int)vdef->constExp().size()), std::cerr);
if (!vdef || !vdef->Ident()) continue;
std::string name = vdef->Ident()->getText();
// array globals
if (!vdef->constExp().empty()) {
std::vector<int> dims;
bool ok = true;
for (auto* ce : vdef->constExp()) {
try {
int val = 0;
auto anyv = sema_.GetConstVal(ce);
if (anyv.has_value()) {
if (anyv.type() == typeid(int)) val = std::any_cast<int>(anyv);
else if (anyv.type() == typeid(long)) val = (int)std::any_cast<long>(anyv);
else throw std::runtime_error("not-const-int");
} else {
// try literal parse
try {
val = std::stoi(ce->addExp()->getText());
} catch (...) {
// try lookup in name_map_ for previously created const
std::string t = ce->addExp()->getText();
auto it = name_map_.find(t);
if (it != name_map_.end() && it->second && it->second->IsConstant()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(it->second)) {
val = ci->GetValue();
} else {
ok = false; break;
}
} else {
ok = false; break;
}
}
}
dims.push_back(val);
} catch (...) { ok = false; break; }
}
if (!ok) continue;
// build zero constant array similar to visitVarDef
std::function<ir::ConstantValue*(const std::vector<int>&, size_t, std::shared_ptr<ir::Type>)> buildZero;
buildZero = [&](const std::vector<int>& ds, size_t idx, std::shared_ptr<ir::Type> elemTy) -> ir::ConstantValue* {
if (idx >= ds.size()) return nullptr;
std::vector<ir::ConstantValue*> elems;
if (idx + 1 == ds.size()) {
for (int i = 0; i < ds[idx]; ++i) {
if (elemTy->IsFloat32()) elems.push_back(module_.GetContext().GetConstFloat(0.0f));
else elems.push_back(module_.GetContext().GetConstInt(0));
}
} else {
for (int i = 0; i < ds[idx]; ++i) {
ir::ConstantValue* sub = buildZero(ds, idx + 1, elemTy);
if (sub) elems.push_back(sub);
else elems.push_back(module_.GetContext().GetConstInt(0));
}
}
std::function<std::shared_ptr<ir::Type>(size_t)> makeArrayType = [&](size_t level) -> std::shared_ptr<ir::Type> {
if (level + 1 >= ds.size()) return ir::Type::GetArrayType(elemTy, ds[level]);
auto sub = makeArrayType(level + 1);
return ir::Type::GetArrayType(sub, ds[level]);
};
auto at_real = makeArrayType(idx);
return new ir::ConstantArray(at_real, elems);
};
std::shared_ptr<ir::Type> elemTy = (btype == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type();
ir::ConstantValue* zero = buildZero(dims, 0, elemTy);
auto gvty = ir::Type::GetPointerType(zero ? zero->GetType() : ir::Type::GetPointerType(elemTy));
ir::GlobalValue* gv = module_.CreateGlobalVariable(name, gvty, zero);
name_map_[name] = gv;
LogInfo(std::string("[irgen] created global ") + name, std::cerr);
} else {
// scalar global
std::shared_ptr<ir::Type> elemTy = (btype == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type();
ir::ConstantValue* init = nullptr;
if (btype == "float") init = module_.GetContext().GetConstFloat(0.0f);
else init = module_.GetContext().GetConstInt(0);
ir::GlobalValue* gv = module_.CreateGlobalVariable(name, ir::Type::GetPointerType(elemTy), init);
name_map_[name] = gv;
LogInfo(std::string("[irgen] created global ") + name, std::cerr);
}
}
}
}
// 生成编译单元中所有函数定义(之前只生成第一个函数)
for (size_t i = 0; i < ctx->funcDef().size(); ++i) {
if (ctx->funcDef(i)) ctx->funcDef(i)->accept(this);
}
return {};
}
@ -61,26 +174,62 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
if (!ctx->blockStmt()) {
if (!ctx->block()) {
throw std::runtime_error(FormatError("irgen", "函数体为空"));
}
if (!ctx->ID()) {
if (!ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "缺少函数名"));
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
if (!ctx->funcType()) {
throw std::runtime_error(FormatError("irgen", "缺少函数返回类型"));
}
std::shared_ptr<ir::Type> ret_type;
if (ctx->funcType()->getText() == "int") ret_type = ir::Type::GetInt32Type();
else if (ctx->funcType()->getText() == "float") ret_type = ir::Type::GetFloat32Type();
else if (ctx->funcType()->getText() == "void") ret_type = ir::Type::GetVoidType();
else throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float/void 函数"));
std::vector<std::shared_ptr<ir::Type>> param_types;
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
if (param->bType()->getText() == "int") param_types.push_back(ir::Type::GetInt32Type());
else if (param->bType()->getText() == "float") param_types.push_back(ir::Type::GetFloat32Type());
else throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float 参数"));
}
}
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
func_ = module_.CreateFunction(ctx->Ident()->getText(), ret_type, param_types);
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
ctx->blockStmt()->accept(this);
// Allocate storage for parameters
if (ctx->funcFParams()) {
int idx = 0;
for (auto* param : ctx->funcFParams()->funcFParam()) {
std::string param_name = param->Ident()->getText();
ir::AllocaInst* alloca = nullptr;
if (param->bType()->getText() == "float") alloca = builder_.CreateAllocaFloat(param_name);
else alloca = builder_.CreateAllocaI32(param_name);
storage_map_[param] = alloca;
name_map_[param_name] = alloca;
// Store the argument value
auto* arg = func_->GetParams()[idx];
builder_.CreateStore(arg, alloca);
idx++;
}
}
ctx->block()->accept(this);
// 如果函数体末尾没有显式终结(如 void 函数没有 return补一个隐式 return
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateRet(nullptr);
}
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
VerifyFunctionStructure(*func_);
return {};

@ -19,21 +19,170 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
if (ctx->returnStmt()) {
return ctx->returnStmt()->accept(this);
std::string text = ctx->getText();
LogInfo("[irgen] visitStmt text='" + text + "' break_size=" + std::to_string(break_targets_.size()) + " cont_size=" + std::to_string(continue_targets_.size()), std::cerr);
// return
if (ctx->getStart()->getText() == "return") {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
} else {
builder_.CreateRet(nullptr);
}
return BlockFlow::Terminated;
}
// assignment: lVal '=' exp
if (ctx->lVal() && text.find("=") != std::string::npos) {
ir::Value* val = EvalExpr(*ctx->exp());
std::string name = ctx->lVal()->Ident()->getText();
// 优先检查按名称的快速映射(支持全局变量)
auto nit = name_map_.find(name);
if (nit != name_map_.end()) {
// 支持带索引的赋值
if (ctx->lVal()->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->lVal()->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp());
builder_.CreateStore(val, gep);
return BlockFlow::Continue;
}
builder_.CreateStore(val, nit->second);
return BlockFlow::Continue;
}
for (auto& kv : storage_map_) {
if (!kv.first) continue;
if (auto* vdef = dynamic_cast<SysYParser::VarDefContext*>(kv.first)) {
if (vdef->Ident() && vdef->Ident()->getText() == name) {
builder_.CreateStore(val, kv.second);
return BlockFlow::Continue;
}
} else if (auto* fparam = dynamic_cast<SysYParser::FuncFParamContext*>(kv.first)) {
if (fparam->Ident() && fparam->Ident()->getText() == name) {
builder_.CreateStore(val, kv.second);
return BlockFlow::Continue;
}
} else if (auto* cdef = dynamic_cast<SysYParser::ConstDefContext*>(kv.first)) {
if (cdef->Ident() && cdef->Ident()->getText() == name) {
builder_.CreateStore(val, kv.second);
return BlockFlow::Continue;
}
}
}
throw std::runtime_error(FormatError("irgen", "变量未声明: " + name));
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}
// if
if (ctx->getStart()->getText() == "if" && ctx->cond()) {
ir::Value* condv = std::any_cast<ir::Value*>(ctx->cond()->lOrExp()->accept(this));
ir::BasicBlock* then_bb = func_->CreateBlock("if.then");
ir::BasicBlock* else_bb = (ctx->stmt().size() > 1) ? func_->CreateBlock("if.else") : nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock("if.merge");
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
if (else_bb) builder_.CreateCondBr(condv, then_bb, else_bb);
else builder_.CreateCondBr(condv, then_bb, merge_bb);
// then
builder_.SetInsertPoint(then_bb);
ctx->stmt(0)->accept(this);
if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(merge_bb);
// else
if (else_bb) {
builder_.SetInsertPoint(else_bb);
ctx->stmt(1)->accept(this);
if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
// while
if (ctx->getStart()->getText() == "while" && ctx->cond()) {
ir::BasicBlock* cond_bb = func_->CreateBlock("while.cond");
ir::BasicBlock* body_bb = func_->CreateBlock("while.body");
ir::BasicBlock* after_bb = func_->CreateBlock("while.after");
builder_.CreateBr(cond_bb);
// cond
builder_.SetInsertPoint(cond_bb);
ir::Value* condv = std::any_cast<ir::Value*>(ctx->cond()->lOrExp()->accept(this));
builder_.CreateCondBr(condv, body_bb, after_bb);
// body
builder_.SetInsertPoint(body_bb);
LogInfo("[irgen] while body about to push targets, before sizes: break=" + std::to_string(break_targets_.size()) + ", cont=" + std::to_string(continue_targets_.size()), std::cerr);
break_targets_.push_back(after_bb);
continue_targets_.push_back(cond_bb);
LogInfo("[irgen] after push: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
ctx->stmt(0)->accept(this);
LogInfo("[irgen] before pop: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
continue_targets_.pop_back();
break_targets_.pop_back();
LogInfo("[irgen] after pop: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(after_bb);
return BlockFlow::Continue;
}
// break
if (ctx->getStart()->getText() == "break") {
if (break_targets_.empty()) {
// fallback: 尝试通过函数块名找目标(不依赖 sema兼容因栈丢失导致的情况
ir::BasicBlock* fallback = nullptr;
for (auto &bb_up : func_->GetBlocks()) {
auto *bb = bb_up.get();
if (!bb) continue;
if (bb->GetName().find("while.after") != std::string::npos) fallback = bb;
}
if (fallback) {
LogInfo("[irgen] emit break (fallback), target=" + fallback->GetName(), std::cerr);
builder_.CreateBr(fallback);
return BlockFlow::Terminated;
}
throw std::runtime_error(FormatError("irgen", "break 不在循环内"));
}
LogInfo("[irgen] emit break, break_targets size=" + std::to_string(break_targets_.size()), std::cerr);
builder_.CreateBr(break_targets_.back());
return BlockFlow::Terminated;
}
// continue
if (ctx->getStart()->getText() == "continue") {
if (continue_targets_.empty()) {
ir::BasicBlock* fallback = nullptr;
for (auto &bb_up : func_->GetBlocks()) {
auto *bb = bb_up.get();
if (!bb) continue;
if (bb->GetName().find("while.cond") != std::string::npos) fallback = bb;
}
if (fallback) {
LogInfo("[irgen] emit continue (fallback), target=" + fallback->GetName(), std::cerr);
builder_.CreateBr(fallback);
return BlockFlow::Terminated;
}
throw std::runtime_error(FormatError("irgen", "continue 不在循环内"));
}
LogInfo("[irgen] emit continue, continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
builder_.CreateBr(continue_targets_.back());
return BlockFlow::Terminated;
}
// block
if (ctx->block()) {
return ctx->block()->accept(this);
}
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
// expression statement
if (ctx->exp()) {
EvalExpr(*ctx->exp());
return BlockFlow::Continue;
}
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
return BlockFlow::Terminated;
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}

@ -1,200 +1,446 @@
#include "sem/Sema.h"
#include <any>
#include "../../include/sem/Sema.h"
#include "SysYParser.h"
#include <stdexcept>
#include <string>
#include <algorithm>
#include <iostream>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
using namespace antlr4;
namespace {
// ===================== 核心访问器实现 =====================
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
}
return lvalue.ID()->getText();
// 1. 编译单元节点访问
std::any SemaVisitor::visitCompUnit(SysYParser::CompUnitContext* ctx) {
// 分析编译单元中的所有子节点
return visitChildren(ctx);
}
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func || !func->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
// 2. 函数定义节点访问
std::any SemaVisitor::visitFuncDef(SysYParser::FuncDefContext* ctx) {
FuncInfo info;
// 通过funcType()获取函数类型
if (ctx->funcType()) {
std::string func_type_text = ctx->funcType()->getText();
if (func_type_text == "void") {
info.ret_type = SymbolType::TYPE_VOID;
} else if (func_type_text == "int") {
info.ret_type = SymbolType::TYPE_INT;
} else if (func_type_text == "float") {
info.ret_type = SymbolType::TYPE_FLOAT;
}
}
if (!func->ID() || func->ID()->getText() != "main") {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
func->accept(this);
if (!seen_return_) {
throw std::runtime_error(
FormatError("sema", "main 函数必须包含 return 语句"));
// 绑定函数名和返回类型
if (ctx->Ident()) {
info.name = ctx->Ident()->getText();
}
return {};
}
ir_ctx_.SetCurrentFuncReturnType(info.ret_type);
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
// 递归分析函数体
if (ctx->block()) {
visit(ctx->block());
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
return std::any();
}
// 3. 声明节点访问
std::any SemaVisitor::visitDecl(SysYParser::DeclContext* ctx) {
return visitChildren(ctx);
}
// 4. 常量声明节点访问
std::any SemaVisitor::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
return visitChildren(ctx);
}
// 5. 变量声明节点访问
std::any SemaVisitor::visitVarDecl(SysYParser::VarDeclContext* ctx) {
return visitChildren(ctx);
}
// 6. 代码块节点访问
std::any SemaVisitor::visitBlock(SysYParser::BlockContext* ctx) {
// 进入新的作用域
ir_ctx_.EnterScope();
// 访问块内的语句
std::any result = visitChildren(ctx);
// 离开作用域
ir_ctx_.LeaveScope();
return result;
}
// 7. 语句节点访问
std::any SemaVisitor::visitStmt(SysYParser::StmtContext* ctx) {
// 赋值语句lVal = exp;
if (ctx->lVal() && ctx->exp()) {
auto l_val_ctx = ctx->lVal();
auto exp_ctx = ctx->exp();
// 解析左右值类型
SymbolType l_type = ir_ctx_.GetType(l_val_ctx);
SymbolType r_type = ir_ctx_.GetType(exp_ctx);
// 类型不匹配报错
if (l_type != r_type && l_type != SymbolType::TYPE_UNKNOWN && r_type != SymbolType::TYPE_UNKNOWN) {
std::string l_type_str = (l_type == SymbolType::TYPE_INT ? "int" : "float");
std::string r_type_str = (r_type == SymbolType::TYPE_INT ? "int" : "float");
std::string err_msg = "赋值类型不匹配,左值为" + l_type_str + ",右值为" + r_type_str;
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg(err_msg, line, col));
}
// 绑定左值类型(同步右值类型)
ir_ctx_.SetType(l_val_ctx, r_type);
}
const auto& items = ctx->blockStmt()->blockItem();
if (items.empty()) {
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
// IF语句
else if (ctx->cond() && ctx->stmt().size() >= 1) {
auto cond_ctx = ctx->cond();
// IF条件必须为整型
SymbolType cond_type = ir_ctx_.GetType(cond_ctx);
if (cond_type != SymbolType::TYPE_INT && cond_type != SymbolType::TYPE_UNKNOWN) {
int line = cond_ctx->getStart()->getLine();
int col = cond_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("if条件表达式必须为整型", line, col));
}
// 递归分析IF体和可能的ELSE体
visit(ctx->stmt(0));
if (ctx->stmt().size() >= 2) {
visit(ctx->stmt(1));
}
}
ctx->blockStmt()->accept(this);
return {};
}
// WHILE语句
else if (ctx->cond() && ctx->stmt().size() >= 1) {
ir_ctx_.EnterLoop(); // 标记进入循环
auto cond_ctx = ctx->cond();
// WHILE条件必须为整型
SymbolType cond_type = ir_ctx_.GetType(cond_ctx);
if (cond_type != SymbolType::TYPE_INT && cond_type != SymbolType::TYPE_UNKNOWN) {
int line = cond_ctx->getStart()->getLine();
int col = cond_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("while条件表达式必须为整型", line, col));
}
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句块"));
}
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
current_item_index_ = i;
total_items_ = items.size();
item->accept(this);
}
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
// 递归分析循环体
visit(ctx->stmt(0));
ir_ctx_.ExitLoop(); // 标记退出循环
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
// BREAK语句
else if (ctx->getText().find("break") != std::string::npos) {
if (!ir_ctx_.InLoop()) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("break只能出现在循环语句中", line, col));
}
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
// CONTINUE语句
else if (ctx->getText().find("continue") != std::string::npos) {
if (!ir_ctx_.InLoop()) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("continue只能出现在循环语句中", line, col));
}
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
// RETURN语句
else if (ctx->getText().find("return") != std::string::npos) {
SymbolType func_ret_type = ir_ctx_.GetCurrentFuncReturnType();
// 有返回表达式的情况
if (ctx->exp()) {
auto exp_ctx = ctx->exp();
SymbolType exp_type = ir_ctx_.GetType(exp_ctx);
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
// 返回类型不匹配报错
if (exp_type != func_ret_type && exp_type != SymbolType::TYPE_UNKNOWN && func_ret_type != SymbolType::TYPE_UNKNOWN) {
std::string ret_type_str = (func_ret_type == SymbolType::TYPE_INT ? "int" : (func_ret_type == SymbolType::TYPE_FLOAT ? "float" : "void"));
std::string exp_type_str = (exp_type == SymbolType::TYPE_INT ? "int" : "float");
std::string err_msg = "return表达式类型与函数返回类型不匹配期望" + ret_type_str + ",实际为" + exp_type_str;
int line = exp_ctx->getStart()->getLine();
int col = exp_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg(err_msg, line, col));
}
}
// 无返回表达式的情况
else {
if (func_ret_type != SymbolType::TYPE_VOID && func_ret_type != SymbolType::TYPE_UNKNOWN) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("非void函数return必须带表达式", line, col));
}
}
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
// 其他语句
return visitChildren(ctx);
}
// 8. 左值节点访问
std::any SemaVisitor::visitLVal(SysYParser::LValContext* ctx) {
return visitChildren(ctx);
}
// 9. 表达式节点访问
std::any SemaVisitor::visitExp(SysYParser::ExpContext* ctx) {
return visitChildren(ctx);
}
// 10. 条件表达式节点访问
std::any SemaVisitor::visitCond(SysYParser::CondContext* ctx) {
return visitChildren(ctx);
}
// 11. 基本表达式节点访问
std::any SemaVisitor::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
return visitChildren(ctx);
}
// 12. 一元表达式节点访问
std::any SemaVisitor::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 带一元运算符的表达式(+/-/!
if (ctx->unaryOp() && ctx->unaryExp()) {
auto op_ctx = ctx->unaryOp();
auto uexp_ctx = ctx->unaryExp();
auto uexp_val = visit(uexp_ctx);
std::string op_text = op_ctx->getText();
SymbolType uexp_type = ir_ctx_.GetType(uexp_ctx);
// 正号 +x → 直接返回原值
if (op_text == "+") {
ir_ctx_.SetType(ctx, uexp_type);
ir_ctx_.SetConstVal(ctx, uexp_val);
return uexp_val;
}
// 负号 -x → 取反
else if (op_text == "-") {
if (ir_ctx_.IsIntType(uexp_val)) {
long val = std::any_cast<long>(uexp_val);
ir_ctx_.SetConstVal(ctx, std::any(-val));
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(-val);
} else if (ir_ctx_.IsFloatType(uexp_val)) {
double val = std::any_cast<double>(uexp_val);
ir_ctx_.SetConstVal(ctx, std::any(-val));
ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT);
return std::any(-val);
}
}
// 逻辑非 !x → 0/1转换
else if (op_text == "!") {
if (ir_ctx_.IsIntType(uexp_val)) {
long val = std::any_cast<long>(uexp_val);
long res = (val == 0) ? 1L : 0L;
ir_ctx_.SetConstVal(ctx, std::any(res));
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(res);
}
}
}
const std::string name = GetLValueName(*var_def->lValue());
if (table_.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
// 函数调用表达式
else if (ctx->Ident() && ctx->funcRParams()) {
// 这里简化处理
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(0L);
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
}
init->exp()->accept(this);
// 基础表达式
else if (ctx->primaryExp()) {
auto val = visit(ctx->primaryExp());
ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->primaryExp()));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
table_.Add(name, var_def);
return {};
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx || !ctx->returnStmt()) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
ctx->returnStmt()->accept(this);
return {};
}
return std::any();
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
// 13. 乘法表达式节点访问
std::any SemaVisitor::visitMulExp(SysYParser::MulExpContext* ctx) {
auto uexps = ctx->unaryExp();
// 单操作数 → 直接返回
if (uexps.size() == 1) {
auto val = visit(uexps[0]);
ir_ctx_.SetType(ctx, ir_ctx_.GetType(uexps[0]));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
return {};
}
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
}
ctx->exp()->accept(this);
return {};
}
// 多操作数 → 依次计算
std::any result = visit(uexps[0]);
SymbolType current_type = ir_ctx_.GetType(uexps[0]);
std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
if (!ctx || !ctx->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
for (size_t i = 1; i < uexps.size(); ++i) {
auto next_uexp = uexps[i];
auto next_val = visit(next_uexp);
SymbolType next_type = ir_ctx_.GetType(next_uexp);
// 类型统一int和float混合转为float
if (current_type == SymbolType::TYPE_INT && next_type == SymbolType::TYPE_FLOAT) {
current_type = SymbolType::TYPE_FLOAT;
} else if (current_type == SymbolType::TYPE_FLOAT && next_type == SymbolType::TYPE_INT) {
current_type = SymbolType::TYPE_FLOAT;
}
// 简化处理:这里假设是乘法运算
if (ir_ctx_.IsIntType(result) && ir_ctx_.IsIntType(next_val)) {
long v1 = std::any_cast<long>(result);
long v2 = std::any_cast<long>(next_val);
result = std::any(v1 * v2);
} else if (ir_ctx_.IsFloatType(result) && ir_ctx_.IsFloatType(next_val)) {
double v1 = std::any_cast<double>(result);
double v2 = std::any_cast<double>(next_val);
result = std::any(v1 * v2);
}
// 更新当前节点类型和常量值
ir_ctx_.SetType(ctx, current_type);
ir_ctx_.SetConstVal(ctx, result);
}
ctx->var()->accept(this);
return {};
}
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
return result;
}
// 14. 加法表达式节点访问
std::any SemaVisitor::visitAddExp(SysYParser::AddExpContext* ctx) {
auto mexps = ctx->mulExp();
// 单操作数 → 直接返回
if (mexps.size() == 1) {
auto val = visit(mexps[0]);
ir_ctx_.SetType(ctx, ir_ctx_.GetType(mexps[0]));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
return {};
}
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
// 多操作数 → 依次计算
std::any result = visit(mexps[0]);
SymbolType current_type = ir_ctx_.GetType(mexps[0]);
for (size_t i = 1; i < mexps.size(); ++i) {
auto next_mexp = mexps[i];
auto next_val = visit(next_mexp);
SymbolType next_type = ir_ctx_.GetType(next_mexp);
// 类型统一
if (current_type == SymbolType::TYPE_INT && next_type == SymbolType::TYPE_FLOAT) {
current_type = SymbolType::TYPE_FLOAT;
} else if (current_type == SymbolType::TYPE_FLOAT && next_type == SymbolType::TYPE_INT) {
current_type = SymbolType::TYPE_FLOAT;
}
// 简化处理:这里假设是加法运算
if (ir_ctx_.IsIntType(result) && ir_ctx_.IsIntType(next_val)) {
long v1 = std::any_cast<long>(result);
long v2 = std::any_cast<long>(next_val);
result = std::any(v1 + v2);
} else if (ir_ctx_.IsFloatType(result) && ir_ctx_.IsFloatType(next_val)) {
double v1 = std::any_cast<double>(result);
double v2 = std::any_cast<double>(next_val);
result = std::any(v1 + v2);
}
ir_ctx_.SetType(ctx, current_type);
ir_ctx_.SetConstVal(ctx, result);
}
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitVar(SysYParser::VarContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
return result;
}
// 15. 关系表达式节点访问
std::any SemaVisitor::visitRelExp(SysYParser::RelExpContext* ctx) {
auto aexps = ctx->addExp();
// 单操作数 → 直接返回
if (aexps.size() == 1) {
auto val = visit(aexps[0]);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return val;
}
const std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
// 多操作数 → 简化处理
std::any result = std::any(1L);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, result);
return result;
}
// 16. 相等表达式节点访问
std::any SemaVisitor::visitEqExp(SysYParser::EqExpContext* ctx) {
auto rexps = ctx->relExp();
// 单操作数 → 直接返回
if (rexps.size() == 1) {
auto val = visit(rexps[0]);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return val;
}
sema_.BindVarUse(ctx, decl);
return {};
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
// 多操作数 → 简化处理
std::any result = std::any(1L);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, result);
private:
SymbolTable table_;
SemanticContext sema_;
bool seen_return_ = false;
size_t current_item_index_ = 0;
size_t total_items_ = 0;
};
return result;
}
} // namespace
// 17. 逻辑与表达式节点访问
std::any SemaVisitor::visitLAndExp(SysYParser::LAndExpContext* ctx) {
return visitChildren(ctx);
}
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
// 18. 逻辑或表达式节点访问
std::any SemaVisitor::visitLOrExp(SysYParser::LOrExpContext* ctx) {
return visitChildren(ctx);
}
// 19. 常量表达式节点访问
std::any SemaVisitor::visitConstExp(SysYParser::ConstExpContext* ctx) {
return visitChildren(ctx);
}
// 20. 数字节点访问
std::any SemaVisitor::visitNumber(SysYParser::NumberContext* ctx) {
// 这里简化处理,实际需要解析整型和浮点型
if (ctx->IntConst()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, std::any(0L));
return std::any(0L);
} else if (ctx->FloatConst()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT);
ir_ctx_.SetConstVal(ctx, std::any(0.0));
return std::any(0.0);
}
return std::any();
}
// 21. 函数参数节点访问
std::any SemaVisitor::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
return visitChildren(ctx);
}
// ===================== 语义分析入口函数 =====================
void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx) {
if (!ctx) {
throw std::invalid_argument("CompUnitContext is null");
}
SemaVisitor visitor(ir_ctx);
visitor.visit(ctx);
}
IRGenContext RunSema(SysYParser::CompUnitContext& ctx) {
IRGenContext ctx_obj;
RunSemanticAnalysis(&ctx, ctx_obj);
return ctx_obj;
}

@ -1,17 +1,164 @@
// 维护局部变量声明的注册与查找。
#include "../../include/sem/SymbolTable.h"
#include <stdexcept>
#include <string>
#include <iostream>
#include "sem/SymbolTable.h"
// 进入新作用域
void SymbolTable::EnterScope() {
scopes_.push(ScopeEntry());
}
// 离开当前作用域
void SymbolTable::LeaveScope() {
if (scopes_.empty()) {
throw std::runtime_error("SymbolTable Error: 作用域栈为空,无法退出");
}
scopes_.pop();
}
// 绑定变量到当前作用域
void SymbolTable::BindVar(const std::string& name, const VarInfo& info, void* decl_ctx) {
if (CurrentScopeHasVar(name)) {
throw std::runtime_error("变量'" + name + "'在当前作用域重复定义");
}
scopes_.top().var_symbols[name] = {info, decl_ctx};
}
// 绑定函数到当前作用域
void SymbolTable::BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx) {
if (CurrentScopeHasFunc(name)) {
throw std::runtime_error("函数'" + name + "'在当前作用域重复定义");
}
scopes_.top().func_symbols[name] = {info, decl_ctx};
}
// 查找变量(从当前作用域向上遍历)
bool SymbolTable::LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const {
if (scopes_.empty()) {
return false;
}
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
auto& scope = temp_stack.top();
auto it = scope.var_symbols.find(name);
if (it != scope.var_symbols.end()) {
out_info = it->second.first;
out_decl_ctx = it->second.second;
return true;
}
temp_stack.pop();
}
return false;
}
// 查找函数(从当前作用域向上遍历,通常函数在全局作用域)
bool SymbolTable::LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const {
if (scopes_.empty()) {
return false;
}
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
auto& scope = temp_stack.top();
auto it = scope.func_symbols.find(name);
if (it != scope.func_symbols.end()) {
out_info = it->second.first;
out_decl_ctx = it->second.second;
return true;
}
temp_stack.pop();
}
return false;
}
void SymbolTable::Add(const std::string& name,
SysYParser::VarDefContext* decl) {
table_[name] = decl;
// 检查当前作用域是否包含指定变量
bool SymbolTable::CurrentScopeHasVar(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.top().var_symbols.count(name) > 0;
}
bool SymbolTable::Contains(const std::string& name) const {
return table_.find(name) != table_.end();
// 检查当前作用域是否包含指定函数
bool SymbolTable::CurrentScopeHasFunc(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.top().func_symbols.count(name) > 0;
}
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
auto it = table_.find(name);
return it == table_.end() ? nullptr : it->second;
// 进入循环
void SymbolTable::EnterLoop() {
loop_depth_++;
}
// 离开循环
void SymbolTable::ExitLoop() {
if (loop_depth_ > 0) loop_depth_--;
}
// 检查是否在循环内
bool SymbolTable::InLoop() const {
return loop_depth_ > 0;
}
// 清空所有作用域和状态
void SymbolTable::Clear() {
while (!scopes_.empty()) {
scopes_.pop();
}
loop_depth_ = 0;
}
// 获取当前作用域中所有变量名
std::vector<std::string> SymbolTable::GetCurrentScopeVarNames() const {
std::vector<std::string> names;
if (!scopes_.empty()) {
for (const auto& pair : scopes_.top().var_symbols) {
names.push_back(pair.first);
}
}
return names;
}
// 获取当前作用域中所有函数名
std::vector<std::string> SymbolTable::GetCurrentScopeFuncNames() const {
std::vector<std::string> names;
if (!scopes_.empty()) {
for (const auto& pair : scopes_.top().func_symbols) {
names.push_back(pair.first);
}
}
return names;
}
// 调试:打印符号表内容
void SymbolTable::Dump() const {
std::cout << "符号表内容 (作用域深度: " << scopes_.size() << "):\n";
int scope_idx = 0;
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
std::cout << "\n作用域 " << scope_idx++ << ":\n";
auto& scope = temp_stack.top();
std::cout << " 变量:\n";
for (const auto& var_pair : scope.var_symbols) {
const VarInfo& info = var_pair.second.first;
std::cout << " " << var_pair.first << ": "
<< SymbolTypeToString(info.type)
<< (info.is_const ? " (const)" : "")
<< (info.IsArray() ? " [数组]" : "")
<< "\n";
}
std::cout << " 函数:\n";
for (const auto& func_pair : scope.func_symbols) {
const FuncInfo& info = func_pair.second.first;
std::cout << " " << func_pair.first << ": "
<< SymbolTypeToString(info.ret_type) << " ("
<< info.param_types.size() << " 个参数)\n";
}
temp_stack.pop();
}
}
Loading…
Cancel
Save