You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/sem/Sema.cpp

440 lines
15 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "../../include/sem/Sema.h"
#include "../../generated/src/antlr4/SysYParser.h"
#include <stdexcept>
#include <algorithm>
#include <iostream>
using namespace antlr4;
// ===================== 核心访问器实现 =====================
// 1. 编译单元节点访问
std::any SemaVisitor::visitCompUnit(SysYParser::CompUnitContext* ctx) {
// 分析编译单元中的所有子节点
return visitChildren(ctx);
}
// 2. 函数定义节点访问
std::any SemaVisitor::visitFuncDef(SysYParser::FuncDefContext* ctx) {
FuncInfo info;
// 通过funcType()获取函数类型
if (ctx->funcType()) {
std::string func_type_text = ctx->funcType()->getText();
if (func_type_text == "void") {
info.ret_type = SymbolType::TYPE_VOID;
} else if (func_type_text == "int") {
info.ret_type = SymbolType::TYPE_INT;
} else if (func_type_text == "float") {
info.ret_type = SymbolType::TYPE_FLOAT;
}
}
// 绑定函数名和返回类型
if (ctx->Ident()) {
info.name = ctx->Ident()->getText();
}
ir_ctx_.SetCurrentFuncReturnType(info.ret_type);
// 递归分析函数体
if (ctx->block()) {
visit(ctx->block());
}
return std::any();
}
// 3. 声明节点访问
std::any SemaVisitor::visitDecl(SysYParser::DeclContext* ctx) {
return visitChildren(ctx);
}
// 4. 常量声明节点访问
std::any SemaVisitor::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
return visitChildren(ctx);
}
// 5. 变量声明节点访问
std::any SemaVisitor::visitVarDecl(SysYParser::VarDeclContext* ctx) {
return visitChildren(ctx);
}
// 6. 代码块节点访问
std::any SemaVisitor::visitBlock(SysYParser::BlockContext* ctx) {
// 进入新的作用域
ir_ctx_.EnterScope();
// 访问块内的语句
std::any result = visitChildren(ctx);
// 离开作用域
ir_ctx_.LeaveScope();
return result;
}
// 7. 语句节点访问
std::any SemaVisitor::visitStmt(SysYParser::StmtContext* ctx) {
// 赋值语句lVal = exp;
if (ctx->lVal() && ctx->exp()) {
auto l_val_ctx = ctx->lVal();
auto exp_ctx = ctx->exp();
// 解析左右值类型
SymbolType l_type = ir_ctx_.GetType(l_val_ctx);
SymbolType r_type = ir_ctx_.GetType(exp_ctx);
// 类型不匹配报错
if (l_type != r_type && l_type != SymbolType::TYPE_UNKNOWN && r_type != SymbolType::TYPE_UNKNOWN) {
std::string l_type_str = (l_type == SymbolType::TYPE_INT ? "int" : "float");
std::string r_type_str = (r_type == SymbolType::TYPE_INT ? "int" : "float");
std::string err_msg = "赋值类型不匹配,左值为" + l_type_str + ",右值为" + r_type_str;
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg(err_msg, line, col));
}
// 绑定左值类型(同步右值类型)
ir_ctx_.SetType(l_val_ctx, r_type);
}
// IF语句
else if (ctx->cond() && ctx->stmt().size() >= 1) {
auto cond_ctx = ctx->cond();
// IF条件必须为整型
SymbolType cond_type = ir_ctx_.GetType(cond_ctx);
if (cond_type != SymbolType::TYPE_INT && cond_type != SymbolType::TYPE_UNKNOWN) {
int line = cond_ctx->getStart()->getLine();
int col = cond_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("if条件表达式必须为整型", line, col));
}
// 递归分析IF体和可能的ELSE体
visit(ctx->stmt(0));
if (ctx->stmt().size() >= 2) {
visit(ctx->stmt(1));
}
}
// WHILE语句
else if (ctx->cond() && ctx->stmt().size() >= 1) {
ir_ctx_.EnterLoop(); // 标记进入循环
auto cond_ctx = ctx->cond();
// WHILE条件必须为整型
SymbolType cond_type = ir_ctx_.GetType(cond_ctx);
if (cond_type != SymbolType::TYPE_INT && cond_type != SymbolType::TYPE_UNKNOWN) {
int line = cond_ctx->getStart()->getLine();
int col = cond_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("while条件表达式必须为整型", line, col));
}
// 递归分析循环体
visit(ctx->stmt(0));
ir_ctx_.ExitLoop(); // 标记退出循环
}
// BREAK语句
else if (ctx->getText().find("break") != std::string::npos) {
if (!ir_ctx_.InLoop()) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("break只能出现在循环语句中", line, col));
}
}
// CONTINUE语句
else if (ctx->getText().find("continue") != std::string::npos) {
if (!ir_ctx_.InLoop()) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("continue只能出现在循环语句中", line, col));
}
}
// RETURN语句
else if (ctx->getText().find("return") != std::string::npos) {
SymbolType func_ret_type = ir_ctx_.GetCurrentFuncReturnType();
// 有返回表达式的情况
if (ctx->exp()) {
auto exp_ctx = ctx->exp();
SymbolType exp_type = ir_ctx_.GetType(exp_ctx);
// 返回类型不匹配报错
if (exp_type != func_ret_type && exp_type != SymbolType::TYPE_UNKNOWN && func_ret_type != SymbolType::TYPE_UNKNOWN) {
std::string ret_type_str = (func_ret_type == SymbolType::TYPE_INT ? "int" : (func_ret_type == SymbolType::TYPE_FLOAT ? "float" : "void"));
std::string exp_type_str = (exp_type == SymbolType::TYPE_INT ? "int" : "float");
std::string err_msg = "return表达式类型与函数返回类型不匹配期望" + ret_type_str + ",实际为" + exp_type_str;
int line = exp_ctx->getStart()->getLine();
int col = exp_ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg(err_msg, line, col));
}
}
// 无返回表达式的情况
else {
if (func_ret_type != SymbolType::TYPE_VOID && func_ret_type != SymbolType::TYPE_UNKNOWN) {
int line = ctx->getStart()->getLine();
int col = ctx->getStart()->getCharPositionInLine() + 1;
ir_ctx_.RecordError(ErrorMsg("非void函数return必须带表达式", line, col));
}
}
}
// 其他语句
return visitChildren(ctx);
}
// 8. 左值节点访问
std::any SemaVisitor::visitLVal(SysYParser::LValContext* ctx) {
return visitChildren(ctx);
}
// 9. 表达式节点访问
std::any SemaVisitor::visitExp(SysYParser::ExpContext* ctx) {
return visitChildren(ctx);
}
// 10. 条件表达式节点访问
std::any SemaVisitor::visitCond(SysYParser::CondContext* ctx) {
return visitChildren(ctx);
}
// 11. 基本表达式节点访问
std::any SemaVisitor::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
return visitChildren(ctx);
}
// 12. 一元表达式节点访问
std::any SemaVisitor::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 带一元运算符的表达式(+/-/!
if (ctx->unaryOp() && ctx->unaryExp()) {
auto op_ctx = ctx->unaryOp();
auto uexp_ctx = ctx->unaryExp();
auto uexp_val = visit(uexp_ctx);
std::string op_text = op_ctx->getText();
SymbolType uexp_type = ir_ctx_.GetType(uexp_ctx);
// 正号 +x → 直接返回原值
if (op_text == "+") {
ir_ctx_.SetType(ctx, uexp_type);
ir_ctx_.SetConstVal(ctx, uexp_val);
return uexp_val;
}
// 负号 -x → 取反
else if (op_text == "-") {
if (ir_ctx_.IsIntType(uexp_val)) {
long val = std::any_cast<long>(uexp_val);
ir_ctx_.SetConstVal(ctx, std::any(-val));
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(-val);
} else if (ir_ctx_.IsFloatType(uexp_val)) {
double val = std::any_cast<double>(uexp_val);
ir_ctx_.SetConstVal(ctx, std::any(-val));
ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT);
return std::any(-val);
}
}
// 逻辑非 !x → 0/1转换
else if (op_text == "!") {
if (ir_ctx_.IsIntType(uexp_val)) {
long val = std::any_cast<long>(uexp_val);
long res = (val == 0) ? 1L : 0L;
ir_ctx_.SetConstVal(ctx, std::any(res));
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(res);
}
}
}
// 函数调用表达式
else if (ctx->Ident() && ctx->funcRParams()) {
// 这里简化处理
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return std::any(0L);
}
// 基础表达式
else if (ctx->primaryExp()) {
auto val = visit(ctx->primaryExp());
ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->primaryExp()));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
return std::any();
}
// 13. 乘法表达式节点访问
std::any SemaVisitor::visitMulExp(SysYParser::MulExpContext* ctx) {
auto uexps = ctx->unaryExp();
// 单操作数 → 直接返回
if (uexps.size() == 1) {
auto val = visit(uexps[0]);
ir_ctx_.SetType(ctx, ir_ctx_.GetType(uexps[0]));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
// 多操作数 → 依次计算
std::any result = visit(uexps[0]);
SymbolType current_type = ir_ctx_.GetType(uexps[0]);
for (size_t i = 1; i < uexps.size(); ++i) {
auto next_uexp = uexps[i];
auto next_val = visit(next_uexp);
SymbolType next_type = ir_ctx_.GetType(next_uexp);
// 类型统一int和float混合转为float
if (current_type == SymbolType::TYPE_INT && next_type == SymbolType::TYPE_FLOAT) {
current_type = SymbolType::TYPE_FLOAT;
} else if (current_type == SymbolType::TYPE_FLOAT && next_type == SymbolType::TYPE_INT) {
current_type = SymbolType::TYPE_FLOAT;
}
// 简化处理:这里假设是乘法运算
if (ir_ctx_.IsIntType(result) && ir_ctx_.IsIntType(next_val)) {
long v1 = std::any_cast<long>(result);
long v2 = std::any_cast<long>(next_val);
result = std::any(v1 * v2);
} else if (ir_ctx_.IsFloatType(result) && ir_ctx_.IsFloatType(next_val)) {
double v1 = std::any_cast<double>(result);
double v2 = std::any_cast<double>(next_val);
result = std::any(v1 * v2);
}
// 更新当前节点类型和常量值
ir_ctx_.SetType(ctx, current_type);
ir_ctx_.SetConstVal(ctx, result);
}
return result;
}
// 14. 加法表达式节点访问
std::any SemaVisitor::visitAddExp(SysYParser::AddExpContext* ctx) {
auto mexps = ctx->mulExp();
// 单操作数 → 直接返回
if (mexps.size() == 1) {
auto val = visit(mexps[0]);
ir_ctx_.SetType(ctx, ir_ctx_.GetType(mexps[0]));
ir_ctx_.SetConstVal(ctx, val);
return val;
}
// 多操作数 → 依次计算
std::any result = visit(mexps[0]);
SymbolType current_type = ir_ctx_.GetType(mexps[0]);
for (size_t i = 1; i < mexps.size(); ++i) {
auto next_mexp = mexps[i];
auto next_val = visit(next_mexp);
SymbolType next_type = ir_ctx_.GetType(next_mexp);
// 类型统一
if (current_type == SymbolType::TYPE_INT && next_type == SymbolType::TYPE_FLOAT) {
current_type = SymbolType::TYPE_FLOAT;
} else if (current_type == SymbolType::TYPE_FLOAT && next_type == SymbolType::TYPE_INT) {
current_type = SymbolType::TYPE_FLOAT;
}
// 简化处理:这里假设是加法运算
if (ir_ctx_.IsIntType(result) && ir_ctx_.IsIntType(next_val)) {
long v1 = std::any_cast<long>(result);
long v2 = std::any_cast<long>(next_val);
result = std::any(v1 + v2);
} else if (ir_ctx_.IsFloatType(result) && ir_ctx_.IsFloatType(next_val)) {
double v1 = std::any_cast<double>(result);
double v2 = std::any_cast<double>(next_val);
result = std::any(v1 + v2);
}
ir_ctx_.SetType(ctx, current_type);
ir_ctx_.SetConstVal(ctx, result);
}
return result;
}
// 15. 关系表达式节点访问
std::any SemaVisitor::visitRelExp(SysYParser::RelExpContext* ctx) {
auto aexps = ctx->addExp();
// 单操作数 → 直接返回
if (aexps.size() == 1) {
auto val = visit(aexps[0]);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return val;
}
// 多操作数 → 简化处理
std::any result = std::any(1L);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, result);
return result;
}
// 16. 相等表达式节点访问
std::any SemaVisitor::visitEqExp(SysYParser::EqExpContext* ctx) {
auto rexps = ctx->relExp();
// 单操作数 → 直接返回
if (rexps.size() == 1) {
auto val = visit(rexps[0]);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
return val;
}
// 多操作数 → 简化处理
std::any result = std::any(1L);
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, result);
return result;
}
// 17. 逻辑与表达式节点访问
std::any SemaVisitor::visitLAndExp(SysYParser::LAndExpContext* ctx) {
return visitChildren(ctx);
}
// 18. 逻辑或表达式节点访问
std::any SemaVisitor::visitLOrExp(SysYParser::LOrExpContext* ctx) {
return visitChildren(ctx);
}
// 19. 常量表达式节点访问
std::any SemaVisitor::visitConstExp(SysYParser::ConstExpContext* ctx) {
return visitChildren(ctx);
}
// 20. 数字节点访问
std::any SemaVisitor::visitNumber(SysYParser::NumberContext* ctx) {
// 这里简化处理,实际需要解析整型和浮点型
if (ctx->IntConst()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_INT);
ir_ctx_.SetConstVal(ctx, std::any(0L));
return std::any(0L);
} else if (ctx->FloatConst()) {
ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT);
ir_ctx_.SetConstVal(ctx, std::any(0.0));
return std::any(0.0);
}
return std::any();
}
// 21. 函数参数节点访问
std::any SemaVisitor::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
return visitChildren(ctx);
}
// ===================== 语义分析入口函数 =====================
void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx) {
if (!ctx) {
throw std::invalid_argument("CompUnitContext is null");
}
SemaVisitor visitor(ir_ctx);
visitor.visit(ctx);
}