You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1048 lines
33 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "sem/Sema.h"
#include <any>
#include <cmath>
#include <cstdlib>
#include <stdexcept>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
namespace {
constexpr int kUnknownArrayDim = -1;
struct ExprInfo {
SemanticType type = SemanticType::Int;
bool is_lvalue = false;
bool is_const_object = false;
std::vector<int> dimensions;
bool has_const_value = false;
ScalarConstant const_value;
bool IsScalar() const { return dimensions.empty() && type != SemanticType::Void; }
bool IsArray() const { return !dimensions.empty(); }
};
SemanticType ParseBType(SysYParser::BTypeContext& ctx) {
if (ctx.INT()) {
return SemanticType::Int;
}
if (ctx.FLOAT()) {
return SemanticType::Float;
}
throw std::runtime_error(FormatError("sema", "未知基础类型"));
}
SemanticType ParseFuncType(SysYParser::FuncTypeContext& ctx) {
if (ctx.VOID()) {
return SemanticType::Void;
}
if (ctx.INT()) {
return SemanticType::Int;
}
if (ctx.FLOAT()) {
return SemanticType::Float;
}
throw std::runtime_error(FormatError("sema", "未知函数返回类型"));
}
int ConvertToInt(const ScalarConstant& value) {
return static_cast<int>(value.number);
}
double ConvertToFloat(const ScalarConstant& value) { return value.number; }
bool IsNumericType(SemanticType type) {
return type == SemanticType::Int || type == SemanticType::Float;
}
bool CanImplicitlyConvert(SemanticType from, SemanticType to) {
if (from == to) {
return true;
}
if (!IsNumericType(from) || !IsNumericType(to)) {
return false;
}
return true;
}
ScalarConstant CastConstant(const ScalarConstant& value, SemanticType to) {
if (!CanImplicitlyConvert(value.type, to)) {
throw std::runtime_error(FormatError("sema", "非法常量类型转换"));
}
ScalarConstant result;
result.type = to;
result.number = to == SemanticType::Int ? static_cast<double>(ConvertToInt(value))
: ConvertToFloat(value);
return result;
}
bool IsTrue(const ScalarConstant& value) {
if (value.type == SemanticType::Float) {
return value.number != 0.0;
}
return ConvertToInt(value) != 0;
}
ScalarConstant MakeInt(int value) {
return ScalarConstant{SemanticType::Int, static_cast<double>(value)};
}
ScalarConstant MakeFloat(double value) {
return ScalarConstant{SemanticType::Float, value};
}
const antlr4::Token* StartToken(const antlr4::ParserRuleContext* ctx) {
return ctx ? ctx->getStart() : nullptr;
}
[[noreturn]] void ThrowSemaError(const antlr4::ParserRuleContext* ctx,
std::string_view msg) {
if (const auto* tok = StartToken(ctx)) {
throw std::runtime_error(
FormatErrorAt("sema", tok->getLine(), tok->getCharPositionInLine(), msg));
}
throw std::runtime_error(FormatError("sema", msg));
}
int ParseIntLiteral(SysYParser::IntConstContext& ctx) {
return std::stoi(ctx.getText(), nullptr, 0);
}
double ParseFloatLiteral(SysYParser::FloatConstContext& ctx) {
const std::string text = ctx.getText();
char* end = nullptr;
const double value = std::strtod(text.c_str(), &end);
if (end == nullptr || *end != '\0') {
throw std::runtime_error(FormatError("sema", "非法浮点字面量: " + text));
}
return value;
}
FunctionBinding MakeBuiltinFunction(std::string name, SemanticType return_type,
std::vector<ObjectBinding> params) {
FunctionBinding fn;
fn.name = std::move(name);
fn.return_type = return_type;
fn.params = std::move(params);
fn.is_builtin = true;
return fn;
}
ObjectBinding MakeParam(std::string name, SemanticType type,
std::vector<int> dimensions = {},
bool is_array_param = false) {
ObjectBinding param;
param.name = std::move(name);
param.type = type;
param.decl_kind = ObjectBinding::DeclKind::Param;
param.dimensions = std::move(dimensions);
param.is_array_param = is_array_param;
return param;
}
class SemaVisitor final : public SysYBaseVisitor {
public:
SemaVisitor() { RegisterBuiltins(); }
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "缺少编译单元");
}
CollectFunctions(*ctx);
for (auto* item : ctx->topLevelItem()) {
if (!item) {
continue;
}
item->accept(this);
}
const FunctionBinding* main = sema_.ResolveFunction("main");
if (!main || main->is_builtin) {
ThrowSemaError(ctx, "缺少 main 函数定义");
}
if (main->return_type != SemanticType::Int || !main->params.empty()) {
ThrowSemaError(main->func_def, "main 函数必须是无参 int main()");
}
return {};
}
std::any visitTopLevelItem(SysYParser::TopLevelItemContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "缺少顶层定义");
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
if (ctx->funcDef()) {
ctx->funcDef()->accept(this);
return {};
}
ThrowSemaError(ctx, "暂不支持的顶层定义");
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "缺少声明");
}
if (ctx->constDecl()) {
ctx->constDecl()->accept(this);
return {};
}
if (ctx->varDecl()) {
ctx->varDecl()->accept(this);
return {};
}
ThrowSemaError(ctx, "非法声明");
}
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) {
ThrowSemaError(ctx, "非法常量声明");
}
const SemanticType type = ParseBType(*ctx->bType());
for (auto* def : ctx->constDef()) {
DeclareConst(*def, type);
}
return {};
}
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
if (!ctx || !ctx->bType()) {
ThrowSemaError(ctx, "非法变量声明");
}
const SemanticType type = ParseBType(*ctx->bType());
for (auto* def : ctx->varDef()) {
DeclareVar(*def, type);
}
return {};
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->ID() || !ctx->funcType() || !ctx->block()) {
ThrowSemaError(ctx, "非法函数定义");
}
const FunctionBinding* binding = sema_.ResolveFunction(ctx->ID()->getText());
if (!binding) {
ThrowSemaError(ctx, "函数未完成预收集: " + ctx->ID()->getText());
}
const FunctionBinding* prev = current_function_;
current_function_ = binding;
symbols_.EnterScope();
for (const auto& param : binding->params) {
if (!symbols_.Add(param)) {
ThrowSemaError(ctx, "函数形参重复定义: " + param.name);
}
}
ctx->block()->accept(this);
symbols_.ExitScope();
current_function_ = prev;
return {};
}
std::any visitBlock(SysYParser::BlockContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "缺少语句块");
}
symbols_.EnterScope();
for (auto* item : ctx->blockItem()) {
if (item) {
item->accept(this);
}
}
symbols_.ExitScope();
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "缺少块内语句");
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
}
ThrowSemaError(ctx, "非法块内语句");
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "缺少语句");
}
if (ctx->BREAK()) {
if (loop_depth_ == 0) {
ThrowSemaError(ctx, "break 只能出现在循环内部");
}
return {};
}
if (ctx->CONTINUE()) {
if (loop_depth_ == 0) {
ThrowSemaError(ctx, "continue 只能出现在循环内部");
}
return {};
}
if (ctx->RETURN()) {
CheckReturn(*ctx);
return {};
}
if (ctx->WHILE()) {
RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "while 条件必须是标量表达式");
++loop_depth_;
ctx->stmt(0)->accept(this);
--loop_depth_;
return {};
}
if (ctx->IF()) {
RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "if 条件必须是标量表达式");
ctx->stmt(0)->accept(this);
if (ctx->stmt().size() > 1 && ctx->stmt(1)) {
ctx->stmt(1)->accept(this);
}
return {};
}
if (ctx->block()) {
ctx->block()->accept(this);
return {};
}
if (ctx->lVal() && ctx->ASSIGN()) {
CheckAssignment(*ctx);
return {};
}
if (ctx->exp()) {
EvalExpr(*ctx->exp());
return {};
}
return {};
}
std::any visitExp(SysYParser::ExpContext* ctx) override {
if (!ctx || !ctx->addExp()) {
ThrowSemaError(ctx, "非法表达式");
}
return EvalExpr(*ctx->addExp());
}
std::any visitCond(SysYParser::CondContext* ctx) override {
if (!ctx || !ctx->lOrExp()) {
ThrowSemaError(ctx, "非法条件表达式");
}
return EvalExpr(*ctx->lOrExp());
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) {
ThrowSemaError(ctx, "非法左值");
}
return AnalyzeLVal(*ctx);
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法基础表达式");
}
if (ctx->exp()) {
return EvalExpr(*ctx->exp());
}
if (ctx->lVal()) {
return AnalyzeLVal(*ctx->lVal());
}
if (ctx->number()) {
return EvalExpr(*ctx->number());
}
ThrowSemaError(ctx, "非法基础表达式");
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法数字字面量");
}
if (ctx->intConst()) {
return EvalExpr(*ctx->intConst());
}
if (ctx->floatConst()) {
return EvalExpr(*ctx->floatConst());
}
ThrowSemaError(ctx, "非法数字字面量");
}
std::any visitIntConst(SysYParser::IntConstContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法整数字面量");
}
ExprInfo expr;
expr.type = SemanticType::Int;
expr.has_const_value = true;
expr.const_value = MakeInt(ParseIntLiteral(*ctx));
return expr;
}
std::any visitFloatConst(SysYParser::FloatConstContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法浮点字面量");
}
ExprInfo expr;
expr.type = SemanticType::Float;
expr.has_const_value = true;
expr.const_value = MakeFloat(ParseFloatLiteral(*ctx));
return expr;
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法一元表达式");
}
if (ctx->primaryExp()) {
return EvalExpr(*ctx->primaryExp());
}
if (ctx->ID()) {
return AnalyzeCall(*ctx);
}
if (ctx->addUnaryOp() && ctx->unaryExp()) {
ExprInfo operand = EvalExpr(*ctx->unaryExp());
RequireScalar(ctx->unaryExp(), operand, "一元运算要求标量操作数");
ExprInfo result;
result.type = operand.type;
if (ctx->addUnaryOp()->SUB() && operand.has_const_value) {
result.has_const_value = true;
result.const_value = operand.const_value;
result.const_value.number = -result.const_value.number;
} else if (ctx->addUnaryOp()->ADD() && operand.has_const_value) {
result.has_const_value = true;
result.const_value = operand.const_value;
}
return result;
}
ThrowSemaError(ctx, "非法一元表达式");
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法乘法表达式");
}
// 如果是 mulExp : unaryExp 形式(没有 MUL/DIV/MOD token直接处理 unaryExp
if (!ctx->MUL() && !ctx->DIV() && !ctx->MOD()) {
return EvalExpr(*ctx->unaryExp());
}
// 否则是 mulExp MUL/DIV/MOD unaryExp 形式
ExprInfo lhs = EvalExpr(*ctx->mulExp());
ExprInfo rhs = EvalExpr(*ctx->unaryExp());
return EvalArithmetic(*ctx, lhs, rhs, ctx->MUL() ? '*' : (ctx->DIV() ? '/' : '%'));
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法加法表达式");
}
// 如果是 addExp : mulExp 形式(没有 ADD/SUB token直接处理 mulExp
if (!ctx->ADD() && !ctx->SUB()) {
return EvalExpr(*ctx->mulExp());
}
// 否则是 addExp ADD/SUB mulExp 形式
ExprInfo lhs = EvalExpr(*ctx->addExp());
ExprInfo rhs = EvalExpr(*ctx->mulExp());
return EvalArithmetic(*ctx, lhs, rhs, ctx->ADD() ? '+' : '-');
}
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法关系表达式");
}
if (ctx->relExp() == nullptr) {
return EvalExpr(*ctx->addExp());
}
ExprInfo lhs = EvalExpr(*ctx->relExp());
ExprInfo rhs = EvalExpr(*ctx->addExp());
return EvalCompare(*ctx, lhs, rhs);
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法相等表达式");
}
if (ctx->eqExp() == nullptr) {
return EvalExpr(*ctx->relExp());
}
ExprInfo lhs = EvalExpr(*ctx->eqExp());
ExprInfo rhs = EvalExpr(*ctx->relExp());
return EvalCompare(*ctx, lhs, rhs);
}
std::any visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法条件一元表达式");
}
if (ctx->eqExp()) {
return EvalExpr(*ctx->eqExp());
}
ExprInfo operand = EvalExpr(*ctx->condUnaryExp());
RequireScalar(ctx->condUnaryExp(), operand, "逻辑非要求标量操作数");
ExprInfo result;
result.type = SemanticType::Int;
if (operand.has_const_value) {
result.has_const_value = true;
result.const_value = MakeInt(IsTrue(operand.const_value) ? 0 : 1);
}
return result;
}
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法逻辑与表达式");
}
if (ctx->lAndExp() == nullptr) {
return EvalExpr(*ctx->condUnaryExp());
}
ExprInfo lhs = EvalExpr(*ctx->lAndExp());
ExprInfo rhs = EvalExpr(*ctx->condUnaryExp());
return EvalLogical(*ctx, lhs, rhs, true);
}
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
if (!ctx) {
ThrowSemaError(ctx, "非法逻辑或表达式");
}
if (ctx->lOrExp() == nullptr) {
return EvalExpr(*ctx->lAndExp());
}
ExprInfo lhs = EvalExpr(*ctx->lOrExp());
ExprInfo rhs = EvalExpr(*ctx->lAndExp());
return EvalLogical(*ctx, lhs, rhs, false);
}
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override {
if (!ctx || !ctx->addExp()) {
ThrowSemaError(ctx, "非法常量表达式");
}
ExprInfo expr = EvalExpr(*ctx->addExp());
if (!expr.IsScalar() || !expr.has_const_value) {
ThrowSemaError(ctx, "要求编译期常量表达式");
}
return expr;
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
ExprInfo EvalExpr(antlr4::tree::ParseTree& node) {
return std::any_cast<ExprInfo>(node.accept(this));
}
ExprInfo EvalCond(SysYParser::CondContext& cond) { return EvalExpr(cond); }
ExprInfo AnalyzeLVal(SysYParser::LValContext& ctx) {
const std::string name = ctx.ID()->getText();
const ObjectBinding* symbol = symbols_.Lookup(name);
if (!symbol) {
ThrowSemaError(&ctx, "使用了未声明的标识符:" + name);
}
sema_.BindObjectUse(&ctx, *symbol);
if (ctx.exp().size() > symbol->dimensions.size()) {
ThrowSemaError(&ctx, "数组下标过多: " + name);
}
for (auto* exp : ctx.exp()) {
ExprInfo index = EvalExpr(*exp);
RequireScalar(exp, index, "数组下标必须是标量表达式");
}
ExprInfo result;
result.type = symbol->type;
result.is_const_object = symbol->decl_kind == ObjectBinding::DeclKind::Const;
result.is_lvalue = ctx.exp().size() == symbol->dimensions.size();
result.dimensions.assign(symbol->dimensions.begin() + ctx.exp().size(),
symbol->dimensions.end());
if (result.dimensions.empty() && symbol->has_const_value) {
result.has_const_value = true;
result.const_value = symbol->const_value;
}
return result;
}
ExprInfo AnalyzeCall(SysYParser::UnaryExpContext& ctx) {
const std::string name = ctx.ID()->getText();
if (const ObjectBinding* object = symbols_.Lookup(name)) {
ThrowSemaError(&ctx, "标识符不是函数: " + object->name);
}
const FunctionBinding* fn = sema_.ResolveFunction(name);
if (!fn) {
ThrowSemaError(&ctx, "调用了未定义的函数: " + name);
}
std::vector<ExprInfo> args;
if (ctx.funcRParams()) {
for (auto* exp : ctx.funcRParams()->exp()) {
args.push_back(EvalExpr(*exp));
}
}
if (args.size() != fn->params.size()) {
ThrowSemaError(&ctx, "函数参数个数不匹配: " + name);
}
for (size_t i = 0; i < args.size(); ++i) {
CheckArgument(ctx, fn->params[i], args[i], i);
}
sema_.BindFunctionCall(&ctx, *fn);
ExprInfo result;
result.type = fn->return_type;
return result;
}
void CheckArgument(const antlr4::ParserRuleContext& call_site,
const ObjectBinding& param, const ExprInfo& arg,
size_t index) {
if (param.dimensions.empty()) {
if (!arg.IsScalar()) {
ThrowSemaError(&call_site, "" + std::to_string(index + 1) +
" 个参数需要标量实参");
}
if (!CanImplicitlyConvert(arg.type, param.type)) {
ThrowSemaError(&call_site, "" + std::to_string(index + 1) +
" 个参数类型不匹配");
}
return;
}
if (!arg.IsArray()) {
ThrowSemaError(&call_site, "" + std::to_string(index + 1) +
" 个参数需要数组实参");
}
if (arg.type != param.type || arg.dimensions.size() != param.dimensions.size()) {
ThrowSemaError(&call_site, "" + std::to_string(index + 1) +
" 个数组参数类型不匹配");
}
for (size_t dim = 1; dim < param.dimensions.size(); ++dim) {
if (param.dimensions[dim] != kUnknownArrayDim &&
arg.dimensions[dim] != param.dimensions[dim]) {
ThrowSemaError(&call_site, "" + std::to_string(index + 1) +
" 个数组参数维度不匹配");
}
}
}
void CheckAssignment(SysYParser::StmtContext& ctx) {
ExprInfo lhs = AnalyzeLVal(*ctx.lVal());
if (!lhs.IsScalar() || !lhs.is_lvalue) {
ThrowSemaError(&ctx, "赋值语句左侧必须是可写标量左值");
}
if (lhs.is_const_object) {
ThrowSemaError(&ctx, "不能给 const 对象赋值");
}
ExprInfo rhs = EvalExpr(*ctx.exp());
RequireScalar(ctx.exp(), rhs, "赋值语句右侧必须是标量表达式");
if (!CanImplicitlyConvert(rhs.type, lhs.type)) {
ThrowSemaError(&ctx, "赋值语句两侧类型不兼容");
}
}
void CheckReturn(SysYParser::StmtContext& ctx) {
if (!current_function_) {
ThrowSemaError(&ctx, "return 语句不在函数内部");
}
if (current_function_->return_type == SemanticType::Void) {
if (ctx.exp()) {
ThrowSemaError(&ctx, "void 函数不能返回值");
}
return;
}
if (!ctx.exp()) {
ThrowSemaError(&ctx, "非 void 函数必须返回值");
}
ExprInfo expr = EvalExpr(*ctx.exp());
RequireScalar(ctx.exp(), expr, "return 表达式必须是标量");
if (!CanImplicitlyConvert(expr.type, current_function_->return_type)) {
ThrowSemaError(&ctx, "return 表达式类型与函数返回类型不匹配");
}
}
void DeclareConst(SysYParser::ConstDefContext& ctx, SemanticType type) {
ObjectBinding symbol;
symbol.name = ctx.ID()->getText();
symbol.type = type;
symbol.decl_kind = ObjectBinding::DeclKind::Const;
symbol.const_def = &ctx;
symbol.dimensions = EvalArrayDims(ctx.constIndex(), true);
if (symbols_.ContainsInCurrentScope(symbol.name)) {
ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name);
}
if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) {
ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name);
}
if (!ctx.constInitVal()) {
ThrowSemaError(&ctx, "const 对象缺少初始化");
}
if (symbol.dimensions.empty()) {
symbol.const_value = ValidateConstInitScalar(*ctx.constInitVal(), type);
symbol.has_const_value = true;
} else {
ValidateConstInitAggregate(*ctx.constInitVal(), type);
}
if (!symbols_.Add(symbol)) {
ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name);
}
}
void DeclareVar(SysYParser::VarDefContext& ctx, SemanticType type) {
ObjectBinding symbol;
symbol.name = ctx.ID()->getText();
symbol.type = type;
symbol.decl_kind = ObjectBinding::DeclKind::Var;
symbol.var_def = &ctx;
symbol.dimensions = EvalArrayDims(ctx.constIndex(), true);
if (symbols_.ContainsInCurrentScope(symbol.name)) {
ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name);
}
if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) {
ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name);
}
if (!symbols_.Add(symbol)) {
ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name);
}
if (!ctx.initVal()) {
return;
}
if (symbol.dimensions.empty()) {
ValidateVarInitScalar(*ctx.initVal(), type, symbols_.Depth() == 1);
} else {
ValidateVarInitAggregate(*ctx.initVal(), type, symbols_.Depth() == 1);
}
}
std::vector<int> EvalArrayDims(
const std::vector<SysYParser::ConstIndexContext*>& indices,
bool require_positive) {
std::vector<int> dims;
dims.reserve(indices.size());
for (auto* index : indices) {
if (!index || !index->constExp()) {
ThrowSemaError(index, "数组维度缺少常量表达式");
}
ExprInfo expr = EvalExpr(*index->constExp());
if (!expr.IsScalar() || !expr.has_const_value) {
ThrowSemaError(index, "数组维度必须是整型常量表达式");
}
const int dim = ConvertToInt(CastConstant(expr.const_value, SemanticType::Int));
if (require_positive && dim <= 0) {
ThrowSemaError(index, "数组维度必须为正整数");
}
dims.push_back(dim);
}
return dims;
}
ScalarConstant ValidateConstInitScalar(SysYParser::ConstInitValContext& init,
SemanticType target_type) {
if (!init.constExp()) {
ThrowSemaError(&init, "标量 const 初始化必须是常量表达式");
}
ExprInfo expr = EvalExpr(*init.constExp());
if (!expr.IsScalar() || !expr.has_const_value) {
ThrowSemaError(&init, "标量 const 初始化必须是常量表达式");
}
return CastConstant(expr.const_value, target_type);
}
void ValidateConstInitAggregate(SysYParser::ConstInitValContext& init,
SemanticType target_type) {
if (init.constExp()) {
ExprInfo expr = EvalExpr(*init.constExp());
if (!expr.IsScalar() || !expr.has_const_value) {
ThrowSemaError(&init, "数组 const 初始化要求常量表达式");
}
CastConstant(expr.const_value, target_type);
return;
}
for (auto* nested : init.constInitVal()) {
if (nested) {
ValidateConstInitAggregate(*nested, target_type);
}
}
}
void ValidateVarInitScalar(SysYParser::InitValContext& init,
SemanticType target_type, bool require_constant) {
if (!init.exp()) {
ThrowSemaError(&init, "标量初始化非法");
}
ExprInfo expr = EvalExpr(*init.exp());
RequireScalar(&init, expr, "标量初始化要求标量表达式");
if (!CanImplicitlyConvert(expr.type, target_type)) {
ThrowSemaError(&init, "初始化表达式类型不兼容");
}
if (require_constant && !expr.has_const_value) {
ThrowSemaError(&init, "全局变量初始化要求编译期常量");
}
}
void ValidateVarInitAggregate(SysYParser::InitValContext& init,
SemanticType target_type, bool require_constant) {
if (init.exp()) {
ExprInfo expr = EvalExpr(*init.exp());
RequireScalar(&init, expr, "数组初始化元素必须是标量表达式");
if (!CanImplicitlyConvert(expr.type, target_type)) {
ThrowSemaError(&init, "数组初始化元素类型不兼容");
}
if (require_constant && !expr.has_const_value) {
ThrowSemaError(&init, "全局数组初始化要求编译期常量");
}
return;
}
for (auto* nested : init.initVal()) {
if (nested) {
ValidateVarInitAggregate(*nested, target_type, require_constant);
}
}
}
ExprInfo EvalArithmetic(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs,
const ExprInfo& rhs, char op) {
RequireScalar(&ctx, lhs, "算术运算要求标量操作数");
RequireScalar(&ctx, rhs, "算术运算要求标量操作数");
ExprInfo result;
result.type = lhs.type == SemanticType::Float || rhs.type == SemanticType::Float
? SemanticType::Float
: SemanticType::Int;
if (!lhs.has_const_value || !rhs.has_const_value) {
return result;
}
result.has_const_value = true;
const ScalarConstant lc = CastConstant(lhs.const_value, result.type);
const ScalarConstant rc = CastConstant(rhs.const_value, result.type);
if (result.type == SemanticType::Float) {
double value = 0.0;
if (op == '+') value = lc.number + rc.number;
if (op == '-') value = lc.number - rc.number;
if (op == '*') value = lc.number * rc.number;
if (op == '/') value = lc.number / rc.number;
if (op == '%') {
ThrowSemaError(&ctx, "浮点数不支持取模运算");
}
result.const_value = MakeFloat(value);
return result;
}
const int li = ConvertToInt(lc);
const int ri = ConvertToInt(rc);
int value = 0;
if (op == '+') value = li + ri;
if (op == '-') value = li - ri;
if (op == '*') value = li * ri;
if (op == '/') value = li / ri;
if (op == '%') value = li % ri;
result.const_value = MakeInt(value);
return result;
}
ExprInfo EvalCompare(antlr4::ParserRuleContext& ctx, const ExprInfo& lhs,
const ExprInfo& rhs) {
RequireScalar(&ctx, lhs, "比较运算要求标量操作数");
RequireScalar(&ctx, rhs, "比较运算要求标量操作数");
ExprInfo result;
result.type = SemanticType::Int;
if (!lhs.has_const_value || !rhs.has_const_value) {
return result;
}
const SemanticType promoted =
lhs.type == SemanticType::Float || rhs.type == SemanticType::Float
? SemanticType::Float
: SemanticType::Int;
const ScalarConstant lc = CastConstant(lhs.const_value, promoted);
const ScalarConstant rc = CastConstant(rhs.const_value, promoted);
bool value = false;
if (auto* rel = dynamic_cast<SysYParser::RelExpContext*>(&ctx)) {
if (rel->LT()) value = lc.number < rc.number;
if (rel->GT()) value = lc.number > rc.number;
if (rel->LE()) value = lc.number <= rc.number;
if (rel->GE()) value = lc.number >= rc.number;
} else if (auto* eq = dynamic_cast<SysYParser::EqExpContext*>(&ctx)) {
if (eq->EQ()) value = lc.number == rc.number;
if (eq->NE()) value = lc.number != rc.number;
}
result.has_const_value = true;
result.const_value = MakeInt(value ? 1 : 0);
return result;
}
ExprInfo EvalLogical(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs,
const ExprInfo& rhs, bool is_and) {
RequireScalar(&ctx, lhs, "逻辑运算要求标量操作数");
RequireScalar(&ctx, rhs, "逻辑运算要求标量操作数");
ExprInfo result;
result.type = SemanticType::Int;
if (!lhs.has_const_value || !rhs.has_const_value) {
return result;
}
const bool value =
is_and ? (IsTrue(lhs.const_value) && IsTrue(rhs.const_value))
: (IsTrue(lhs.const_value) || IsTrue(rhs.const_value));
result.has_const_value = true;
result.const_value = MakeInt(value ? 1 : 0);
return result;
}
void RequireScalar(const antlr4::ParserRuleContext* ctx, const ExprInfo& expr,
std::string_view message) {
if (!expr.IsScalar()) {
ThrowSemaError(ctx, message);
}
}
void CollectFunctions(SysYParser::CompUnitContext& ctx) {
for (auto* item : ctx.topLevelItem()) {
if (!item || !item->funcDef()) {
continue;
}
FunctionBinding fn = BuildFunctionSignature(*item->funcDef());
if (sema_.ResolveFunction(fn.name)) {
ThrowSemaError(item->funcDef(), "重复定义函数: " + fn.name);
}
if (symbols_.ContainsInCurrentScope(fn.name)) {
ThrowSemaError(item->funcDef(), "函数与全局对象重名: " + fn.name);
}
sema_.RegisterFunction(std::move(fn));
}
}
FunctionBinding BuildFunctionSignature(SysYParser::FuncDefContext& ctx) {
FunctionBinding fn;
fn.name = ctx.ID()->getText();
fn.return_type = ParseFuncType(*ctx.funcType());
fn.func_def = &ctx;
if (ctx.funcFParams()) {
for (auto* param : ctx.funcFParams()->funcFParam()) {
fn.params.push_back(BuildParamBinding(*param));
}
}
return fn;
}
ObjectBinding BuildParamBinding(SysYParser::FuncFParamContext& ctx) {
if (!ctx.ID() || !ctx.bType()) {
ThrowSemaError(&ctx, "非法函数形参");
}
ObjectBinding param;
param.name = ctx.ID()->getText();
param.type = ParseBType(*ctx.bType());
param.decl_kind = ObjectBinding::DeclKind::Param;
param.func_param = &ctx;
if (!ctx.LBRACK().empty()) {
param.is_array_param = true;
param.dimensions.push_back(kUnknownArrayDim);
for (auto* exp : ctx.exp()) {
ExprInfo dim = EvalExpr(*exp);
if (!dim.IsScalar() || !dim.has_const_value) {
ThrowSemaError(&ctx, "数组形参维度必须是整型常量表达式");
}
const int value = ConvertToInt(CastConstant(dim.const_value, SemanticType::Int));
if (value <= 0) {
ThrowSemaError(&ctx, "数组形参维度必须为正整数");
}
param.dimensions.push_back(value);
}
}
return param;
}
void RegisterBuiltins() {
sema_.RegisterFunction(MakeBuiltinFunction("getint", SemanticType::Int, {}));
sema_.RegisterFunction(MakeBuiltinFunction("getch", SemanticType::Int, {}));
sema_.RegisterFunction(
MakeBuiltinFunction("getfloat", SemanticType::Float, {}));
sema_.RegisterFunction(MakeBuiltinFunction(
"getarray", SemanticType::Int,
{MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)}));
sema_.RegisterFunction(MakeBuiltinFunction(
"getfarray", SemanticType::Int,
{MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)}));
sema_.RegisterFunction(MakeBuiltinFunction(
"putint", SemanticType::Void, {MakeParam("x", SemanticType::Int)}));
sema_.RegisterFunction(MakeBuiltinFunction(
"putch", SemanticType::Void, {MakeParam("x", SemanticType::Int)}));
sema_.RegisterFunction(MakeBuiltinFunction(
"putfloat", SemanticType::Void, {MakeParam("x", SemanticType::Float)}));
sema_.RegisterFunction(MakeBuiltinFunction(
"putarray", SemanticType::Void,
{MakeParam("n", SemanticType::Int),
MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)}));
sema_.RegisterFunction(MakeBuiltinFunction(
"putfarray", SemanticType::Void,
{MakeParam("n", SemanticType::Int),
MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)}));
sema_.RegisterFunction(
MakeBuiltinFunction("starttime", SemanticType::Void, {}));
sema_.RegisterFunction(
MakeBuiltinFunction("stoptime", SemanticType::Void, {}));
}
SymbolTable symbols_;
SemanticContext sema_;
const FunctionBinding* current_function_ = nullptr;
int loop_depth_ = 0;
};
} // namespace
void SemanticContext::BindObjectUse(const SysYParser::LValContext* use,
ObjectBinding binding) {
object_uses_[use] = std::move(binding);
}
const ObjectBinding* SemanticContext::ResolveObjectUse(
const SysYParser::LValContext* use) const {
auto it = object_uses_.find(use);
return it == object_uses_.end() ? nullptr : &it->second;
}
void SemanticContext::BindFunctionCall(const SysYParser::UnaryExpContext* call,
FunctionBinding binding) {
function_calls_[call] = std::move(binding);
}
const FunctionBinding* SemanticContext::ResolveFunctionCall(
const SysYParser::UnaryExpContext* call) const {
auto it = function_calls_.find(call);
return it == function_calls_.end() ? nullptr : &it->second;
}
void SemanticContext::RegisterFunction(FunctionBinding binding) {
functions_[binding.name] = std::move(binding);
}
const FunctionBinding* SemanticContext::ResolveFunction(
const std::string& name) const {
auto it = functions_.find(name);
return it == functions_.end() ? nullptr : &it->second;
}
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
}