Shrink: Compile pass with IRGen fixed

Pomelo
Shrink 2 weeks ago
parent 477720eb5e
commit 04a29b2bf9

@ -29,13 +29,22 @@ class IRGenImpl final : public SysYBaseVisitor {
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
std::any visitExp(SysYParser::ExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLValue(SysYParser::LValueContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
private:
enum class BlockFlow {
@ -43,8 +52,16 @@ class IRGenImpl final : public SysYBaseVisitor {
Terminated,
};
struct LoopTargets {
ir::BasicBlock* continue_target;
ir::BasicBlock* break_target;
};
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::Value* EvalCond(SysYParser::CondContext& cond);
ir::Value* ToBoolValue(ir::Value* v);
std::string NextBlockName();
ir::Module& module_;
const SemanticContext& sema_;
@ -52,6 +69,8 @@ class IRGenImpl final : public SysYBaseVisitor {
ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
std::unordered_map<std::string, ir::Value*> named_storage_;
std::vector<LoopTargets> loop_stack_;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

@ -1,30 +1,213 @@
// 基于语法树的语义检查与名称绑定。
#pragma once
#ifndef SEMANTIC_ANALYSIS_H
#define SEMANTIC_ANALYSIS_H
#include "SymbolTable.h"
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include <vector>
#include <string>
#include <sstream>
#include <unordered_map>
#include <any>
#include <memory>
#include "SysYParser.h"
// 错误信息结构体
struct ErrorMsg {
std::string msg;
int line;
int column;
ErrorMsg(std::string m, int l, int c) : msg(std::move(m)), line(l), column(c) {}
};
// 前向声明
namespace antlr4 {
class ParserRuleContext;
namespace tree {
class ParseTree;
}
}
// 语义/IR生成上下文核心类
class IRGenContext {
public:
// 错误管理
void RecordError(const ErrorMsg& err) { errors_.push_back(err); }
const std::vector<ErrorMsg>& GetErrors() const { return errors_; }
bool HasError() const { return !errors_.empty(); }
void ClearErrors() { errors_.clear(); }
// 类型绑定/查询 - 使用 void* 以兼容测试代码
void SetType(void* ctx, SymbolType type) {
node_type_map_[ctx] = type;
}
SymbolType GetType(void* ctx) const {
auto it = node_type_map_.find(ctx);
return it == node_type_map_.end() ? SymbolType::TYPE_UNKNOWN : it->second;
}
// 常量值绑定/查询 - 使用 void* 以兼容测试代码
void SetConstVal(void* ctx, const std::any& val) {
const_val_map_[ctx] = val;
}
std::any GetConstVal(void* ctx) const {
auto it = const_val_map_.find(ctx);
return it == const_val_map_.end() ? std::any() : it->second;
}
// 循环状态管理
void EnterLoop() { sym_table_.EnterLoop(); }
void ExitLoop() { sym_table_.ExitLoop(); }
bool InLoop() const { return sym_table_.InLoop(); }
// 类型判断工具函数
bool IsIntType(const std::any& val) const {
return val.type() == typeid(long) || val.type() == typeid(int);
}
bool IsFloatType(const std::any& val) const {
return val.type() == typeid(double) || val.type() == typeid(float);
}
// 当前函数返回类型
SymbolType GetCurrentFuncReturnType() const {
return current_func_ret_type_;
}
void SetCurrentFuncReturnType(SymbolType type) {
current_func_ret_type_ = type;
}
// 符号表访问
SymbolTable& GetSymbolTable() { return sym_table_; }
const SymbolTable& GetSymbolTable() const { return sym_table_; }
// 作用域管理
void EnterScope() { sym_table_.EnterScope(); }
void LeaveScope() { sym_table_.LeaveScope(); }
size_t GetScopeDepth() const { return sym_table_.GetScopeDepth(); }
private:
SymbolTable sym_table_;
std::unordered_map<void*, SymbolType> node_type_map_;
std::unordered_map<void*, std::any> const_val_map_;
std::vector<ErrorMsg> errors_;
SymbolType current_func_ret_type_ = SymbolType::TYPE_UNKNOWN;
};
// 与现有 IRGen/主流程保持兼容的语义上下文占位。
class SemanticContext {
public:
void BindVarUse(SysYParser::VarContext* use,
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl;
}
void BindVarUse(const SysYParser::LValueContext* use,
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl;
}
SysYParser::VarDefContext* ResolveVarUse(
const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second;
}
SysYParser::VarDefContext* ResolveVarUse(
const SysYParser::LValueContext* use) const {
auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second;
}
private:
std::unordered_map<const SysYParser::VarContext*,
SysYParser::VarDefContext*>
var_uses_;
std::unordered_map<const SysYParser::LValueContext*,
SysYParser::VarDefContext*>
var_uses_;
};
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
// 错误信息格式化工具函数
inline std::string FormatErrMsg(const std::string& msg, int line, int col) {
std::ostringstream oss;
oss << "[行:" << line << ",列:" << col << "] " << msg;
return oss.str();
}
// 语义分析访问器 - 继承自生成的基类
class SemaVisitor : public SysYBaseVisitor {
public:
explicit SemaVisitor(IRGenContext& ctx) : ir_ctx_(ctx) {}
// 必须实现的 ANTLR4 接口
std::any visit(antlr4::tree::ParseTree* tree) override {
if (tree) {
return tree->accept(this);
}
return std::any();
}
std::any visitTerminal(antlr4::tree::TerminalNode* node) override {
return std::any();
}
std::any visitErrorNode(antlr4::tree::ErrorNode* node) override {
if (node) {
int line = node->getSymbol()->getLine();
int col = node->getSymbol()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("语法错误节点", line, col));
}
return std::any();
}
// 核心访问方法
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override;
std::any visitBtype(SysYParser::BtypeContext* ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
std::any visitConstInitValue(SysYParser::ConstInitValueContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitInitValue(SysYParser::InitValueContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override;
std::any visitFuncFParams(SysYParser::FuncFParamsContext* ctx) override;
std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitExp(SysYParser::ExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitLValue(SysYParser::LValueContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitUnaryOp(SysYParser::UnaryOpContext* ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override;
// 通用子节点访问
std::any visitChildren(antlr4::tree::ParseTree* node) override {
std::any result;
if (node) {
for (auto* child : node->children) {
if (child) {
result = child->accept(this);
}
}
}
return result;
}
// 获取上下文引用
IRGenContext& GetContext() { return ir_ctx_; }
const IRGenContext& GetContext() const { return ir_ctx_; }
private:
IRGenContext& ir_ctx_;
};
// 语义分析入口函数
void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx);
// 兼容旧流程入口。
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
#endif // SEMANTIC_ANALYSIS_H

@ -1,17 +1,201 @@
// 极简符号表:记录局部变量定义点。
#pragma once
#ifndef SYMBOL_TABLE_H
#define SYMBOL_TABLE_H
#include <any>
#include <string>
#include <vector>
#include <unordered_map>
#include <stack>
#include <utility>
#include "SysYParser.h"
// 核心类型枚举
enum class SymbolType {
TYPE_UNKNOWN, // 未知类型
TYPE_INT, // 整型
TYPE_FLOAT, // 浮点型
TYPE_VOID, // 空类型
TYPE_ARRAY, // 数组类型
TYPE_FUNCTION // 函数类型
};
// 获取类型名称字符串
inline const char* SymbolTypeToString(SymbolType type) {
switch (type) {
case SymbolType::TYPE_INT: return "int";
case SymbolType::TYPE_FLOAT: return "float";
case SymbolType::TYPE_VOID: return "void";
case SymbolType::TYPE_ARRAY: return "array";
case SymbolType::TYPE_FUNCTION: return "function";
default: return "unknown";
}
}
// 变量信息结构体
struct VarInfo {
SymbolType type = SymbolType::TYPE_UNKNOWN;
bool is_const = false;
std::any const_val;
std::vector<int> array_dims; // 数组维度,空表示非数组
void* decl_ctx = nullptr; // 关联的语法节点
// 检查是否为数组类型
bool IsArray() const { return !array_dims.empty(); }
// 获取数组元素总数
int GetArrayElementCount() const {
int count = 1;
for (int dim : array_dims) {
count *= dim;
}
return count;
}
};
// 函数信息结构体
struct FuncInfo {
SymbolType ret_type = SymbolType::TYPE_UNKNOWN;
std::string name;
std::vector<SymbolType> param_types; // 参数类型列表
void* decl_ctx = nullptr; // 关联的语法节点
// 检查参数匹配
bool CheckParams(const std::vector<SymbolType>& actual_params) const {
if (actual_params.size() != param_types.size()) {
return false;
}
for (size_t i = 0; i < param_types.size(); ++i) {
if (param_types[i] != actual_params[i] &&
param_types[i] != SymbolType::TYPE_UNKNOWN &&
actual_params[i] != SymbolType::TYPE_UNKNOWN) {
return false;
}
}
return true;
}
};
// 作用域条目结构体
struct ScopeEntry {
// 变量符号表:符号名 -> (符号信息, 声明节点)
std::unordered_map<std::string, std::pair<VarInfo, void*>> var_symbols;
// 函数符号表:符号名 -> (函数信息, 声明节点)
std::unordered_map<std::string, std::pair<FuncInfo, void*>> func_symbols;
// 清空作用域
void Clear() {
var_symbols.clear();
func_symbols.clear();
}
};
// 符号表核心类
class SymbolTable {
public:
void Add(const std::string& name, SysYParser::VarDefContext* decl);
bool Contains(const std::string& name) const;
SysYParser::VarDefContext* Lookup(const std::string& name) const;
public:
// ========== 作用域管理 ==========
// 进入新作用域
void EnterScope();
// 离开当前作用域
void LeaveScope();
// 获取当前作用域深度
size_t GetScopeDepth() const { return scopes_.size(); }
// 检查作用域栈是否为空
bool IsEmpty() const { return scopes_.empty(); }
// ========== 变量符号管理 ==========
// 检查当前作用域是否包含指定变量
bool CurrentScopeHasVar(const std::string& name) const;
// 绑定变量到当前作用域
void BindVar(const std::string& name, const VarInfo& info, void* decl_ctx);
// 查找变量(从当前作用域向上遍历)
bool LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const;
// 快速查找变量(不获取详细信息)
bool HasVar(const std::string& name) const {
VarInfo info;
void* ctx;
return LookupVar(name, info, ctx);
}
// ========== 函数符号管理 ==========
// 检查当前作用域是否包含指定函数
bool CurrentScopeHasFunc(const std::string& name) const;
// 绑定函数到当前作用域
void BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx);
// 查找函数(从当前作用域向上遍历)
bool LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const;
// 快速查找函数(不获取详细信息)
bool HasFunc(const std::string& name) const {
FuncInfo info;
void* ctx;
return LookupFunc(name, info, ctx);
}
// ========== 循环状态管理 ==========
// 进入循环
void EnterLoop();
// 离开循环
void ExitLoop();
// 检查是否在循环内
bool InLoop() const;
// 获取循环嵌套深度
int GetLoopDepth() const { return loop_depth_; }
// ========== 辅助功能 ==========
// 清空所有作用域和状态
void Clear();
// 获取当前作用域中所有变量名
std::vector<std::string> GetCurrentScopeVarNames() const;
// 获取当前作用域中所有函数名
std::vector<std::string> GetCurrentScopeFuncNames() const;
// 调试:打印符号表内容
void Dump() const;
private:
std::unordered_map<std::string, SysYParser::VarDefContext*> table_;
private:
// 作用域栈
std::stack<ScopeEntry> scopes_;
// 循环嵌套深度
int loop_depth_ = 0;
};
// 类型兼容性检查函数
inline bool IsTypeCompatible(SymbolType expected, SymbolType actual) {
if (expected == SymbolType::TYPE_UNKNOWN || actual == SymbolType::TYPE_UNKNOWN) {
return true; // 未知类型视为兼容
}
// 基本类型兼容规则
if (expected == actual) {
return true;
}
// int 可以隐式转换为 float
if (expected == SymbolType::TYPE_FLOAT && actual == SymbolType::TYPE_INT) {
return true;
}
return false;
}
#endif // SYMBOL_TABLE_H

@ -60,17 +60,29 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
if (!ctx->varDecl()) {
// 当前先忽略 constDecl 与其它声明形态。
return {};
}
return ctx->varDecl()->accept(this);
}
std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
for (auto* var_def : ctx->varDef()) {
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
}
var_def->accept(this);
}
var_def->accept(this);
return {};
}
@ -83,15 +95,16 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
}
if (!ctx->lValue()) {
if (!ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
GetLValueName(*ctx->lValue());
const std::string name = ctx->ID()->getText();
if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
named_storage_[name] = slot;
ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) {

@ -24,21 +24,62 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this));
}
ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) {
return std::any_cast<ir::Value*>(cond.accept(this));
}
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) {
if (!v) {
throw std::runtime_error(FormatError("irgen", "条件值为空"));
}
return EvalExpr(*ctx->exp());
auto* zero = builder_.CreateConstInt(0);
return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp());
}
std::string IRGenImpl::NextBlockName() {
std::string temp = module_.GetContext().NextTemp();
if (!temp.empty() && temp.front() == '%') {
return "bb" + temp.substr(1);
}
return "bb" + temp;
}
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法表达式"));
}
return ctx->addExp()->accept(this);
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
}
return ctx->lOrExp()->accept(this);
}
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法基本表达式"));
}
if (ctx->exp()) {
return EvalExpr(*ctx->exp());
}
if (ctx->number()) {
return ctx->number()->accept(this);
}
if (ctx->lValue()) {
return ctx->lValue()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式"));
}
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx || !ctx->ILITERAL()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
}
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
builder_.CreateConstInt(std::stoi(ctx->getText())));
}
// 变量使用的处理流程:
@ -47,34 +88,192 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
auto* decl = sema_.ResolveVarUse(ctx->var());
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
}
auto it = storage_map_.find(decl);
if (it == storage_map_.end()) {
const std::string name = ctx->ID()->getText();
auto it = named_storage_.find(name);
if (it == named_storage_.end()) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
FormatError("irgen", "变量声明缺少存储槽位: " + name));
}
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
}
if (ctx->primaryExp()) {
return ctx->primaryExp()->accept(this);
}
if (ctx->unaryOp() && ctx->unaryExp()) {
ir::Value* v = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->unaryOp()->SUB()) {
auto* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateSub(
zero, v, module_.GetContext().NextTemp()));
}
if (ctx->unaryOp()->ADD()) {
return v;
}
throw std::runtime_error(FormatError("irgen", "当前不支持逻辑非运算"));
}
throw std::runtime_error(FormatError("irgen", "当前不支持函数调用表达式"));
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
if (ctx->mulExp()) {
if (!ctx->unaryExp()) {
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->MUL()) {
return static_cast<ir::Value*>(
builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->DIV()) {
return static_cast<ir::Value*>(
builder_.CreateDiv(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->MOD()) {
return static_cast<ir::Value*>(
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
if (ctx->unaryExp()) {
return ctx->unaryExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
return static_cast<ir::Value*>(
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
module_.GetContext().NextTemp()));
if (ctx->addExp()) {
if (!ctx->mulExp()) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
if (ctx->ADD()) {
return static_cast<ir::Value*>(
builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->SUB()) {
return static_cast<ir::Value*>(
builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
if (ctx->mulExp()) {
return ctx->mulExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
if (ctx->relExp()) {
if (!ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
if (ctx->LT()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Lt, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->LE()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Le, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->GT()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Gt, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->GE()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Ge, lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
if (ctx->addExp()) {
return ctx->addExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
if (ctx->eqExp()) {
if (!ctx->relExp()) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
if (ctx->EQ()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Eq, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->NE()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Ne, lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
if (ctx->relExp()) {
return ctx->relExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
if (ctx->lAndExp()) {
if (!ctx->eqExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
auto* lhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this)));
auto* rhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->eqExp()->accept(this)));
return static_cast<ir::Value*>(
builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->eqExp()) {
return ToBoolValue(std::any_cast<ir::Value*>(ctx->eqExp()->accept(this)));
}
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}
if (ctx->lOrExp()) {
if (!ctx->lAndExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}
auto* lhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->lOrExp()->accept(this)));
auto* rhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this)));
auto* sum = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(ToBoolValue(sum));
}
if (ctx->lAndExp()) {
return ToBoolValue(std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this)));
}
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}

@ -38,11 +38,14 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func) {
if (ctx->funcDef().empty()) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
func->accept(this);
for (auto* func : ctx->funcDef()) {
if (func) {
func->accept(this);
}
}
return {};
}
@ -79,6 +82,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
named_storage_.clear();
ctx->blockStmt()->accept(this);
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。

@ -19,9 +19,101 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
if (ctx->lValue() && ctx->ASSIGN() && ctx->exp()) {
if (!ctx->lValue()->ID()) {
throw std::runtime_error(FormatError("irgen", "赋值语句左值非法"));
}
const std::string name = ctx->lValue()->ID()->getText();
auto slot_it = named_storage_.find(name);
if (slot_it == named_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "赋值目标未声明: " + name));
}
ir::Value* rhs = EvalExpr(*ctx->exp());
builder_.CreateStore(rhs, slot_it->second);
return BlockFlow::Continue;
}
if (ctx->blockStmt()) {
ctx->blockStmt()->accept(this);
return builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()
? BlockFlow::Terminated
: BlockFlow::Continue;
}
if (ctx->IF()) {
if (!ctx->cond() || ctx->stmt().empty()) {
throw std::runtime_error(FormatError("irgen", "if 语句不完整"));
}
auto* then_bb = func_->CreateBlock(NextBlockName());
auto* merge_bb = func_->CreateBlock(NextBlockName());
auto* else_bb = ctx->ELSE() ? func_->CreateBlock(NextBlockName()) : merge_bb;
ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond()));
builder_.CreateCondBr(cond, then_bb, else_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
if (ctx->ELSE()) {
builder_.SetInsertPoint(else_bb);
auto else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (ctx->WHILE()) {
if (!ctx->cond() || ctx->stmt().empty()) {
throw std::runtime_error(FormatError("irgen", "while 语句不完整"));
}
auto* cond_bb = func_->CreateBlock(NextBlockName());
auto* body_bb = func_->CreateBlock(NextBlockName());
auto* exit_bb = func_->CreateBlock(NextBlockName());
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond()));
builder_.CreateCondBr(cond, body_bb, exit_bb);
loop_stack_.push_back({cond_bb, exit_bb});
builder_.SetInsertPoint(body_bb);
auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (body_flow != BlockFlow::Terminated) {
builder_.CreateBr(cond_bb);
}
loop_stack_.pop_back();
builder_.SetInsertPoint(exit_bb);
return BlockFlow::Continue;
}
if (ctx->BREAK()) {
if (loop_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "break 不在循环中"));
}
builder_.CreateBr(loop_stack_.back().break_target);
return BlockFlow::Terminated;
}
if (ctx->CONTINUE()) {
if (loop_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "continue 不在循环中"));
}
builder_.CreateBr(loop_stack_.back().continue_target);
return BlockFlow::Terminated;
}
if (ctx->returnStmt()) {
return ctx->returnStmt()->accept(this);
}
if (ctx->exp()) {
EvalExpr(*ctx->exp());
return BlockFlow::Continue;
}
if (ctx->SEMICOLON()) {
return BlockFlow::Continue;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}

