mirror 3 weeks ago
parent 7405f1327d
commit 8414298089

@ -1,31 +1,178 @@
//写这个
// 基于语法树的语义检查与名称绑定。
#pragma once
#ifndef SEMANTIC_ANALYSIS_H
#define SEMANTIC_ANALYSIS_H
#include "SymbolTable.h"
#include "../../generated/src/antlr4/SysYBaseVisitor.h"
#include <vector>
#include <string>
#include <sstream>
#include <unordered_map>
#include <any>
#include <memory>
#include "SysYParser.h"
class SemanticContext {
public:
void BindVarUse(SysYParser::VarContext* use,
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl;
}
SysYParser::VarDefContext* ResolveVarUse(
const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second;
}
private:
std::unordered_map<const SysYParser::VarContext*,
SysYParser::VarDefContext*>
var_uses_;
// 错误信息结构体
struct ErrorMsg {
std::string msg;
int line;
int column;
ErrorMsg(std::string m, int l, int c) : msg(std::move(m)), line(l), column(c) {}
};
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
// 前向声明
namespace antlr4 {
class ParserRuleContext;
namespace tree {
class ParseTree;
}
}
// 语义/IR生成上下文核心类
class IRGenContext {
public:
// 错误管理
void RecordError(const ErrorMsg& err) { errors_.push_back(err); }
const std::vector<ErrorMsg>& GetErrors() const { return errors_; }
bool HasError() const { return !errors_.empty(); }
void ClearErrors() { errors_.clear(); }
// 类型绑定/查询 - 使用 void* 以兼容测试代码
void SetType(void* ctx, SymbolType type) {
node_type_map_[ctx] = type;
}
SymbolType GetType(void* ctx) const {
auto it = node_type_map_.find(ctx);
return it == node_type_map_.end() ? SymbolType::TYPE_UNKNOWN : it->second;
}
// 常量值绑定/查询 - 使用 void* 以兼容测试代码
void SetConstVal(void* ctx, const std::any& val) {
const_val_map_[ctx] = val;
}
std::any GetConstVal(void* ctx) const {
auto it = const_val_map_.find(ctx);
return it == const_val_map_.end() ? std::any() : it->second;
}
// 循环状态管理
void EnterLoop() { sym_table_.EnterLoop(); }
void ExitLoop() { sym_table_.ExitLoop(); }
bool InLoop() const { return sym_table_.InLoop(); }
// 类型判断工具函数
bool IsIntType(const std::any& val) const {
return val.type() == typeid(long) || val.type() == typeid(int);
}
bool IsFloatType(const std::any& val) const {
return val.type() == typeid(double) || val.type() == typeid(float);
}
// 当前函数返回类型
SymbolType GetCurrentFuncReturnType() const {
return current_func_ret_type_;
}
void SetCurrentFuncReturnType(SymbolType type) {
current_func_ret_type_ = type;
}
// 符号表访问
SymbolTable& GetSymbolTable() { return sym_table_; }
const SymbolTable& GetSymbolTable() const { return sym_table_; }
// 作用域管理
void EnterScope() { sym_table_.EnterScope(); }
void LeaveScope() { sym_table_.LeaveScope(); }
size_t GetScopeDepth() const { return sym_table_.GetScopeDepth(); }
private:
SymbolTable sym_table_;
std::unordered_map<void*, SymbolType> node_type_map_;
std::unordered_map<void*, std::any> const_val_map_;
std::vector<ErrorMsg> errors_;
SymbolType current_func_ret_type_ = SymbolType::TYPE_UNKNOWN;
};
// 错误信息格式化工具函数
inline std::string FormatErrMsg(const std::string& msg, int line, int col) {
std::ostringstream oss;
oss << "[行:" << line << ",列:" << col << "] " << msg;
return oss.str();
}
// 语义分析访问器 - 继承自生成的基类
class SemaVisitor : public SysYBaseVisitor {
public:
explicit SemaVisitor(IRGenContext& ctx) : ir_ctx_(ctx) {}
// 必须实现的 ANTLR4 接口
std::any visit(antlr4::tree::ParseTree* tree) override {
if (tree) {
return tree->accept(this);
}
return std::any();
}
std::any visitTerminal(antlr4::tree::TerminalNode* node) override {
return std::any();
}
std::any visitErrorNode(antlr4::tree::ErrorNode* node) override {
if (node) {
int line = node->getSymbol()->getLine();
int col = node->getSymbol()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("语法错误节点", line, col));
}
return std::any();
}
// 核心访问方法
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
std::any visitBlock(SysYParser::BlockContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitExp(SysYParser::ExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override;
// 通用子节点访问
std::any visitChildren(antlr4::tree::ParseTree* node) override {
std::any result;
if (node) {
for (auto* child : node->children) {
if (child) {
result = child->accept(this);
}
}
}
return result;
}
// 获取上下文引用
IRGenContext& GetContext() { return ir_ctx_; }
const IRGenContext& GetContext() const { return ir_ctx_; }
private:
IRGenContext& ir_ctx_;
};
// 语义分析入口函数
void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx);
#endif // SEMANTIC_ANALYSIS_H

@ -1,18 +1,201 @@
//写这个
// 极简符号表:记录局部变量定义点。
#pragma once
#ifndef SYMBOL_TABLE_H
#define SYMBOL_TABLE_H
#include <any>
#include <string>
#include <vector>
#include <unordered_map>
#include <stack>
#include <utility>
#include "SysYParser.h"
// 核心类型枚举
enum class SymbolType {
TYPE_UNKNOWN, // 未知类型
TYPE_INT, // 整型
TYPE_FLOAT, // 浮点型
TYPE_VOID, // 空类型
TYPE_ARRAY, // 数组类型
TYPE_FUNCTION // 函数类型
};
// 获取类型名称字符串
inline const char* SymbolTypeToString(SymbolType type) {
switch (type) {
case SymbolType::TYPE_INT: return "int";
case SymbolType::TYPE_FLOAT: return "float";
case SymbolType::TYPE_VOID: return "void";
case SymbolType::TYPE_ARRAY: return "array";
case SymbolType::TYPE_FUNCTION: return "function";
default: return "unknown";
}
}
// 变量信息结构体
struct VarInfo {
SymbolType type = SymbolType::TYPE_UNKNOWN;
bool is_const = false;
std::any const_val;
std::vector<int> array_dims; // 数组维度,空表示非数组
void* decl_ctx = nullptr; // 关联的语法节点
// 检查是否为数组类型
bool IsArray() const { return !array_dims.empty(); }
// 获取数组元素总数
int GetArrayElementCount() const {
int count = 1;
for (int dim : array_dims) {
count *= dim;
}
return count;
}
};
// 函数信息结构体
struct FuncInfo {
SymbolType ret_type = SymbolType::TYPE_UNKNOWN;
std::string name;
std::vector<SymbolType> param_types; // 参数类型列表
void* decl_ctx = nullptr; // 关联的语法节点
// 检查参数匹配
bool CheckParams(const std::vector<SymbolType>& actual_params) const {
if (actual_params.size() != param_types.size()) {
return false;
}
for (size_t i = 0; i < param_types.size(); ++i) {
if (param_types[i] != actual_params[i] &&
param_types[i] != SymbolType::TYPE_UNKNOWN &&
actual_params[i] != SymbolType::TYPE_UNKNOWN) {
return false;
}
}
return true;
}
};
// 作用域条目结构体
struct ScopeEntry {
// 变量符号表:符号名 -> (符号信息, 声明节点)
std::unordered_map<std::string, std::pair<VarInfo, void*>> var_symbols;
// 函数符号表:符号名 -> (函数信息, 声明节点)
std::unordered_map<std::string, std::pair<FuncInfo, void*>> func_symbols;
// 清空作用域
void Clear() {
var_symbols.clear();
func_symbols.clear();
}
};
// 符号表核心类
class SymbolTable {
public:
void Add(const std::string& name, SysYParser::VarDefContext* decl);
bool Contains(const std::string& name) const;
SysYParser::VarDefContext* Lookup(const std::string& name) const;
public:
// ========== 作用域管理 ==========
// 进入新作用域
void EnterScope();
// 离开当前作用域
void LeaveScope();
// 获取当前作用域深度
size_t GetScopeDepth() const { return scopes_.size(); }
// 检查作用域栈是否为空
bool IsEmpty() const { return scopes_.empty(); }
// ========== 变量符号管理 ==========
// 检查当前作用域是否包含指定变量
bool CurrentScopeHasVar(const std::string& name) const;
// 绑定变量到当前作用域
void BindVar(const std::string& name, const VarInfo& info, void* decl_ctx);
// 查找变量(从当前作用域向上遍历)
bool LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const;
// 快速查找变量(不获取详细信息)
bool HasVar(const std::string& name) const {
VarInfo info;
void* ctx;
return LookupVar(name, info, ctx);
}
// ========== 函数符号管理 ==========
// 检查当前作用域是否包含指定函数
bool CurrentScopeHasFunc(const std::string& name) const;
// 绑定函数到当前作用域
void BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx);
// 查找函数(从当前作用域向上遍历)
bool LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const;
// 快速查找函数(不获取详细信息)
bool HasFunc(const std::string& name) const {
FuncInfo info;
void* ctx;
return LookupFunc(name, info, ctx);
}
// ========== 循环状态管理 ==========
// 进入循环
void EnterLoop();
// 离开循环
void ExitLoop();
// 检查是否在循环内
bool InLoop() const;
// 获取循环嵌套深度
int GetLoopDepth() const { return loop_depth_; }
// ========== 辅助功能 ==========
// 清空所有作用域和状态
void Clear();
// 获取当前作用域中所有变量名
std::vector<std::string> GetCurrentScopeVarNames() const;
// 获取当前作用域中所有函数名
std::vector<std::string> GetCurrentScopeFuncNames() const;
// 调试:打印符号表内容
void Dump() const;
private:
std::unordered_map<std::string, SysYParser::VarDefContext*> table_;
private:
// 作用域栈
std::stack<ScopeEntry> scopes_;
// 循环嵌套深度
int loop_depth_ = 0;
};
// 类型兼容性检查函数
inline bool IsTypeCompatible(SymbolType expected, SymbolType actual) {
if (expected == SymbolType::TYPE_UNKNOWN || actual == SymbolType::TYPE_UNKNOWN) {
return true; // 未知类型视为兼容
}
// 基本类型兼容规则
if (expected == actual) {
return true;
}
// int 可以隐式转换为 float
if (expected == SymbolType::TYPE_FLOAT && actual == SymbolType::TYPE_INT) {
return true;
}
return false;
}
#endif // SYMBOL_TABLE_H

