#include "sem/func.h" #include #include #include #include "utils/Log.h" #include // 提供 ldexp namespace { // 解析十六进制浮点字面量,支持 0xH.Hp±E 格式 double ParseHexFloat(const std::string& str) { const char* s = str.c_str(); if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) s += 2; double significand = 0.0; bool have_dot = false; double dot_scale = 1.0 / 16.0; while (*s && *s != 'p' && *s != 'P') { if (*s == '.') { have_dot = true; ++s; continue; } int digit = -1; if (*s >= '0' && *s <= '9') digit = *s - '0'; else if (*s >= 'a' && *s <= 'f') digit = *s - 'a' + 10; else if (*s >= 'A' && *s <= 'F') digit = *s - 'A' + 10; if (digit >= 0) { if (have_dot) { significand += digit * dot_scale; dot_scale /= 16.0; } else { significand = significand * 16 + digit; } } ++s; } int exponent = 0; if (*s == 'p' || *s == 'P') { ++s; int sign = 1; if (*s == '-') { sign = -1; ++s; } else if (*s == '+') { ++s; } exponent = 0; while (*s >= '0' && *s <= '9') { exponent = exponent * 10 + (*s - '0'); ++s; } exponent *= sign; } return ldexp(significand, exponent); } } // anonymous namespace namespace sem { // Truncate double to float32 precision (mimics C float arithmetic) static double ToFloat32(double v) { float f = static_cast(v); return static_cast(f); } // 编译时求值常量表达式 ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) { return EvaluateExp(*ctx.addExp()); } // 求值表达式 ConstValue EvaluateExp(SysYParser::AddExpContext& ctx) { ConstValue result = EvaluateMulExp(*ctx.mulExp(0)); for (size_t i = 1; i < ctx.mulExp().size(); ++i) { ConstValue rhs = EvaluateMulExp(*ctx.mulExp(i)); if (ctx.AddOp(i-1)->getText() == "+") { result = AddValues(result, rhs); } else { result = SubValues(result, rhs); } } return result; } // 求值乘法表达式 ConstValue EvaluateMulExp(SysYParser::MulExpContext& ctx) { ConstValue result = EvaluateUnaryExp(*ctx.unaryExp(0)); for (size_t i = 1; i < ctx.unaryExp().size(); ++i) { ConstValue rhs = EvaluateUnaryExp(*ctx.unaryExp(i)); std::string op = ctx.MulOp(i-1)->getText(); if (op == "*") { result = MulValues(result, rhs); } else if (op == "/") { if (IsZero(rhs)) { throw std::runtime_error(FormatError("sema", "除零错误")); } result = DivValues(result, rhs); } else if (op == "%") { if (IsZero(rhs)) { throw std::runtime_error(FormatError("sema", "取模除零错误")); } result = ModValues(result, rhs); } } return result; } // 求值一元表达式 ConstValue EvaluateUnaryExp(SysYParser::UnaryExpContext& ctx) { if (ctx.unaryOp()) { ConstValue operand = EvaluateUnaryExp(*ctx.unaryExp()); std::string op = ctx.unaryOp()->getText(); if (op == "-") { return NegValue(operand); } else if (op == "!") { return NotValue(operand); } else { return operand; // "+" 操作符 } } else if (ctx.primaryExp()) { return EvaluatePrimaryExp(*ctx.primaryExp()); } else { throw std::runtime_error(FormatError("sema", "非法常量表达式")); } } // 求值基本表达式 ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) { if (ctx.exp()) { return EvaluateExp(*ctx.exp()->addExp()); } else if (ctx.lVar()) { // 处理变量引用:向上遍历 AST 找到对应的常量定义并求值 auto* ident = ctx.lVar()->Ident(); if (!ident) { throw std::runtime_error(FormatError("sema", "非法变量引用")); } std::string name = ident->getText(); // 向上遍历 AST 找到作用域内的 constDef antlr4::ParserRuleContext* scope = dynamic_cast(ctx.lVar()->parent); while (scope) { // 检查当前作用域中的所有 constDecl for (auto* tree_child : scope->children) { auto* child = dynamic_cast(tree_child); if (!child) continue; auto* block_item = dynamic_cast(child); if (block_item && block_item->decl()) { auto* decl = block_item->decl(); if (decl->constDecl()) { for (auto* def : decl->constDecl()->constDef()) { if (def->Ident() && def->Ident()->getText() == name) { if (def->constInitVal() && def->constInitVal()->constExp()) { ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp()); bool decl_is_int = decl->constDecl()->bType() && decl->constDecl()->bType()->Int(); if (decl_is_int) { cv.is_int = true; cv.int_val = static_cast(static_cast(cv.float_val)); cv.float_val = static_cast(cv.int_val); } return cv; } } } } } // compUnit 级别的 constDecl auto* decl = dynamic_cast(child); if (decl && decl->constDecl()) { for (auto* def : decl->constDecl()->constDef()) { if (def->Ident() && def->Ident()->getText() == name) { if (def->constInitVal() && def->constInitVal()->constExp()) { ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp()); // If declared as int, truncate to integer bool decl_is_int = decl->constDecl()->bType() && decl->constDecl()->bType()->Int(); if (decl_is_int) { cv.is_int = true; cv.int_val = static_cast(static_cast(cv.float_val)); cv.float_val = static_cast(cv.int_val); } return cv; } } } } } scope = dynamic_cast(scope->parent); } // 未找到常量定义,返回 0 ConstValue val; val.is_int = true; val.int_val = 0; val.float_val = 0.0; return val; } else if (ctx.number()) { // 处理数字字面量 auto* int_const = ctx.number()->IntConst(); auto* float_const = ctx.number()->FloatConst(); ConstValue val; if (int_const) { val.is_int = true; val.int_val = std::stoll(int_const->getText(), nullptr, 0); val.float_val = static_cast(val.int_val); } else if (float_const) { val.is_int = false; std::string text = float_const->getText(); if (text.size() >= 2 && (text[1] == 'x' || text[1] == 'X')) { val.float_val = ToFloat32(ParseHexFloat(text)); } else { val.float_val = ToFloat32(std::stod(text)); } val.int_val = static_cast(val.float_val); } else { throw std::runtime_error(FormatError("sema", "非法数字字面量")); } return val; } else { throw std::runtime_error(FormatError("sema", "非法基本表达式")); } } // 辅助函数:检查值是否为零 bool IsZero(const ConstValue& val) { if (val.is_int) { return val.int_val == 0; } else { return val.float_val == 0.0; } } // 辅助函数:加法 ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) { ConstValue result; if (lhs.is_int && rhs.is_int) { result.is_int = true; result.int_val = lhs.int_val + rhs.int_val; result.float_val = static_cast(result.int_val); } else { result.is_int = false; double l = lhs.is_int ? lhs.int_val : lhs.float_val; double r = rhs.is_int ? rhs.int_val : rhs.float_val; result.float_val = ToFloat32(l + r); result.int_val = static_cast(result.float_val); } return result; } // 辅助函数:减法 ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) { ConstValue result; if (lhs.is_int && rhs.is_int) { result.is_int = true; result.int_val = lhs.int_val - rhs.int_val; result.float_val = static_cast(result.int_val); } else { result.is_int = false; double l = lhs.is_int ? lhs.int_val : lhs.float_val; double r = rhs.is_int ? rhs.int_val : rhs.float_val; result.float_val = ToFloat32(l - r); result.int_val = static_cast(result.float_val); } return result; } // 辅助函数:乘法 ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) { ConstValue result; if (lhs.is_int && rhs.is_int) { result.is_int = true; result.int_val = lhs.int_val * rhs.int_val; result.float_val = static_cast(result.int_val); } else { result.is_int = false; double l = lhs.is_int ? lhs.int_val : lhs.float_val; double r = rhs.is_int ? rhs.int_val : rhs.float_val; result.float_val = ToFloat32(l * r); result.int_val = static_cast(result.float_val); } return result; } // 辅助函数:除法 ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) { ConstValue result; if (lhs.is_int && rhs.is_int) { result.is_int = true; result.int_val = lhs.int_val / rhs.int_val; result.float_val = static_cast(result.int_val); } else { result.is_int = false; double l = lhs.is_int ? lhs.int_val : lhs.float_val; double r = rhs.is_int ? rhs.int_val : rhs.float_val; result.float_val = ToFloat32(l / r); result.int_val = static_cast(result.float_val); } return result; } // 辅助函数:取模 ConstValue ModValues(const ConstValue& lhs, const ConstValue& rhs) { ConstValue result; if (!lhs.is_int || !rhs.is_int) { throw std::runtime_error(FormatError("sema", "取模运算只能用于整数")); } result.is_int = true; result.int_val = lhs.int_val % rhs.int_val; result.float_val = static_cast(result.int_val); return result; } // 辅助函数:取负 ConstValue NegValue(const ConstValue& val) { ConstValue result; result.is_int = val.is_int; if (val.is_int) { result.int_val = -val.int_val; result.float_val = static_cast(result.int_val); } else { result.float_val = -val.float_val; result.int_val = static_cast(result.float_val); } return result; } // 辅助函数:逻辑非 ConstValue NotValue(const ConstValue& val) { ConstValue result; result.is_int = true; if (val.is_int) { result.int_val = !val.int_val; } else { result.int_val = !val.float_val; } result.float_val = static_cast(result.int_val); return result; } // 检查常量初始化器 size_t CheckConstInitVal(SysYParser::ConstInitValContext& ctx, const std::vector& dimensions, bool is_int, size_t total_elements) { if (ctx.constExp()) { // 单个常量值 // 求值并检查常量表达式 ConstValue value = EvaluateConstExp(*ctx.constExp()); // 检查类型约束 if (is_int && !value.is_int) { throw std::runtime_error(FormatError("sema", "整型数组的初始化列表中不能出现浮点型常量")); } // 检查值域 if (is_int) { if (value.int_val < INT_MIN || value.int_val > INT_MAX) { throw std::runtime_error(FormatError("sema", "整数值超过int类型表示范围")); } } return 1; } else if (ctx.L_BRACE()) { // 花括号初始化列表 size_t count = 0; auto init_vals = ctx.constInitVal(); for (auto* init_val : init_vals) { // 计算剩余维度的总元素个数 size_t remaining_elements = total_elements; if (!dimensions.empty()) { remaining_elements = total_elements / dimensions[0]; } count += CheckConstInitVal(*init_val, std::vector(dimensions.begin() + 1, dimensions.end()), is_int, remaining_elements); } // 检查总元素个数 if (count > total_elements) { throw std::runtime_error(FormatError("sema", "初始化列表元素个数超过数组大小")); } return count; } else { // 空初始化列表 return 0; } } // 检查变量初始化器 size_t CheckInitVal(SysYParser::InitValContext& ctx, const std::vector& dimensions, bool is_int, size_t total_elements) { if (ctx.exp()) { // 单个表达式值 // 检查表达式中的变量引用 // 这里不需要编译时求值,只需要检查类型约束 // 类型检查在IR生成阶段进行 return 1; } else if (ctx.L_BRACE()) { // 花括号初始化列表 size_t count = 0; auto init_vals = ctx.initVal(); for (auto* init_val : init_vals) { // 计算剩余维度的总元素个数 size_t remaining_elements = total_elements; if (!dimensions.empty()) { remaining_elements = total_elements / dimensions[0]; } count += CheckInitVal(*init_val, std::vector(dimensions.begin() + 1, dimensions.end()), is_int, remaining_elements); } // 检查总元素个数 if (count > total_elements) { throw std::runtime_error(FormatError("sema", "初始化列表元素个数超过数组大小")); } return count; } else { // 空初始化列表 return 0; } } } // namespace sem