@ -1,200 +1,224 @@
#include "sem/Sema.h"
#include <any>
#include <stdexcept>
#include <string>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
SymbolType ParseType(const std::string& text) {
if (text == "int") {
return SymbolType::TYPE_INT;
}
return lvalue.ID()->getText();
}
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func || !func->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (!func->ID() || func->ID()->getText() != "main") {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
func->accept(this);
if (!seen_return_) {
throw std::runtime_error(
FormatError("sema", "main 函数必须包含 return 语句"));
}
return {};
if (text == "float") {
return SymbolType::TYPE_FLOAT;
}
if (text == "void") {
return SymbolType::TYPE_VOID;
}
return SymbolType::TYPE_UNKNOWN;
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
}
const auto& items = ctx->blockStmt()->blockItem();
if (items.empty()) {
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
}
ctx->blockStmt()->accept(this);
return {};
} // namespace
std::any SemaVisitor::visitCompUnit(SysYParser::CompUnitContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitDecl(SysYParser::DeclContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitBtype(SysYParser::BtypeContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitConstDef(SysYParser::ConstDefContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitConstInitValue(SysYParser::ConstInitValueContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitVarDecl(SysYParser::VarDeclContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitVarDef(SysYParser::VarDefContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitInitValue(SysYParser::InitValueContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitFuncDef(SysYParser::FuncDefContext* ctx) {
SymbolType ret_type = SymbolType::TYPE_UNKNOWN;
if (ctx && ctx->funcType()) {
ret_type = ParseType(ctx->funcType()->getText());
}
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句块"));
}
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
current_item_index_ = i;
total_items_ = items.size();
item->accept(this);
}
ir_ctx_.SetCurrentFuncReturnType(ret_type);
return visitChildren(ctx);
}
std::any SemaVisitor::visitFuncType(SysYParser::FuncTypeContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitFuncFParams(SysYParser::FuncFParamsContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitFuncFParam(SysYParser::FuncFParamContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
ir_ctx_.EnterScope();
std::any result = visitChildren(ctx);
ir_ctx_.LeaveScope();
return result;
}
std::any SemaVisitor::visitBlockItem(SysYParser::BlockItemContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
if (ctx->WHILE()) {
ir_ctx_.EnterLoop();
std::any result = visitChildren(ctx);
ir_ctx_.ExitLoop();
return result;
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
const std::string name = GetLValueName(*var_def->lValue());
if (table_.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
}
init->exp()->accept(this);
}
table_.Add(name, var_def);
return {};
if (ctx->BREAK() && !ir_ctx_.InLoop()) {
ir_ctx_.RecordError(
ErrorMsg("break 只能出现在循环语句中", ctx->getStart()->getLine(),
ctx->getStart()->getCharPositionInLine() + 1));
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx || !ctx->returnStmt()) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
ctx->returnStmt()->accept(this);
return {};
if (ctx->CONTINUE() && !ir_ctx_.InLoop()) {
ir_ctx_.RecordError(
ErrorMsg("continue 只能出现在循环语句中", ctx->getStart()->getLine(),
ctx->getStart()->getCharPositionInLine() + 1));
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
return {};
}
return visitChildren(ctx);
}
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
}
ctx->exp()->accept(this);
std::any SemaVisitor::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
if (!ctx) {
return {};
}
std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
if (!ctx || !ctx->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
}
ctx->var()->accept(this);
return {};
if (ctx->exp() && ir_ctx_.GetCurrentFuncReturnType() == SymbolType::TYPE_VOID) {
ir_ctx_.RecordError(
ErrorMsg("void 函数不应返回表达式", ctx->getStart()->getLine(),
ctx->getStart()->getCharPositionInLine() + 1));
}
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
}
return {};
if (!ctx->exp() &&
ir_ctx_.GetCurrentFuncReturnType() != SymbolType::TYPE_VOID &&
ir_ctx_.GetCurrentFuncReturnType() != SymbolType::TYPE_UNKNOWN) {
ir_ctx_.RecordError(
ErrorMsg("非 void 函数 return 必须带表达式", ctx->getStart()->getLine(),
ctx->getStart()->getCharPositionInLine() + 1));
}
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
}
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return visitChildren(ctx);
}
std::any SemaVisitor::visitExp(SysYParser::ExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitCond(SysYParser::CondContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitLValue(SysYParser::LValueContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx) {
return {};
}
std::any visitVar(SysYParser::VarContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
const std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
sema_.BindVarUse(ctx, decl);
return {};
if (ctx->ILITERAL()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, std::any(0L));
} else if (ctx->FLITERAL()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT);
ir_ctx_.SetConstVal(ctx, std::any(0.0));
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
return {};
}
private:
SymbolTable table_;
SemanticContext sema_;
bool seen_return_ = false;
size_t current_item_index_ = 0;
size_t total_items_ = 0;
};
std::any SemaVisitor::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
return visitChildren(ctx);
}
} // namespace
std::any SemaVisitor::visitUnaryOp(SysYParser::UnaryOpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitMulExp(SysYParser::MulExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitAddExp(SysYParser::AddExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitRelExp(SysYParser::RelExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitEqExp(SysYParser::EqExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitLAndExp(SysYParser::LAndExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitLOrExp(SysYParser::LOrExpContext* ctx) {
return visitChildren(ctx);
}
std::any SemaVisitor::visitConstExp(SysYParser::ConstExpContext* ctx) {
return visitChildren(ctx);
}
void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx) {
if (!ctx) {
throw std::invalid_argument("CompUnitContext is null");
}
SemaVisitor visitor(ir_ctx);
visitor.visit(ctx);
}
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
IRGenContext ctx;
RunSemanticAnalysis(&comp_unit, ctx);
return SemanticContext();
}

@ -1,17 +1,164 @@
// 维护局部变量声明的注册与查找。
#include "../../include/sem/SymbolTable.h"
#include <stdexcept>
#include <string>
#include <iostream>
#include "sem/SymbolTable.h"
// 进入新作用域
void SymbolTable::EnterScope() {
scopes_.push(ScopeEntry());
}
// 离开当前作用域
void SymbolTable::LeaveScope() {
if (scopes_.empty()) {
throw std::runtime_error("SymbolTable Error: 作用域栈为空,无法退出");
}
scopes_.pop();
}
// 绑定变量到当前作用域
void SymbolTable::BindVar(const std::string& name, const VarInfo& info, void* decl_ctx) {
if (CurrentScopeHasVar(name)) {
throw std::runtime_error("变量'" + name + "'在当前作用域重复定义");
}
scopes_.top().var_symbols[name] = {info, decl_ctx};
}
// 绑定函数到当前作用域
void SymbolTable::BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx) {
if (CurrentScopeHasFunc(name)) {
throw std::runtime_error("函数'" + name + "'在当前作用域重复定义");
}
scopes_.top().func_symbols[name] = {info, decl_ctx};
}
// 查找变量(从当前作用域向上遍历)
bool SymbolTable::LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const {
if (scopes_.empty()) {
return false;
}
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
auto& scope = temp_stack.top();
auto it = scope.var_symbols.find(name);
if (it != scope.var_symbols.end()) {
out_info = it->second.first;
out_decl_ctx = it->second.second;
return true;
}
temp_stack.pop();
}
return false;
}
// 查找函数(从当前作用域向上遍历,通常函数在全局作用域)
bool SymbolTable::LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const {
if (scopes_.empty()) {
return false;
}
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
auto& scope = temp_stack.top();
auto it = scope.func_symbols.find(name);
if (it != scope.func_symbols.end()) {
out_info = it->second.first;
out_decl_ctx = it->second.second;
return true;
}
temp_stack.pop();
}
return false;
}
void SymbolTable::Add(const std::string& name,
SysYParser::VarDefContext* decl) {
table_[name] = decl;
// 检查当前作用域是否包含指定变量
bool SymbolTable::CurrentScopeHasVar(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.top().var_symbols.count(name) > 0;
}
bool SymbolTable::Contains(const std::string& name) const {
return table_.find(name) != table_.end();
// 检查当前作用域是否包含指定函数
bool SymbolTable::CurrentScopeHasFunc(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.top().func_symbols.count(name) > 0;
}
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
auto it = table_.find(name);
return it == table_.end() ? nullptr : it->second;
// 进入循环
void SymbolTable::EnterLoop() {
loop_depth_++;
}
// 离开循环
void SymbolTable::ExitLoop() {
if (loop_depth_ > 0) loop_depth_--;
}
// 检查是否在循环内
bool SymbolTable::InLoop() const {
return loop_depth_ > 0;
}
// 清空所有作用域和状态
void SymbolTable::Clear() {
while (!scopes_.empty()) {
scopes_.pop();
}
loop_depth_ = 0;
}
// 获取当前作用域中所有变量名
std::vector<std::string> SymbolTable::GetCurrentScopeVarNames() const {
std::vector<std::string> names;
if (!scopes_.empty()) {
for (const auto& pair : scopes_.top().var_symbols) {
names.push_back(pair.first);
}
}
return names;
}
// 获取当前作用域中所有函数名
std::vector<std::string> SymbolTable::GetCurrentScopeFuncNames() const {
std::vector<std::string> names;
if (!scopes_.empty()) {
for (const auto& pair : scopes_.top().func_symbols) {
names.push_back(pair.first);
}
}
return names;
}
// 调试:打印符号表内容
void SymbolTable::Dump() const {
std::cout << "符号表内容 (作用域深度: " << scopes_.size() << "):\n";
int scope_idx = 0;
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
std::cout << "\n作用域 " << scope_idx++ << ":\n";
auto& scope = temp_stack.top();
std::cout << " 变量:\n";
for (const auto& var_pair : scope.var_symbols) {
const VarInfo& info = var_pair.second.first;
std::cout << " " << var_pair.first << ": "
<< SymbolTypeToString(info.type)
<< (info.is_const ? " (const)" : "")
<< (info.IsArray() ? " [数组]" : "")
<< "\n";
}
std::cout << " 函数:\n";
for (const auto& func_pair : scope.func_symbols) {
const FuncInfo& info = func_pair.second.first;
std::cout << " " << func_pair.first << ": "
<< SymbolTypeToString(info.ret_type) << " ("
<< info.param_types.size() << " 个参数)\n";
}
temp_stack.pop();
}
}
Loading…
Cancel
Save