@ -1,200 +1,440 @@
#include "sem/Sema.h"
#include <any>
#include "../../include/sem/Sema.h"
#include "../../generated/src/antlr4/SysYParser.h"
#include <stdexcept>
#include <string>
#include <algorithm>
#include <iostream>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
using namespace antlr4;
namespace {
// ===================== 核心访问器实现 =====================
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
}
return lvalue.ID()->getText();
// 1. 编译单元节点访问
std::any SemaVisitor::visitCompUnit(SysYParser::CompUnitContext* ctx) {
// 分析编译单元中的所有子节点
return visitChildren(ctx);
}
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func || !func->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
// 2. 函数定义节点访问
std::any SemaVisitor::visitFuncDef(SysYParser::FuncDefContext* ctx) {
FuncInfo info;
// 通过funcType()获取函数类型
if (ctx->funcType()) {
std::string func_type_text = ctx->funcType()->getText();
if (func_type_text == "void") {
info.ret_type = SymbolType::TYPE_VOID;
} else if (func_type_text == "int") {
info.ret_type = SymbolType::TYPE_INT;
} else if (func_type_text == "float") {
info.ret_type = SymbolType::TYPE_FLOAT;
}
}
if (!func->ID() || func->ID()->getText() != "main") {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
func->accept(this);
if (!seen_return_) {
throw std::runtime_error(
FormatError("sema", "main 函数必须包含 return 语句"));
// 绑定函数名和返回类型
if (ctx->Ident()) {
info.name = ctx->Ident()->getText();
}
return {};
}
ir_ctx_.SetCurrentFuncReturnType(info.ret_type);
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
// 递归分析函数体
if (ctx->block()) {
visit(ctx->block());
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
return std::any();
}
// 3. 声明节点访问
std::any SemaVisitor::visitDecl(SysYParser::DeclContext* ctx) {
return visitChildren(ctx);
}
// 4. 常量声明节点访问
std::any SemaVisitor::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
return visitChildren(ctx);
}
// 5. 变量声明节点访问
std::any SemaVisitor::visitVarDecl(SysYParser::VarDeclContext* ctx) {
return visitChildren(ctx);
}
// 6. 代码块节点访问
std::any SemaVisitor::visitBlock(SysYParser::BlockContext* ctx) {
// 进入新的作用域
ir_ctx_.EnterScope();
// 访问块内的语句
std::any result = visitChildren(ctx);
// 离开作用域
ir_ctx_.LeaveScope();
return result;
}
// 7. 语句节点访问
std::any SemaVisitor::visitStmt(SysYParser::StmtContext* ctx) {
// 赋值语句lVal = exp;
if (ctx->lVal() && ctx->exp()) {
auto l_val_ctx = ctx->lVal();
auto exp_ctx = ctx->exp();
// 解析左右值类型
SymbolType l_type = ir_ctx_.GetType(l_val_ctx);
SymbolType r_type = ir_ctx_.GetType(exp_ctx);
// 类型不匹配报错
if (l_type != r_type && l_type != SymbolType::TYPE_UNKNOWN && r_type != SymbolType::TYPE_UNKNOWN) {
std::string l_type_str = (l_type == SymbolType::TYPE_INT ? "int" : "float");
std::string r_type_str = (r_type == SymbolType::TYPE_INT ? "int" : "float");
std::string err_msg = "赋值类型不匹配,左值为" + l_type_str + ",右值为" + r_type_str;
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg(err_msg, line, col));
}
// 绑定左值类型(同步右值类型)
ir_ctx_.SetType(l_val_ctx, r_type);
}
const auto& items = ctx->blockStmt()->blockItem();
if (items.empty()) {
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
// IF语句
else if (ctx->cond() && ctx->stmt().size() >= 1) {
auto cond_ctx = ctx->cond();
// IF条件必须为整型
SymbolType cond_type = ir_ctx_.GetType(cond_ctx);
if (cond_type != SymbolType::TYPE_INT && cond_type != SymbolType::TYPE_UNKNOWN) {
int line = cond_ctx->getStart()->getLine();
int col = cond_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("if条件表达式必须为整型", line, col));
}
// 递归分析IF体和可能的ELSE体
visit(ctx->stmt(0));
if (ctx->stmt().size() >= 2) {
visit(ctx->stmt(1));
}
}
ctx->blockStmt()->accept(this);
return {};
}
// WHILE语句
else if (ctx->cond() && ctx->stmt().size() >= 1) {
ir_ctx_.EnterLoop(); // 标记进入循环
auto cond_ctx = ctx->cond();
// WHILE条件必须为整型
SymbolType cond_type = ir_ctx_.GetType(cond_ctx);
if (cond_type != SymbolType::TYPE_INT && cond_type != SymbolType::TYPE_UNKNOWN) {
int line = cond_ctx->getStart()->getLine();
int col = cond_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("while条件表达式必须为整型", line, col));
}
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句块"));
}
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
current_item_index_ = i;
total_items_ = items.size();
item->accept(this);
}
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
// 递归分析循环体
visit(ctx->stmt(0));
ir_ctx_.ExitLoop(); // 标记退出循环
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
// BREAK语句
else if (ctx->getText().find("break") != std::string::npos) {
if (!ir_ctx_.InLoop()) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("break只能出现在循环语句中", line, col));
}
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
// CONTINUE语句
else if (ctx->getText().find("continue") != std::string::npos) {
if (!ir_ctx_.InLoop()) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("continue只能出现在循环语句中", line, col));
}
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
// RETURN语句
else if (ctx->getText().find("return") != std::string::npos) {
SymbolType func_ret_type = ir_ctx_.GetCurrentFuncReturnType();
// 有返回表达式的情况
if (ctx->exp()) {
auto exp_ctx = ctx->exp();
SymbolType exp_type = ir_ctx_.GetType(exp_ctx);
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
// 返回类型不匹配报错
if (exp_type != func_ret_type && exp_type != SymbolType::TYPE_UNKNOWN && func_ret_type != SymbolType::TYPE_UNKNOWN) {
std::string ret_type_str = (func_ret_type == SymbolType::TYPE_INT ? "int" : (func_ret_type == SymbolType::TYPE_FLOAT ? "float" : "void"));
std::string exp_type_str = (exp_type == SymbolType::TYPE_INT ? "int" : "float");
std::string err_msg = "return表达式类型与函数返回类型不匹配期望" + ret_type_str + ",实际为" + exp_type_str;
int line = exp_ctx->getStart()->getLine();
int col = exp_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg(err_msg, line, col));
}
}
// 无返回表达式的情况
else {
if (func_ret_type != SymbolType::TYPE_VOID && func_ret_type != SymbolType::TYPE_UNKNOWN) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("非void函数return必须带表达式", line, col));
}
}
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
// 其他语句
return visitChildren(ctx);
}
// 8. 左值节点访问
std::any SemaVisitor::visitLVal(SysYParser::LValContext* ctx) {
return visitChildren(ctx);
}
// 9. 表达式节点访问
std::any SemaVisitor::visitExp(SysYParser::ExpContext* ctx) {
return visitChildren(ctx);
}
// 10. 条件表达式节点访问
std::any SemaVisitor::visitCond(SysYParser::CondContext* ctx) {
return visitChildren(ctx);
}
// 11. 基本表达式节点访问
std::any SemaVisitor::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
return visitChildren(ctx);
}
// 12. 一元表达式节点访问
std::any SemaVisitor::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 带一元运算符的表达式(+/-/!
if (ctx->unaryOp() && ctx->unaryExp()) {
auto op_ctx = ctx->unaryOp();
auto uexp_ctx = ctx->unaryExp();
auto uexp_val = visit(uexp_ctx);
std::string op_text = op_ctx->getText();
SymbolType uexp_type = ir_ctx_.GetType(uexp_ctx);
// 正号 +x → 直接返回原值
if (op_text == "+") {
ir_ctx_.SetType(ctx, uexp_type);
ir_ctx_.SetConstVal(ctx, uexp_val);
return uexp_val;
}
// 负号 -x → 取反
else if (op_text == "-") {
if (ir_ctx_.IsIntType(uexp_val)) {
long val = std::any_cast<long>(uexp_val);
ir_ctx_.SetConstVal(ctx, std::any(-val));
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(-val);
} else if (ir_ctx_.IsFloatType(uexp_val)) {
double val = std::any_cast<double>(uexp_val);
ir_ctx_.SetConstVal(ctx, std::any(-val));
ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT);
return std::any(-val);
}
}
// 逻辑非 !x → 0/1转换
else if (op_text == "!") {
if (ir_ctx_.IsIntType(uexp_val)) {
long val = std::any_cast<long>(uexp_val);
long res = (val == 0) ? 1L : 0L;
ir_ctx_.SetConstVal(ctx, std::any(res));
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(res);
}
}
}
const std::string name = GetLValueName(*var_def->lValue());
if (table_.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
// 函数调用表达式
else if (ctx->Ident() && ctx->funcRParams()) {
// 这里简化处理
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(0L);
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
}
init->exp()->accept(this);
// 基础表达式
else if (ctx->primaryExp()) {
auto val = visit(ctx->primaryExp());
ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->primaryExp()));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
table_.Add(name, var_def);
return {};
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx || !ctx->returnStmt()) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
ctx->returnStmt()->accept(this);
return {};
}
return std::any();
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
// 13. 乘法表达式节点访问
std::any SemaVisitor::visitMulExp(SysYParser::MulExpContext* ctx) {
auto uexps = ctx->unaryExp();
// 单操作数 → 直接返回
if (uexps.size() == 1) {
auto val = visit(uexps[0]);
ir_ctx_.SetType(ctx, ir_ctx_.GetType(uexps[0]));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
return {};
}
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
}
ctx->exp()->accept(this);
return {};
}
// 多操作数 → 依次计算
std::any result = visit(uexps[0]);
SymbolType current_type = ir_ctx_.GetType(uexps[0]);
for (size_t i = 1; i < uexps.size(); ++i) {
auto next_uexp = uexps[i];
auto next_val = visit(next_uexp);
SymbolType next_type = ir_ctx_.GetType(next_uexp);
// 类型统一int和float混合转为float
if (current_type == SymbolType::TYPE_INT && next_type == SymbolType::TYPE_FLOAT) {
current_type = SymbolType::TYPE_FLOAT;
} else if (current_type == SymbolType::TYPE_FLOAT && next_type == SymbolType::TYPE_INT) {
current_type = SymbolType::TYPE_FLOAT;
}
// 简化处理:这里假设是乘法运算
if (ir_ctx_.IsIntType(result) && ir_ctx_.IsIntType(next_val)) {
long v1 = std::any_cast<long>(result);
long v2 = std::any_cast<long>(next_val);
result = std::any(v1 * v2);
} else if (ir_ctx_.IsFloatType(result) && ir_ctx_.IsFloatType(next_val)) {
double v1 = std::any_cast<double>(result);
double v2 = std::any_cast<double>(next_val);
result = std::any(v1 * v2);
}
std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
if (!ctx || !ctx->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
// 更新当前节点类型和常量值
ir_ctx_.SetType(ctx, current_type);
ir_ctx_.SetConstVal(ctx, result);
}
ctx->var()->accept(this);
return {};
}
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
return result;
}
// 14. 加法表达式节点访问
std::any SemaVisitor::visitAddExp(SysYParser::AddExpContext* ctx) {
auto mexps = ctx->mulExp();
// 单操作数 → 直接返回
if (mexps.size() == 1) {
auto val = visit(mexps[0]);
ir_ctx_.SetType(ctx, ir_ctx_.GetType(mexps[0]));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
return {};
}
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
// 多操作数 → 依次计算
std::any result = visit(mexps[0]);
SymbolType current_type = ir_ctx_.GetType(mexps[0]);
for (size_t i = 1; i < mexps.size(); ++i) {
auto next_mexp = mexps[i];
auto next_val = visit(next_mexp);
SymbolType next_type = ir_ctx_.GetType(next_mexp);
// 类型统一
if (current_type == SymbolType::TYPE_INT && next_type == SymbolType::TYPE_FLOAT) {
current_type = SymbolType::TYPE_FLOAT;
} else if (current_type == SymbolType::TYPE_FLOAT && next_type == SymbolType::TYPE_INT) {
current_type = SymbolType::TYPE_FLOAT;
}
// 简化处理:这里假设是加法运算
if (ir_ctx_.IsIntType(result) && ir_ctx_.IsIntType(next_val)) {
long v1 = std::any_cast<long>(result);
long v2 = std::any_cast<long>(next_val);
result = std::any(v1 + v2);
} else if (ir_ctx_.IsFloatType(result) && ir_ctx_.IsFloatType(next_val)) {
double v1 = std::any_cast<double>(result);
double v2 = std::any_cast<double>(next_val);
result = std::any(v1 + v2);
}
ir_ctx_.SetType(ctx, current_type);
ir_ctx_.SetConstVal(ctx, result);
}
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitVar(SysYParser::VarContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
return result;
}
// 15. 关系表达式节点访问
std::any SemaVisitor::visitRelExp(SysYParser::RelExpContext* ctx) {
auto aexps = ctx->addExp();
// 单操作数 → 直接返回
if (aexps.size() == 1) {
auto val = visit(aexps[0]);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return val;
}
const std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
// 多操作数 → 简化处理
std::any result = std::any(1L);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, result);
return result;
}
// 16. 相等表达式节点访问
std::any SemaVisitor::visitEqExp(SysYParser::EqExpContext* ctx) {
auto rexps = ctx->relExp();
// 单操作数 → 直接返回
if (rexps.size() == 1) {
auto val = visit(rexps[0]);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return val;
}
sema_.BindVarUse(ctx, decl);
return {};
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
// 多操作数 → 简化处理
std::any result = std::any(1L);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, result);
return result;
}
// 17. 逻辑与表达式节点访问
std::any SemaVisitor::visitLAndExp(SysYParser::LAndExpContext* ctx) {
return visitChildren(ctx);
}
// 18. 逻辑或表达式节点访问
std::any SemaVisitor::visitLOrExp(SysYParser::LOrExpContext* ctx) {
return visitChildren(ctx);
}
private:
SymbolTable table_;
SemanticContext sema_;
bool seen_return_ = false;
size_t current_item_index_ = 0;
size_t total_items_ = 0;
};
// 19. 常量表达式节点访问
std::any SemaVisitor::visitConstExp(SysYParser::ConstExpContext* ctx) {
return visitChildren(ctx);
}
} // namespace
// 20. 数字节点访问
std::any SemaVisitor::visitNumber(SysYParser::NumberContext* ctx) {
// 这里简化处理,实际需要解析整型和浮点型
if (ctx->IntConst()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, std::any(0L));
return std::any(0L);
} else if (ctx->FloatConst()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT);
ir_ctx_.SetConstVal(ctx, std::any(0.0));
return std::any(0.0);
}
return std::any();
}
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
// 21. 函数参数节点访问
std::any SemaVisitor::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
return visitChildren(ctx);
}
// ===================== 语义分析入口函数 =====================
void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx) {
if (!ctx) {
throw std::invalid_argument("CompUnitContext is null");
}
SemaVisitor visitor(ir_ctx);
visitor.visit(ctx);
}

