You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1417 lines
60 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "sem/Sema.h"
#include <any>
#include <stdexcept>
#include <string>
#include <sstream>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
namespace {
// 获取左值名称的辅助函数
std::string GetLValueName(SysYParser::LValContext& lval) {
if (!lval.Ident()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
}
return lval.Ident()->getText();
}
// 从 BTypeContext 获取类型
std::shared_ptr<ir::Type> GetTypeFromBType(SysYParser::BTypeContext* ctx) {
if (!ctx) return ir::Type::GetInt32Type();
if (ctx->Int()) return ir::Type::GetInt32Type();
if (ctx->Float()) return ir::Type::GetFloatType();
return ir::Type::GetInt32Type();
}
// 语义分析 Visitor
class SemaVisitor final : public SysYBaseVisitor {
public:
SemaVisitor() : table_() {}
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
table_.enterScope(); // 创建全局作用域
for (auto* func : ctx->funcDef()) { // 收集所有函数声明(处理互相调用)
CollectFunctionDeclaration(func);
}
for (auto* decl : ctx->decl()) { // 处理所有声明和定义
if (decl) decl->accept(this);
}
for (auto* func : ctx->funcDef()) {
if (func) func->accept(this);
}
CheckMainFunction(); // 检查 main 函数存在且正确
table_.exitScope(); // 退出全局作用域
return {};
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "函数定义缺少标识符"));
}
std::string name = ctx->Ident()->getText();
std::shared_ptr<ir::Type> return_type; // 获取返回类型
if (ctx->funcType()) {
if (ctx->funcType()->Void()) {
return_type = ir::Type::GetVoidType();
} else if (ctx->funcType()->Int()) {
return_type = ir::Type::GetInt32Type();
} else if (ctx->funcType()->Float()) {
return_type = ir::Type::GetFloatType();
} else {
return_type = ir::Type::GetInt32Type();
}
} else {
return_type = ir::Type::GetInt32Type();
}
std::cout << "[DEBUG] 进入函数: " << name
<< " 返回类型: " << (return_type->IsInt32() ? "int" :
return_type->IsFloat() ? "float" : "void")
<< std::endl;
// 记录当前函数返回类型(用于 return 检查)
current_func_return_type_ = return_type;
current_func_has_return_ = false;
table_.enterScope();
if (ctx->funcFParams()) { // 处理参数
CollectFunctionParams(ctx->funcFParams());
}
if (ctx->block()) { // 处理函数体
ctx->block()->accept(this);
}
std::cout << "[DEBUG] 函数 " << name
<< " has_return: " << current_func_has_return_
<< " return_type_is_void: " << return_type->IsVoid()
<< std::endl;
if (!return_type->IsVoid() && !current_func_has_return_) { // 检查非 void 函数是否有 return
throw std::runtime_error(FormatError("sema", "非 void 函数 " + name + " 缺少 return 语句"));
}
table_.exitScope();
current_func_return_type_ = nullptr;
current_func_has_return_ = false;
return {};
}
std::any visitBlock(SysYParser::BlockContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句块"));
}
table_.enterScope();
for (auto* item : ctx->blockItem()) { // 处理所有 blockItem
if (item) {
item->accept(this);
// 如果已经有 return可以继续但 return 必须是最后一条)
// 注意:这里不需要跳出,因为 return 语句本身已经标记了
}
}
table_.exitScope();
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
if (ctx->constDecl()) {
ctx->constDecl()->accept(this);
} else if (ctx->varDecl()) {
ctx->varDecl()->accept(this);
}
return {};
}
// ==================== 变量声明 ====================
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
std::shared_ptr<ir::Type> base_type = GetTypeFromBType(ctx->bType());
bool is_global = (table_.currentScopeLevel() == 0);
for (auto* var_def : ctx->varDef()) {
if (var_def) {
CheckVarDef(var_def, base_type, is_global);
}
}
return {};
}
void CheckVarDef(SysYParser::VarDefContext* ctx,
std::shared_ptr<ir::Type> base_type,
bool is_global) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法变量定义"));
}
std::string name = ctx->Ident()->getText();
if (table_.lookupCurrent(name)) { // 检查重复定义
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
// 确定类型(处理数组维度)
std::shared_ptr<ir::Type> type = base_type;
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
// 调试输出
std::cout << "[DEBUG] CheckVarDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
// 处理数组维度
for (auto* dim_exp : ctx->constExp()) {
int dim = EvaluateConstExp(dim_exp);
if (dim <= 0) {
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
// 创建数组类型
type = ir::Type::GetArrayType(base_type, dims);
std::cout << "[DEBUG] 创建数组类型完成" << std::endl;
std::cout << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl;
std::cout << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl;
// 验证数组类型
if (type->IsArray()) {
auto* arr_type = dynamic_cast<ir::ArrayType*>(type.get());
if (arr_type) {
std::cout << "[DEBUG] ArrayType dimensions: ";
for (int d : arr_type->GetDimensions()) {
std::cout << d << " ";
}
std::cout << std::endl;
std::cout << "[DEBUG] Element type: "
<< (arr_type->GetElementType()->IsInt32() ? "int" :
arr_type->GetElementType()->IsFloat() ? "float" : "unknown")
<< std::endl;
}
}
}
bool has_init = (ctx->initVal() != nullptr); // 处理初始化
if (is_global && has_init) {
CheckGlobalInitIsConst(ctx->initVal()); // 全局变量初始化必须是常量表达式
}
// 创建符号
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Variable;
sym.type = type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = has_init;
sym.var_def_ctx = ctx;
if (is_array) {
// 存储维度信息,但 param_types 通常用于函数参数
// 数组变量的维度信息已经包含在 type 中
sym.param_types.clear(); // 确保不混淆
}
table_.addSymbol(sym); // 添加到符号表
std::cout << "[DEBUG] 符号添加完成: " << name
<< " type_kind: " << (int)sym.type->GetKind()
<< " is_array: " << sym.type->IsArray()
<< std::endl;
}
// ==================== 常量声明 ====================
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法常量声明"));
}
std::shared_ptr<ir::Type> base_type = GetTypeFromBType(ctx->bType());
for (auto* const_def : ctx->constDef()) {
if (const_def) {
CheckConstDef(const_def, base_type);
}
}
return {};
}
void CheckConstDef(SysYParser::ConstDefContext* ctx,
std::shared_ptr<ir::Type> base_type) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法常量定义"));
}
std::string name = ctx->Ident()->getText();
if (table_.lookupCurrent(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
// 确定类型
std::shared_ptr<ir::Type> type = base_type;
std::vector<int> dims;
bool is_array = !ctx->constExp().empty();
std::cout << "[DEBUG] CheckConstDef: " << name
<< " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown")
<< " is_array: " << is_array
<< " dim_count: " << ctx->constExp().size() << std::endl;
if (is_array) {
for (auto* dim_exp : ctx->constExp()) {
int dim = EvaluateConstExp(dim_exp);
if (dim <= 0) {
throw std::runtime_error(FormatError("sema", "数组维度必须为正整数"));
}
dims.push_back(dim);
std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl;
}
type = ir::Type::GetArrayType(base_type, dims);
std::cout << "[DEBUG] 创建数组类型完成IsArray: " << type->IsArray() << std::endl;
}
// 求值初始化器
std::vector<ConstValue> init_values;
if (ctx->constInitVal()) {
init_values = EvaluateConstInitVal(ctx->constInitVal(), dims, base_type);
std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl;
}
// 检查初始化值数量
size_t expected_count = 1;
if (is_array) {
expected_count = 1;
for (int d : dims) expected_count *= d;
std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl;
}
if (init_values.size() > expected_count) {
throw std::runtime_error(FormatError("sema", "初始化值过多"));
}
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Constant;
sym.type = type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = true;
sym.var_def_ctx = nullptr;
// 存储常量值(仅对非数组有效)
if (!is_array && !init_values.empty()) {
if (base_type->IsInt32() && init_values[0].is_int) {
sym.is_int_const = true;
sym.const_value.i32 = init_values[0].int_val;
std::cout << "[DEBUG] 存储整型常量值: " << init_values[0].int_val << std::endl;
} else if (base_type->IsFloat() && !init_values[0].is_int) {
sym.is_int_const = false;
sym.const_value.f32 = init_values[0].float_val;
std::cout << "[DEBUG] 存储浮点常量值: " << init_values[0].float_val << std::endl;
}
} else if (is_array) {
std::cout << "[DEBUG] 数组常量,不存储单个常量值" << std::endl;
}
table_.addSymbol(sym);
std::cout << "[DEBUG] 常量符号添加完成" << std::endl;
}
// ==================== 语句语义检查 ====================
// 处理所有语句 - 通过运行时类型判断
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) return {};
// 调试输出
std::cout << "[DEBUG] visitStmt: ";
if (ctx->Return()) std::cout << "Return ";
if (ctx->If()) std::cout << "If ";
if (ctx->While()) std::cout << "While ";
if (ctx->Break()) std::cout << "Break ";
if (ctx->Continue()) std::cout << "Continue ";
if (ctx->lVal() && ctx->Assign()) std::cout << "Assign ";
if (ctx->exp() && ctx->Semi()) std::cout << "ExpStmt ";
if (ctx->block()) std::cout << "Block ";
std::cout << std::endl;
// 判断语句类型 - 注意Return() 返回的是 TerminalNode*
if (ctx->Return() != nullptr) {
// return 语句
std::cout << "[DEBUG] 检测到 return 语句" << std::endl;
return visitReturnStmtInternal(ctx);
} else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) {
// 赋值语句
return visitAssignStmt(ctx);
} else if (ctx->exp() != nullptr && ctx->Semi() != nullptr) {
// 表达式语句(可能有表达式)
return visitExpStmt(ctx);
} else if (ctx->block() != nullptr) {
// 块语句
return ctx->block()->accept(this);
} else if (ctx->If() != nullptr) {
// if 语句
return visitIfStmtInternal(ctx);
} else if (ctx->While() != nullptr) {
// while 语句
return visitWhileStmtInternal(ctx);
} else if (ctx->Break() != nullptr) {
// break 语句
return visitBreakStmtInternal(ctx);
} else if (ctx->Continue() != nullptr) {
// continue 语句
return visitContinueStmtInternal(ctx);
}
return {};
}
// return 语句内部实现
std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) {
std::cout << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl;
std::shared_ptr<ir::Type> expected = current_func_return_type_;
if (!expected) {
throw std::runtime_error(FormatError("sema", "return 语句不在函数体内"));
}
if (ctx->exp() != nullptr) {
// 有返回值的 return
std::cout << "[DEBUG] 有返回值的 return" << std::endl;
ExprInfo ret_val = CheckExp(ctx->exp());
if (expected->IsVoid()) {
throw std::runtime_error(FormatError("sema", "void 函数不能返回值"));
} else if (!IsTypeCompatible(ret_val.type, expected)) {
throw std::runtime_error(FormatError("sema", "返回值类型不匹配"));
}
// 标记需要隐式转换
if (ret_val.type != expected) {
sema_.AddConversion(ctx->exp(), ret_val.type, expected);
}
// 设置 has_return 标志
current_func_has_return_ = true;
std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl;
} else {
// 无返回值的 return
std::cout << "[DEBUG] 无返回值的 return" << std::endl;
if (!expected->IsVoid()) {
throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值"));
}
// 设置 has_return 标志
current_func_has_return_ = true;
std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl;
}
return {};
}
// 左值表达式(变量引用)
std::any visitLVal(SysYParser::LValContext* ctx) override {
std::cout << "[DEBUG] visitLVal: " << ctx->getText() << std::endl;
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
std::string name = ctx->Ident()->getText();
auto* sym = table_.lookup(name);
if (!sym) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
// 检查数组访问
bool is_array_access = !ctx->exp().empty();
std::cout << "[DEBUG] name: " << name
<< ", is_array_access: " << is_array_access
<< ", subscript_count: " << ctx->exp().size() << std::endl;
ExprInfo result;
// 判断是否为数组类型或指针类型(数组参数)
bool is_array_or_ptr = false;
if (sym->type) {
is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat();
std::cout << "[DEBUG] type_kind: " << (int)sym->type->GetKind()
<< ", is_array: " << sym->type->IsArray()
<< ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) << std::endl;
}
if (is_array_or_ptr) {
// 获取维度信息
size_t dim_count = 0;
std::shared_ptr<ir::Type> elem_type = sym->type;
if (sym->type->IsArray()) {
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
dim_count = arr_type->GetDimensions().size();
elem_type = arr_type->GetElementType();
std::cout << "[DEBUG] 数组维度: " << dim_count << std::endl;
}
} else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) {
dim_count = 1;
if (sym->type->IsPtrInt32()) {
elem_type = ir::Type::GetInt32Type();
} else if (sym->type->IsPtrFloat()) {
elem_type = ir::Type::GetFloatType();
}
std::cout << "[DEBUG] 指针类型, dim_count: 1" << std::endl;
}
if (is_array_access) {
std::cout << "[DEBUG] 有下标访问,期望维度: " << dim_count
<< ", 实际下标数: " << ctx->exp().size() << std::endl;
if (ctx->exp().size() != dim_count) {
throw std::runtime_error(FormatError("sema", "数组下标个数不匹配"));
}
for (auto* idx_exp : ctx->exp()) {
ExprInfo idx = CheckExp(idx_exp);
if (!idx.type->IsInt32()) {
throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型"));
}
}
result.type = elem_type;
result.is_lvalue = true;
result.is_const = false;
} else {
std::cout << "[DEBUG] 无下标访问" << std::endl;
if (sym->type->IsArray()) {
std::cout << "[DEBUG] 数组名作为地址,转换为指针" << std::endl;
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
if (arr_type->GetElementType()->IsInt32()) {
result.type = ir::Type::GetPtrInt32Type();
} else if (arr_type->GetElementType()->IsFloat()) {
result.type = ir::Type::GetPtrFloatType();
} else {
result.type = ir::Type::GetPtrInt32Type();
}
} else {
result.type = ir::Type::GetPtrInt32Type();
}
result.is_lvalue = false;
result.is_const = true;
} else {
result.type = sym->type;
result.is_lvalue = true;
result.is_const = (sym->kind == SymbolKind::Constant);
}
}
} else {
if (is_array_access) {
throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name));
}
result.type = sym->type;
result.is_lvalue = true;
result.is_const = (sym->kind == SymbolKind::Constant);
if (result.is_const && sym->type && !sym->type->IsArray()) {
if (sym->is_int_const) {
result.is_const_int = true;
result.const_int_value = sym->const_value.i32;
} else {
result.const_float_value = sym->const_value.f32;
}
}
}
sema_.SetExprType(ctx, result);
return {};
}
// if 语句内部实现
std::any visitIfStmtInternal(SysYParser::StmtContext* ctx) {
// 检查条件表达式
if (ctx->cond()) {
ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换
// 不需要额外检查,因为 CheckCond 已经确保类型正确
}
// 处理 then 分支
if (ctx->stmt().size() > 0) {
ctx->stmt()[0]->accept(this);
}
// 处理 else 分支
if (ctx->stmt().size() > 1) {
ctx->stmt()[1]->accept(this);
}
return {};
}
// while 语句内部实现
std::any visitWhileStmtInternal(SysYParser::StmtContext* ctx) {
if (ctx->cond()) {
ExprInfo cond = CheckCond(ctx->cond()); // CheckCond 已经处理了类型转换
// 不需要额外检查
}
loop_stack_.push_back({true, ctx});
if (ctx->stmt().size() > 0) {
ctx->stmt()[0]->accept(this);
}
loop_stack_.pop_back();
return {};
}
// break 语句内部实现
std::any visitBreakStmtInternal(SysYParser::StmtContext* ctx) {
if (loop_stack_.empty() || !loop_stack_.back().in_loop) {
throw std::runtime_error(FormatError("sema", "break 语句必须在循环体内使用"));
}
return {};
}
// continue 语句内部实现
std::any visitContinueStmtInternal(SysYParser::StmtContext* ctx) {
if (loop_stack_.empty() || !loop_stack_.back().in_loop) {
throw std::runtime_error(FormatError("sema", "continue 语句必须在循环体内使用"));
}
return {};
}
// 赋值语句内部实现
std::any visitAssignStmt(SysYParser::StmtContext* ctx) {
if (!ctx->lVal() || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法赋值语句"));
}
ExprInfo lvalue = CheckLValue(ctx->lVal()); // 检查左值
if (lvalue.is_const) {
throw std::runtime_error(FormatError("sema", "不能给常量赋值"));
}
if (!lvalue.is_lvalue) {
throw std::runtime_error(FormatError("sema", "赋值左边必须是左值"));
}
ExprInfo rvalue = CheckExp(ctx->exp()); // 检查右值
if (!IsTypeCompatible(rvalue.type, lvalue.type)) {
throw std::runtime_error(FormatError("sema", "赋值类型不匹配"));
}
if (rvalue.type != lvalue.type) { // 标记需要隐式转换
sema_.AddConversion(ctx->exp(), rvalue.type, lvalue.type);
}
return {};
}
// 表达式语句内部实现
std::any visitExpStmt(SysYParser::StmtContext* ctx) {
if (ctx->exp()) {
CheckExp(ctx->exp());
}
return {};
}
// ==================== 表达式类型推导 ====================
// 主表达式
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
std::cout << "[DEBUG] visitPrimaryExp: " << ctx->getText() << std::endl;
ExprInfo result;
if (ctx->lVal()) { // 左值表达式
result = CheckLValue(ctx->lVal());
result.is_lvalue = true;
} else if (ctx->HEX_FLOAT() || ctx->DEC_FLOAT()) { // 浮点字面量
result.type = ir::Type::GetFloatType();
result.is_const = true;
result.is_const_int = false;
std::string text;
if (ctx->HEX_FLOAT()) text = ctx->HEX_FLOAT()->getText();
else text = ctx->DEC_FLOAT()->getText();
result.const_float_value = std::stof(text);
} else if (ctx->HEX_INT() || ctx->OCTAL_INT() || ctx->DECIMAL_INT() || ctx->ZERO()) { // 整数字面量
result.type = ir::Type::GetInt32Type();
result.is_const = true;
result.is_const_int = true;
std::string text;
if (ctx->HEX_INT()) text = ctx->HEX_INT()->getText();
else if (ctx->OCTAL_INT()) text = ctx->OCTAL_INT()->getText();
else if (ctx->DECIMAL_INT()) text = ctx->DECIMAL_INT()->getText();
else text = ctx->ZERO()->getText();
result.const_int_value = std::stoi(text, nullptr, 0);
} else if (ctx->exp()) { // 括号表达式
result = CheckExp(ctx->exp());
}
sema_.SetExprType(ctx, result);
return {};
}
// 一元表达式
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
std::cout << "[DEBUG] visitUnaryExp: " << ctx->getText() << std::endl;
ExprInfo result;
if (ctx->primaryExp()) {
ctx->primaryExp()->accept(this);
auto* info = sema_.GetExprType(ctx->primaryExp());
if (info) result = *info;
} else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用
std::cout << "[DEBUG] 函数调用: " << ctx->Ident()->getText() << std::endl;
result = CheckFuncCall(ctx);
} else if (ctx->unaryOp()) { // 一元运算
ctx->unaryExp()->accept(this);
auto* operand = sema_.GetExprType(ctx->unaryExp());
if (!operand) {
throw std::runtime_error(FormatError("sema", "一元操作数类型推导失败"));
result.type = ir::Type::GetInt32Type();
} else {
std::string op = ctx->unaryOp()->getText();
if (op == "!") {
// 逻辑非:要求操作数是 int 类型,或者可以转换为 int 的 float
if (operand->type->IsInt32()) {
// 已经是 int没问题
} else if (operand->type->IsFloat()) {
// float 可以隐式转换为 int
sema_.AddConversion(ctx->unaryExp(), operand->type, ir::Type::GetInt32Type());
// 更新操作数类型为 int
operand->type = ir::Type::GetInt32Type();
operand->is_const_int = true;
if (operand->is_const && !operand->is_const_int) {
// 如果原来是 float 常量,转换为 int 常量
operand->const_int_value = (int)operand->const_float_value;
operand->is_const_int = true;
}
} else {
throw std::runtime_error(FormatError("sema", "逻辑非操作数必须是 int 类型或可以转换为 int 的 float 类型"));
}
result.type = ir::Type::GetInt32Type();
result.is_lvalue = false;
result.is_const = operand->is_const;
if (operand->is_const && operand->is_const_int) {
result.is_const_int = true;
result.const_int_value = (operand->const_int_value == 0) ? 1 : 0;
}
} else {
// 正负号
if (!operand->type->IsInt32() && !operand->type->IsFloat()) {
throw std::runtime_error(FormatError("sema", "正负号操作数必须是算术类型"));
}
result.type = operand->type;
result.is_lvalue = false;
result.is_const = operand->is_const;
if (op == "-" && operand->is_const) {
if (operand->type->IsInt32() && operand->is_const_int) {
result.is_const_int = true;
result.const_int_value = -operand->const_int_value;
} else if (operand->type->IsFloat()) {
result.const_float_value = -operand->const_float_value;
}
}
}
}
}
sema_.SetExprType(ctx, result);
return {};
}
// 乘除模表达式
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
ExprInfo result;
if (ctx->mulExp()) {
ctx->mulExp()->accept(this);
ctx->unaryExp()->accept(this);
auto* left_info = sema_.GetExprType(ctx->mulExp());
auto* right_info = sema_.GetExprType(ctx->unaryExp());
if (!left_info || !right_info) {
throw std::runtime_error(FormatError("sema", "乘除模操作数类型推导失败"));
result.type = ir::Type::GetInt32Type();
} else {
std::string op;
if (ctx->MulOp()) {
op = "*";
} else if (ctx->DivOp()) {
op = "/";
} else if (ctx->QuoOp()) {
op = "%";
}
result = CheckBinaryOp(left_info, right_info, op, ctx);
}
} else {
ctx->unaryExp()->accept(this);
auto* info = sema_.GetExprType(ctx->unaryExp());
if (info) {
sema_.SetExprType(ctx, *info);
}
return {};
}
sema_.SetExprType(ctx, result);
return {};
}
// 加减表达式
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
ExprInfo result;
if (ctx->addExp()) {
ctx->addExp()->accept(this);
ctx->mulExp()->accept(this);
auto* left_info = sema_.GetExprType(ctx->addExp());
auto* right_info = sema_.GetExprType(ctx->mulExp());
if (!left_info || !right_info) {
throw std::runtime_error(FormatError("sema", "加减操作数类型推导失败"));
result.type = ir::Type::GetInt32Type();
} else {
std::string op;
if (ctx->AddOp()) {
op = "+";
} else if (ctx->SubOp()) {
op = "-";
}
result = CheckBinaryOp(left_info, right_info, op, ctx);
}
} else {
ctx->mulExp()->accept(this);
auto* info = sema_.GetExprType(ctx->mulExp());
if (info) {
sema_.SetExprType(ctx, *info);
}
return {};
}
sema_.SetExprType(ctx, result);
return {};
}
// 关系表达式
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
ExprInfo result;
if (ctx->relExp()) {
ctx->relExp()->accept(this);
ctx->addExp()->accept(this);
auto* left_info = sema_.GetExprType(ctx->relExp());
auto* right_info = sema_.GetExprType(ctx->addExp());
if (!left_info || !right_info) {
throw std::runtime_error(FormatError("sema", "关系操作数类型推导失败"));
} else {
if (!left_info->type->IsInt32() && !left_info->type->IsFloat()) {
throw std::runtime_error(FormatError("sema", "关系运算操作数必须是算术类型"));
}
std::string op;
if (ctx->LOp()) {
op = "<";
} else if (ctx->GOp()) {
op = ">";
} else if (ctx->LeOp()) {
op = "<=";
} else if (ctx->GeOp()) {
op = ">=";
}
result.type = ir::Type::GetInt32Type();
result.is_lvalue = false;
if (left_info->is_const && right_info->is_const) {
result.is_const = true;
result.is_const_int = true;
float l = GetFloatValue(*left_info);
float r = GetFloatValue(*right_info);
if (op == "<") result.const_int_value = (l < r) ? 1 : 0;
else if (op == ">") result.const_int_value = (l > r) ? 1 : 0;
else if (op == "<=") result.const_int_value = (l <= r) ? 1 : 0;
else if (op == ">=") result.const_int_value = (l >= r) ? 1 : 0;
}
}
} else {
ctx->addExp()->accept(this);
auto* info = sema_.GetExprType(ctx->addExp());
if (info) {
sema_.SetExprType(ctx, *info);
}
return {};
}
sema_.SetExprType(ctx, result);
return {};
}
// 相等性表达式
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
ExprInfo result;
if (ctx->eqExp()) {
ctx->eqExp()->accept(this);
ctx->relExp()->accept(this);
auto* left_info = sema_.GetExprType(ctx->eqExp());
auto* right_info = sema_.GetExprType(ctx->relExp());
if (!left_info || !right_info) {
throw std::runtime_error(FormatError("sema", "相等性操作数类型推导失败"));
} else {
std::string op;
if (ctx->EqOp()) {
op = "==";
} else if (ctx->NeOp()) {
op = "!=";
}
result.type = ir::Type::GetInt32Type();
result.is_lvalue = false;
if (left_info->is_const && right_info->is_const) {
result.is_const = true;
result.is_const_int = true;
float l = GetFloatValue(*left_info);
float r = GetFloatValue(*right_info);
if (op == "==") result.const_int_value = (l == r) ? 1 : 0;
else if (op == "!=") result.const_int_value = (l != r) ? 1 : 0;
}
}
} else {
ctx->relExp()->accept(this);
auto* info = sema_.GetExprType(ctx->relExp());
if (info) {
sema_.SetExprType(ctx, *info);
}
return {};
}
sema_.SetExprType(ctx, result);
return {};
}
// 逻辑与表达式
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
ExprInfo result;
if (ctx->lAndExp()) {
ctx->lAndExp()->accept(this);
ctx->eqExp()->accept(this);
auto* left_info = sema_.GetExprType(ctx->lAndExp());
auto* right_info = sema_.GetExprType(ctx->eqExp());
if (!left_info || !right_info) {
throw std::runtime_error(FormatError("sema", "逻辑与操作数类型推导失败"));
} else {
// 处理左操作数
if (left_info->type->IsInt32()) {
// 已经是 int没问题
} else if (left_info->type->IsFloat()) {
// float 可以隐式转换为 int
sema_.AddConversion(ctx->lAndExp(), left_info->type, ir::Type::GetInt32Type());
left_info->type = ir::Type::GetInt32Type();
if (left_info->is_const && !left_info->is_const_int) {
left_info->const_int_value = (int)left_info->const_float_value;
left_info->is_const_int = true;
}
} else {
throw std::runtime_error(FormatError("sema", "逻辑与左操作数必须是 int 类型或可以转换为 int 的 float 类型"));
}
// 处理右操作数
if (right_info->type->IsInt32()) {
// 已经是 int没问题
} else if (right_info->type->IsFloat()) {
// float 可以隐式转换为 int
sema_.AddConversion(ctx->eqExp(), right_info->type, ir::Type::GetInt32Type());
right_info->type = ir::Type::GetInt32Type();
if (right_info->is_const && !right_info->is_const_int) {
right_info->const_int_value = (int)right_info->const_float_value;
right_info->is_const_int = true;
}
} else {
throw std::runtime_error(FormatError("sema", "逻辑与右操作数必须是 int 类型或可以转换为 int 的 float 类型"));
}
result.type = ir::Type::GetInt32Type();
result.is_lvalue = false;
if (left_info->is_const && right_info->is_const &&
left_info->is_const_int && right_info->is_const_int) {
result.is_const = true;
result.is_const_int = true;
result.const_int_value =
(left_info->const_int_value && right_info->const_int_value) ? 1 : 0;
}
}
} else {
ctx->eqExp()->accept(this);
auto* info = sema_.GetExprType(ctx->eqExp());
if (info) {
// 对于单个操作数,也需要确保类型是 int用于条件表达式
if (info->type->IsFloat()) {
sema_.AddConversion(ctx->eqExp(), info->type, ir::Type::GetInt32Type());
info->type = ir::Type::GetInt32Type();
if (info->is_const && !info->is_const_int) {
info->const_int_value = (int)info->const_float_value;
info->is_const_int = true;
}
} else if (!info->type->IsInt32()) {
throw std::runtime_error(FormatError("sema", "逻辑与操作数必须是 int 类型或可以转换为 int 的 float 类型"));
}
sema_.SetExprType(ctx, *info);
}
return {};
}
sema_.SetExprType(ctx, result);
return {};
}
// 逻辑或表达式
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
ExprInfo result;
if (ctx->lOrExp()) {
ctx->lOrExp()->accept(this);
ctx->lAndExp()->accept(this);
auto* left_info = sema_.GetExprType(ctx->lOrExp());
auto* right_info = sema_.GetExprType(ctx->lAndExp());
if (!left_info || !right_info) {
throw std::runtime_error(FormatError("sema", "逻辑或操作数类型推导失败"));
} else {
// 处理左操作数
if (left_info->type->IsInt32()) {
// 已经是 int没问题
} else if (left_info->type->IsFloat()) {
// float 可以隐式转换为 int
sema_.AddConversion(ctx->lOrExp(), left_info->type, ir::Type::GetInt32Type());
left_info->type = ir::Type::GetInt32Type();
if (left_info->is_const && !left_info->is_const_int) {
left_info->const_int_value = (int)left_info->const_float_value;
left_info->is_const_int = true;
}
} else {
throw std::runtime_error(FormatError("sema", "逻辑或左操作数必须是 int 类型或可以转换为 int 的 float 类型"));
}
// 处理右操作数
if (right_info->type->IsInt32()) {
// 已经是 int没问题
} else if (right_info->type->IsFloat()) {
// float 可以隐式转换为 int
sema_.AddConversion(ctx->lAndExp(), right_info->type, ir::Type::GetInt32Type());
right_info->type = ir::Type::GetInt32Type();
if (right_info->is_const && !right_info->is_const_int) {
right_info->const_int_value = (int)right_info->const_float_value;
right_info->is_const_int = true;
}
} else {
throw std::runtime_error(FormatError("sema", "逻辑或右操作数必须是 int 类型或可以转换为 int 的 float 类型"));
}
result.type = ir::Type::GetInt32Type();
result.is_lvalue = false;
if (left_info->is_const && right_info->is_const &&
left_info->is_const_int && right_info->is_const_int) {
result.is_const = true;
result.is_const_int = true;
result.const_int_value =
(left_info->const_int_value || right_info->const_int_value) ? 1 : 0;
}
}
} else {
ctx->lAndExp()->accept(this);
auto* info = sema_.GetExprType(ctx->lAndExp());
if (info) {
// 对于单个操作数,也需要确保类型是 int用于条件表达式
if (info->type->IsFloat()) {
sema_.AddConversion(ctx->lAndExp(), info->type, ir::Type::GetInt32Type());
info->type = ir::Type::GetInt32Type();
if (info->is_const && !info->is_const_int) {
info->const_int_value = (int)info->const_float_value;
info->is_const_int = true;
}
} else if (!info->type->IsInt32()) {
throw std::runtime_error(FormatError("sema", "逻辑或操作数必须是 int 类型或可以转换为 int 的 float 类型"));
}
sema_.SetExprType(ctx, *info);
}
return {};
}
sema_.SetExprType(ctx, result);
return {};
}
// 获取语义上下文
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
SymbolTable table_;
SemanticContext sema_;
struct LoopContext {
bool in_loop;
antlr4::ParserRuleContext* loop_node;
};
std::vector<LoopContext> loop_stack_;
std::shared_ptr<ir::Type> current_func_return_type_ = nullptr;
bool current_func_has_return_ = false;
// ==================== 辅助函数 ====================
ExprInfo CheckExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "无效表达式"));
}
std::cout << "[DEBUG] CheckExp: " << ctx->getText() << std::endl;
ctx->addExp()->accept(this);
auto* info = sema_.GetExprType(ctx->addExp());
if (!info) {
throw std::runtime_error(FormatError("sema", "表达式类型推导失败"));
}
ExprInfo result = *info;
sema_.SetExprType(ctx, result);
return result;
}
// 专门用于检查 AddExp 的辅助函数(用于常量表达式)
ExprInfo CheckAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "无效表达式"));
}
ctx->accept(this);
auto* info = sema_.GetExprType(ctx);
if (!info) {
throw std::runtime_error(FormatError("sema", "表达式类型推导失败"));
}
return *info;
}
ExprInfo CheckCond(SysYParser::CondContext* ctx) {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("sema", "无效条件表达式"));
}
ctx->lOrExp()->accept(this);
auto* info = sema_.GetExprType(ctx->lOrExp());
if (!info) {
throw std::runtime_error(FormatError("sema", "条件表达式类型推导失败"));
}
ExprInfo result = *info;
// 条件表达式的结果必须是 int如果是 float 则需要转换
// 注意lOrExp 已经处理了类型转换,这里只是再检查一次
if (!result.type->IsInt32()) {
throw std::runtime_error(FormatError("sema", "条件表达式必须是 int 类型"));
}
return result;
}
ExprInfo CheckLValue(SysYParser::LValContext* ctx) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
}
std::string name = ctx->Ident()->getText();
auto* sym = table_.lookup(name);
if (!sym) {
throw std::runtime_error(FormatError("sema", "未定义的变量: " + name));
}
bool is_array_access = !ctx->exp().empty();
bool is_const = (sym->kind == SymbolKind::Constant);
bool is_array_or_ptr = false;
if (sym->type) {
is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat();
}
size_t dim_count = 0;
std::shared_ptr<ir::Type> elem_type = sym->type;
if (sym->type && sym->type->IsArray()) {
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
dim_count = arr_type->GetDimensions().size();
elem_type = arr_type->GetElementType();
}
} else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) {
dim_count = 1;
if (sym->type->IsPtrInt32()) {
elem_type = ir::Type::GetInt32Type();
} else if (sym->type->IsPtrFloat()) {
elem_type = ir::Type::GetFloatType();
}
}
size_t subscript_count = ctx->exp().size();
if (is_array_or_ptr) {
if (subscript_count > 0) {
// 有下标访问
if (subscript_count != dim_count) {
throw std::runtime_error(FormatError("sema", "数组下标个数不匹配"));
}
for (auto* idx_exp : ctx->exp()) {
ExprInfo idx = CheckExp(idx_exp);
if (!idx.type->IsInt32()) {
throw std::runtime_error(FormatError("sema", "数组下标必须是 int 类型"));
}
}
return {elem_type, true, false};
} else {
// 没有下标访问
if (sym->type->IsArray()) {
// 数组名作为地址(右值)
if (auto* arr_type = dynamic_cast<ir::ArrayType*>(sym->type.get())) {
if (arr_type->GetElementType()->IsInt32()) {
return {ir::Type::GetPtrInt32Type(), false, true};
} else if (arr_type->GetElementType()->IsFloat()) {
return {ir::Type::GetPtrFloatType(), false, true};
}
}
return {ir::Type::GetPtrInt32Type(), false, true};
} else {
// 指针类型(如函数参数)可以不带下标使用
return {sym->type, true, is_const};
}
}
} else {
if (subscript_count > 0) {
throw std::runtime_error(FormatError("sema", "非数组变量不能使用下标: " + name));
}
return {sym->type, true, is_const};
}
}
ExprInfo CheckFuncCall(SysYParser::UnaryExpContext* ctx) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("sema", "非法函数调用"));
}
std::string func_name = ctx->Ident()->getText();
std::cout << "[DEBUG] CheckFuncCall: " << func_name << std::endl;
auto* func_sym = table_.lookup(func_name);
if (!func_sym || func_sym->kind != SymbolKind::Function) {
throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name));
}
std::vector<ExprInfo> args;
if (ctx->funcRParams()) {
std::cout << "[DEBUG] 处理函数调用参数:" << std::endl;
for (auto* exp : ctx->funcRParams()->exp()) {
if (exp) {
args.push_back(CheckExp(exp));
}
}
}
if (args.size() != func_sym->param_types.size()) {
throw std::runtime_error(FormatError("sema", "参数个数不匹配"));
}
for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) {
std::cout << "[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind()
<< " 形参类型 " << (int)func_sym->param_types[i]->GetKind() << std::endl;
if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) {
throw std::runtime_error(FormatError("sema", "参数类型不匹配"));
}
if (args[i].type != func_sym->param_types[i] && ctx->funcRParams() &&
i < ctx->funcRParams()->exp().size()) {
sema_.AddConversion(ctx->funcRParams()->exp()[i],
args[i].type, func_sym->param_types[i]);
}
}
std::shared_ptr<ir::Type> return_type;
if (func_sym->type && func_sym->type->IsFunction()) {
auto* func_type = dynamic_cast<ir::FunctionType*>(func_sym->type.get());
if (func_type) {
return_type = func_type->GetReturnType();
}
}
if (!return_type) {
return_type = ir::Type::GetInt32Type();
}
ExprInfo result;
result.type = return_type;
result.is_lvalue = false;
result.is_const = false;
return result;
}
ExprInfo CheckBinaryOp(const ExprInfo* left, const ExprInfo* right,
const std::string& op, antlr4::ParserRuleContext* ctx) {
ExprInfo result;
if (!left->type->IsInt32() && !left->type->IsFloat()) {
throw std::runtime_error(FormatError("sema", "左操作数必须是算术类型"));
}
if (!right->type->IsInt32() && !right->type->IsFloat()) {
throw std::runtime_error(FormatError("sema", "右操作数必须是算术类型"));
}
if (op == "%" && (!left->type->IsInt32() || !right->type->IsInt32())) {
throw std::runtime_error(FormatError("sema", "取模运算要求操作数为 int 类型"));
}
if (left->type->IsFloat() || right->type->IsFloat()) {
result.type = ir::Type::GetFloatType();
} else {
result.type = ir::Type::GetInt32Type();
}
result.is_lvalue = false;
if (left->is_const && right->is_const) {
result.is_const = true;
float l = GetFloatValue(*left);
float r = GetFloatValue(*right);
if (result.type->IsInt32()) {
result.is_const_int = true;
int li = (int)l, ri = (int)r;
if (op == "*") result.const_int_value = li * ri;
else if (op == "/") result.const_int_value = li / ri;
else if (op == "%") result.const_int_value = li % ri;
else if (op == "+") result.const_int_value = li + ri;
else if (op == "-") result.const_int_value = li - ri;
} else {
if (op == "*") result.const_float_value = l * r;
else if (op == "/") result.const_float_value = l / r;
else if (op == "+") result.const_float_value = l + r;
else if (op == "-") result.const_float_value = l - r;
}
}
return result;
}
float GetFloatValue(const ExprInfo& info) {
if (info.type->IsInt32() && info.is_const_int) {
return (float)info.const_int_value;
} else {
return info.const_float_value;
}
}
bool IsTypeCompatible(std::shared_ptr<ir::Type> src, std::shared_ptr<ir::Type> dst) {
if (src == dst) return true;
if (src->IsInt32() && dst->IsFloat()) return true;
if (src->IsFloat() && dst->IsInt32()) return true;
return false;
}
void CollectFunctionDeclaration(SysYParser::FuncDefContext* ctx) {
if (!ctx || !ctx->Ident()) return;
std::string name = ctx->Ident()->getText();
if (table_.lookup(name)) return;
std::shared_ptr<ir::Type> ret_type;
if (ctx->funcType()) {
if (ctx->funcType()->Void()) {
ret_type = ir::Type::GetVoidType();
} else if (ctx->funcType()->Int()) {
ret_type = ir::Type::GetInt32Type();
} else if (ctx->funcType()->Float()) {
ret_type = ir::Type::GetFloatType();
}
}
if (!ret_type) ret_type = ir::Type::GetInt32Type();
std::vector<std::shared_ptr<ir::Type>> param_types;
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
if (!param) continue;
std::shared_ptr<ir::Type> param_type;
if (param->bType()) {
if (param->bType()->Int()) {
param_type = ir::Type::GetInt32Type();
} else if (param->bType()->Float()) {
param_type = ir::Type::GetFloatType();
}
}
if (!param_type) param_type = ir::Type::GetInt32Type();
if (!param->L_BRACK().empty()) {
if (param_type->IsInt32()) {
param_type = ir::Type::GetPtrInt32Type();
} else if (param_type->IsFloat()) {
param_type = ir::Type::GetPtrFloatType();
}
}
param_types.push_back(param_type);
}
}
// 创建函数类型
std::shared_ptr<ir::Type> func_type = ir::Type::GetFunctionType(ret_type, param_types);
// 创建函数符号
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Function;
sym.type = func_type;
sym.param_types = param_types;
sym.scope_level = 0;
sym.is_initialized = true;
sym.func_def_ctx = ctx;
table_.addSymbol(sym);
}
void CollectFunctionParams(SysYParser::FuncFParamsContext* ctx) {
if (!ctx) return;
for (auto* param : ctx->funcFParam()) {
if (!param || !param->Ident()) continue;
std::string name = param->Ident()->getText();
if (table_.lookupCurrent(name)) {
throw std::runtime_error(FormatError("sema", "重复定义参数: " + name));
}
std::shared_ptr<ir::Type> param_type;
if (param->bType()) {
if (param->bType()->Int()) {
param_type = ir::Type::GetInt32Type();
} else if (param->bType()->Float()) {
param_type = ir::Type::GetFloatType();
}
}
if (!param_type) param_type = ir::Type::GetInt32Type();
bool is_array = !param->L_BRACK().empty();
if (is_array) {
if (param_type->IsInt32()) {
param_type = ir::Type::GetPtrInt32Type();
} else if (param_type->IsFloat()) {
param_type = ir::Type::GetPtrFloatType();
}
std::cout << "[DEBUG] 数组参数: " << name << " 类型转换为指针" << std::endl;
}
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Parameter;
sym.type = param_type;
sym.scope_level = table_.currentScopeLevel();
sym.is_initialized = true;
sym.var_def_ctx = nullptr;
table_.addSymbol(sym);
std::cout << "[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind() << std::endl;
}
}
void CheckGlobalInitIsConst(SysYParser::InitValContext* ctx) {
if (!ctx) return;
if (ctx->exp()) {
ExprInfo info = CheckExp(ctx->exp());
if (!info.is_const) {
throw std::runtime_error(FormatError("sema", "全局变量初始化必须是常量表达式"));
}
} else {
for (auto* init : ctx->initVal()) {
CheckGlobalInitIsConst(init);
}
}
}
int EvaluateConstExp(SysYParser::ConstExpContext* ctx) {
if (!ctx || !ctx->addExp()) return 0;
ExprInfo info = CheckAddExp(ctx->addExp());
if (info.is_const && info.is_const_int) {
return info.const_int_value;
}
throw std::runtime_error(FormatError("sema", "常量表达式求值失败"));
return 0;
}
struct ConstValue {
bool is_int;
int int_val;
float float_val;
};
std::vector<ConstValue> EvaluateConstInitVal(SysYParser::ConstInitValContext* ctx,
const std::vector<int>& dims,
std::shared_ptr<ir::Type> base_type) {
std::vector<ConstValue> result;
if (!ctx) return result;
if (ctx->constExp()) {
ExprInfo info = CheckAddExp(ctx->constExp()->addExp());
ConstValue val;
if (info.type->IsInt32() && info.is_const_int) {
val.is_int = true;
val.int_val = info.const_int_value;
if (base_type->IsFloat()) {
val.is_int = false;
val.float_val = (float)info.const_int_value;
}
} else if (info.type->IsFloat() && info.is_const) {
val.is_int = false;
val.float_val = info.const_float_value;
if (base_type->IsInt32()) {
val.is_int = true;
val.int_val = (int)info.const_float_value;
}
} else {
val.is_int = base_type->IsInt32();
val.int_val = 0;
val.float_val = 0.0f;
}
result.push_back(val);
} else {
for (auto* init : ctx->constInitVal()) {
std::vector<ConstValue> sub_vals = EvaluateConstInitVal(init, dims, base_type);
result.insert(result.end(), sub_vals.begin(), sub_vals.end());
}
}
return result;
}
void CheckMainFunction() {
auto* main_sym = table_.lookup("main");
if (!main_sym || main_sym->kind != SymbolKind::Function) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数"));
}
std::shared_ptr<ir::Type> ret_type;
if (main_sym->type && main_sym->type->IsFunction()) {
auto* func_type = dynamic_cast<ir::FunctionType*>(main_sym->type.get());
if (func_type) {
ret_type = func_type->GetReturnType();
}
}
if (!ret_type || !ret_type->IsInt32()) {
throw std::runtime_error(FormatError("sema", "main 函数必须返回 int"));
}
if (!main_sym->param_types.empty()) {
throw std::runtime_error(FormatError("sema", "main 函数不能有参数"));
}
}
};
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
}