You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

516 lines
16 KiB

#include "sem/Sema.h"
#include <any>
#include <stdexcept>
#include <string>
#include <unordered_set>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
namespace {
static BaseTypeKind BaseTypeFromBType(SysYParser::BTypeContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少 bType"));
}
if (ctx->INT()) return BaseTypeKind::Int;
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知基础类型"));
}
static BaseTypeKind BaseTypeFromFuncType(SysYParser::FuncTypeContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少 funcType"));
}
if (ctx->VOID()) return BaseTypeKind::Void;
if (ctx->INT()) return BaseTypeKind::Int;
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知函数返回类型"));
}
class ConstEvalVisitor final : public SysYBaseVisitor {
public:
explicit ConstEvalVisitor(const SymbolTable& table) : table_(table) {}
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override {
return visitAddExp(ctx->addExp());
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
auto muls = ctx->mulExp();
if (muls.empty()) return 0;
int value = std::any_cast<int>(muls[0]->accept(this));
for (size_t i = 1; i < muls.size(); ++i) {
int rhs = std::any_cast<int>(muls[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "+";
if (text == "+") {
value += rhs;
} else if (text == "-") {
value -= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法加法运算符"));
}
}
return value;
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
auto unaries = ctx->unaryExp();
if (unaries.empty()) return 0;
int value = std::any_cast<int>(unaries[0]->accept(this));
for (size_t i = 1; i < unaries.size(); ++i) {
int rhs = std::any_cast<int>(unaries[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "*";
if (text == "*") {
value *= rhs;
} else if (text == "/") {
value /= rhs;
} else if (text == "%") {
value %= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法乘法运算符"));
}
}
return value;
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
if (ctx->unaryOp() && ctx->unaryExp()) {
int val = std::any_cast<int>(ctx->unaryExp()->accept(this));
auto op = ctx->unaryOp()->getText();
if (op == "+") return val;
if (op == "-") return -val;
throw std::runtime_error(FormatError("sema", "constExp 不支持 !"));
}
throw std::runtime_error(FormatError("sema", "constExp 不支持函数调用"));
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this);
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
return 0;
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (ctx->INT_CONST()) {
const std::string text = ctx->getText();
size_t idx = 0;
long long val = std::stoll(text, &idx, 0);
if (idx != text.size()) {
throw std::runtime_error(FormatError("sema", "非法整数常量"));
}
return static_cast<int>(val);
}
if (ctx->FLOAT_CONST()) {
return static_cast<int>(std::stof(ctx->getText()));
}
throw std::runtime_error(FormatError("sema", "constExp 仅支持整数"));
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "constExp 非法变量"));
}
const auto* entry = table_.Lookup(ctx->ID()->getText());
if (!entry || !entry->is_const || !entry->const_value.has_value()) {
throw std::runtime_error(FormatError("sema", "constExp 使用了非常量"));
}
return entry->const_value.value();
}
private:
const SymbolTable& table_;
};
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
for (auto* func : ctx->funcDef()) {
if (!func || !func->ID()) continue;
std::string name = func->ID()->getText();
if (func_table_.find(name) != func_table_.end()) {
throw std::runtime_error(FormatError("sema", "重复定义函数: " + name));
}
func_table_[name] = func;
}
for (auto* decl : ctx->decl()) {
if (decl) decl->accept(this);
}
for (auto* func : ctx->funcDef()) {
if (func) func->accept(this);
}
if (func_table_.find("main") == func_table_.end()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
return {};
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->block()) {
throw std::runtime_error(FormatError("sema", "函数体为空"));
}
if (!ctx->ID()) {
throw std::runtime_error(FormatError("sema", "缺少函数名"));
}
FuncTypeDesc fty;
fty.ret.base = BaseTypeFromFuncType(ctx->funcType());
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
fty.params.push_back(BuildParamType(param));
}
}
sema_.RegisterFunc(ctx, fty);
current_ret_ = fty.ret.base;
seen_return_ = false;
table_.EnterScope();
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
RegisterParam(param);
}
}
ctx->block()->accept(this);
table_.ExitScope();
if (current_ret_ != BaseTypeKind::Void && !seen_return_) {
throw std::runtime_error(FormatError("sema", "非 void 函数缺少 return"));
}
return {};
}
std::any visitBlock(SysYParser::BlockContext* ctx) override {
if (!ctx) return {};
table_.EnterScope();
for (auto* item : ctx->blockItem()) {
if (item) item->accept(this);
}
table_.ExitScope();
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) return {};
if (ctx->decl()) return ctx->decl()->accept(this);
if (ctx->stmt()) return ctx->stmt()->accept(this);
return {};
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) return {};
if (auto* c = ctx->constDecl()) return c->accept(this);
if (auto* v = ctx->varDecl()) return v->accept(this);
return {};
}
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) return {};
BaseTypeKind base = BaseTypeFromBType(ctx->bType());
for (auto* def : ctx->constDef()) {
RegisterConst(def, base);
}
return {};
}
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
if (!ctx || !ctx->bType()) return {};
BaseTypeKind base = BaseTypeFromBType(ctx->bType());
for (auto* def : ctx->varDef()) {
RegisterVar(def, base);
}
return {};
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) return {};
if (ctx->lVal() && ctx->ASSIGN()) {
ctx->lVal()->accept(this);
if (ctx->exp()) ctx->exp()->accept(this);
return {};
}
if (ctx->block()) return ctx->block()->accept(this);
if (ctx->IF()) {
if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this);
if (ctx->stmt(1)) ctx->stmt(1)->accept(this);
return {};
}
if (ctx->WHILE()) {
loop_depth_++;
if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this);
loop_depth_--;
return {};
}
if (ctx->BREAK()) {
if (loop_depth_ == 0) {
throw std::runtime_error(FormatError("sema", "break 不在循环内"));
}
return {};
}
if (ctx->CONTINUE()) {
if (loop_depth_ == 0) {
throw std::runtime_error(FormatError("sema", "continue 不在循环内"));
}
return {};
}
if (ctx->RETURN()) {
if (ctx->exp()) ctx->exp()->accept(this);
if (current_ret_ == BaseTypeKind::Void && ctx->exp()) {
throw std::runtime_error(FormatError("sema", "void 函数不能返回值"));
}
if (current_ret_ != BaseTypeKind::Void && !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值"));
}
seen_return_ = true;
return {};
}
if (ctx->exp()) ctx->exp()->accept(this);
return {};
}
std::any visitExp(SysYParser::ExpContext* ctx) override {
if (ctx->addExp()) return ctx->addExp()->accept(this);
return {};
}
std::any visitCond(SysYParser::CondContext* ctx) override {
if (ctx->lOrExp()) return ctx->lOrExp()->accept(this);
return {};
}
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
for (auto* e : ctx->lAndExp()) e->accept(this);
return {};
}
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
for (auto* e : ctx->eqExp()) e->accept(this);
return {};
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
for (auto* e : ctx->relExp()) e->accept(this);
return {};
}
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
for (auto* e : ctx->addExp()) e->accept(this);
return {};
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
for (auto* mul : ctx->mulExp()) mul->accept(this);
return {};
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
for (auto* unary : ctx->unaryExp()) unary->accept(this);
return {};
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
if (ctx->ID() && ctx->LPAREN()) {
std::string name = ctx->ID()->getText();
auto it = func_table_.find(name);
if (it == func_table_.end()) {
if (builtin_funcs_.find(name) == builtin_funcs_.end()) {
throw std::runtime_error(FormatError("sema", "未定义的函数: " + name));
}
} else {
sema_.BindFuncCall(ctx, it->second);
}
if (ctx->funcRParams()) ctx->funcRParams()->accept(this);
return {};
}
if (ctx->unaryExp()) return ctx->unaryExp()->accept(this);
return {};
}
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override {
for (auto* e : ctx->exp()) e->accept(this);
return {};
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this);
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
return {};
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (!ctx->INT_CONST() && !ctx->FLOAT_CONST()) {
throw std::runtime_error(FormatError("sema", "非法常量"));
}
return {};
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
std::string name = ctx->ID()->getText();
const SymbolEntry* entry = table_.Lookup(name);
if (!entry) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
BoundDecl bound;
if (entry->kind == SymbolKind::Var) {
bound.kind = BoundDecl::Kind::Var;
bound.var_decl = entry->var_decl;
} else if (entry->kind == SymbolKind::Const) {
bound.kind = BoundDecl::Kind::Const;
bound.const_decl = entry->const_decl;
} else {
bound.kind = BoundDecl::Kind::Param;
bound.param_decl = entry->param_decl;
}
sema_.BindVarUse(ctx, bound);
for (auto* exp : ctx->exp()) {
if (exp) {
exp->accept(this);
}
}
return {};
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
TypeDesc BuildParamType(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法参数"));
}
TypeDesc ty;
ty.base = BaseTypeFromBType(ctx->bType());
if (ctx->LBRACK().size() > 0) {
ty.dims.push_back(-1);
for (auto* exp : ctx->exp()) {
ty.dims.push_back(EvalConstExp(exp));
}
}
return ty;
}
void RegisterParam(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "参数缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义参数: " + name));
}
TypeDesc ty = BuildParamType(ctx);
SymbolEntry entry;
entry.kind = SymbolKind::Param;
entry.param_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterParam(ctx, ty);
}
void RegisterVar(SysYParser::VarDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "变量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
TypeDesc ty;
ty.base = base;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Var;
entry.var_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterVarDecl(ctx, ty);
if (auto* init = ctx->initVal()) {
init->accept(this);
}
}
void RegisterConst(SysYParser::ConstDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "常量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
TypeDesc ty;
ty.base = base;
ty.is_const = true;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Const;
entry.const_decl = ctx;
entry.is_const = true;
entry.type = ty;
if (ctx->constInitVal() && ty.dims.empty() && ty.base == BaseTypeKind::Int) {
if (auto* exp = ctx->constInitVal()->constExp()) {
entry.const_value = EvalConstExp(exp);
}
}
table_.Add(name, entry);
sema_.RegisterConstDecl(ctx, ty);
if (auto* init = ctx->constInitVal()) {
init->accept(this);
}
}
int EvalConstExp(SysYParser::ConstExpContext* ctx) {
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->accept(&visitor));
}
int EvalConstExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "非法常量表达式"));
}
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->addExp()->accept(&visitor));
}
private:
SymbolTable table_;
SemanticContext sema_;
std::unordered_map<std::string, SysYParser::FuncDefContext*> func_table_;
const std::unordered_set<std::string> builtin_funcs_ = {
"getint", "getch", "getarray", "putint", "putch", "putarray",
"getfloat", "getfarray", "putfloat", "putfarray", "starttime",
"stoptime"};
BaseTypeKind current_ret_ = BaseTypeKind::Void;
bool seen_return_ = false;
int loop_depth_ = 0;
};
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
}