@ -1,17 +1,164 @@
// 维护局部变量声明的注册与查找。
#include "../../include/sem/SymbolTable.h"
#include <stdexcept>
#include <string>
#include <iostream>
#include "sem/SymbolTable.h"
// 进入新作用域
void SymbolTable::EnterScope() {
scopes_.push(ScopeEntry());
}
// 离开当前作用域
void SymbolTable::LeaveScope() {
if (scopes_.empty()) {
throw std::runtime_error("SymbolTable Error: 作用域栈为空,无法退出");
}
scopes_.pop();
}
// 绑定变量到当前作用域
void SymbolTable::BindVar(const std::string& name, const VarInfo& info, void* decl_ctx) {
if (CurrentScopeHasVar(name)) {
throw std::runtime_error("变量'" + name + "'在当前作用域重复定义");
}
scopes_.top().var_symbols[name] = {info, decl_ctx};
}
// 绑定函数到当前作用域
void SymbolTable::BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx) {
if (CurrentScopeHasFunc(name)) {
throw std::runtime_error("函数'" + name + "'在当前作用域重复定义");
}
scopes_.top().func_symbols[name] = {info, decl_ctx};
}
// 查找变量(从当前作用域向上遍历)
bool SymbolTable::LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const {
if (scopes_.empty()) {
return false;
}
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
auto& scope = temp_stack.top();
auto it = scope.var_symbols.find(name);
if (it != scope.var_symbols.end()) {
out_info = it->second.first;
out_decl_ctx = it->second.second;
return true;
}
temp_stack.pop();
}
return false;
}
// 查找函数(从当前作用域向上遍历,通常函数在全局作用域)
bool SymbolTable::LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const {
if (scopes_.empty()) {
return false;
}
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
auto& scope = temp_stack.top();
auto it = scope.func_symbols.find(name);
if (it != scope.func_symbols.end()) {
out_info = it->second.first;
out_decl_ctx = it->second.second;
return true;
}
temp_stack.pop();
}
return false;
}
void SymbolTable::Add(const std::string& name,
SysYParser::VarDefContext* decl) {
table_[name] = decl;
// 检查当前作用域是否包含指定变量
bool SymbolTable::CurrentScopeHasVar(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.top().var_symbols.count(name) > 0;
}
bool SymbolTable::Contains(const std::string& name) const {
return table_.find(name) != table_.end();
// 检查当前作用域是否包含指定函数
bool SymbolTable::CurrentScopeHasFunc(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.top().func_symbols.count(name) > 0;
}
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
auto it = table_.find(name);
return it == table_.end() ? nullptr : it->second;
// 进入循环
void SymbolTable::EnterLoop() {
loop_depth_++;
}
// 离开循环
void SymbolTable::ExitLoop() {
if (loop_depth_ > 0) loop_depth_--;
}
// 检查是否在循环内
bool SymbolTable::InLoop() const {
return loop_depth_ > 0;
}
// 清空所有作用域和状态
void SymbolTable::Clear() {
while (!scopes_.empty()) {
scopes_.pop();
}
loop_depth_ = 0;
}
// 获取当前作用域中所有变量名
std::vector<std::string> SymbolTable::GetCurrentScopeVarNames() const {
std::vector<std::string> names;
if (!scopes_.empty()) {
for (const auto& pair : scopes_.top().var_symbols) {
names.push_back(pair.first);
}
}
return names;
}
// 获取当前作用域中所有函数名
std::vector<std::string> SymbolTable::GetCurrentScopeFuncNames() const {
std::vector<std::string> names;
if (!scopes_.empty()) {
for (const auto& pair : scopes_.top().func_symbols) {
names.push_back(pair.first);
}
}
return names;
}
// 调试:打印符号表内容
void SymbolTable::Dump() const {
std::cout << "符号表内容 (作用域深度: " << scopes_.size() << "):\n";
int scope_idx = 0;
auto temp_stack = scopes_;
while (!temp_stack.empty()) {
std::cout << "\n作用域 " << scope_idx++ << ":\n";
auto& scope = temp_stack.top();
std::cout << " 变量:\n";
for (const auto& var_pair : scope.var_symbols) {
const VarInfo& info = var_pair.second.first;
std::cout << " " << var_pair.first << ": "
<< SymbolTypeToString(info.type)
<< (info.is_const ? " (const)" : "")
<< (info.IsArray() ? " [数组]" : "")
<< "\n";
}
std::cout << " 函数:\n";
for (const auto& func_pair : scope.func_symbols) {
const FuncInfo& info = func_pair.second.first;
std::cout << " " << func_pair.first << ": "
<< SymbolTypeToString(info.ret_type) << " ("
<< info.param_types.size() << " 个参数)\n";
}
temp_stack.pop();
}
}
Loading…
Cancel
Save