forked from ppxf25tqu/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
432 lines
14 KiB
432 lines
14 KiB
#include "sem/func.h"
|
|
|
|
#include <cstring>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
|
|
#include "utils/Log.h"
|
|
#include <cmath> // 提供 ldexp
|
|
|
|
namespace {
|
|
|
|
// 解析十六进制浮点字面量,支持 0xH.Hp±E 格式
|
|
double ParseHexFloat(const std::string& str) {
|
|
const char* s = str.c_str();
|
|
if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) s += 2;
|
|
|
|
double significand = 0.0;
|
|
bool have_dot = false;
|
|
double dot_scale = 1.0 / 16.0;
|
|
|
|
while (*s && *s != 'p' && *s != 'P') {
|
|
if (*s == '.') {
|
|
have_dot = true;
|
|
++s;
|
|
continue;
|
|
}
|
|
int digit = -1;
|
|
if (*s >= '0' && *s <= '9') digit = *s - '0';
|
|
else if (*s >= 'a' && *s <= 'f') digit = *s - 'a' + 10;
|
|
else if (*s >= 'A' && *s <= 'F') digit = *s - 'A' + 10;
|
|
if (digit >= 0) {
|
|
if (have_dot) {
|
|
significand += digit * dot_scale;
|
|
dot_scale /= 16.0;
|
|
} else {
|
|
significand = significand * 16 + digit;
|
|
}
|
|
}
|
|
++s;
|
|
}
|
|
|
|
int exponent = 0;
|
|
if (*s == 'p' || *s == 'P') {
|
|
++s;
|
|
int sign = 1;
|
|
if (*s == '-') { sign = -1; ++s; }
|
|
else if (*s == '+') { ++s; }
|
|
exponent = 0;
|
|
while (*s >= '0' && *s <= '9') {
|
|
exponent = exponent * 10 + (*s - '0');
|
|
++s;
|
|
}
|
|
exponent *= sign;
|
|
}
|
|
|
|
return ldexp(significand, exponent);
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
namespace sem {
|
|
|
|
// Truncate double to float32 precision (mimics C float arithmetic)
|
|
static double ToFloat32(double v) {
|
|
float f = static_cast<float>(v);
|
|
return static_cast<double>(f);
|
|
}
|
|
|
|
// 编译时求值常量表达式
|
|
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) {
|
|
return EvaluateExp(*ctx.addExp());
|
|
}
|
|
|
|
// 求值表达式
|
|
ConstValue EvaluateExp(SysYParser::AddExpContext& ctx) {
|
|
ConstValue result = EvaluateMulExp(*ctx.mulExp(0));
|
|
for (size_t i = 1; i < ctx.mulExp().size(); ++i) {
|
|
ConstValue rhs = EvaluateMulExp(*ctx.mulExp(i));
|
|
if (ctx.AddOp(i-1)->getText() == "+") {
|
|
result = AddValues(result, rhs);
|
|
} else {
|
|
result = SubValues(result, rhs);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// 求值乘法表达式
|
|
ConstValue EvaluateMulExp(SysYParser::MulExpContext& ctx) {
|
|
ConstValue result = EvaluateUnaryExp(*ctx.unaryExp(0));
|
|
for (size_t i = 1; i < ctx.unaryExp().size(); ++i) {
|
|
ConstValue rhs = EvaluateUnaryExp(*ctx.unaryExp(i));
|
|
std::string op = ctx.MulOp(i-1)->getText();
|
|
if (op == "*") {
|
|
result = MulValues(result, rhs);
|
|
} else if (op == "/") {
|
|
if (IsZero(rhs)) {
|
|
throw std::runtime_error(FormatError("sema", "除零错误"));
|
|
}
|
|
result = DivValues(result, rhs);
|
|
} else if (op == "%") {
|
|
if (IsZero(rhs)) {
|
|
throw std::runtime_error(FormatError("sema", "取模除零错误"));
|
|
}
|
|
result = ModValues(result, rhs);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// 求值一元表达式
|
|
ConstValue EvaluateUnaryExp(SysYParser::UnaryExpContext& ctx) {
|
|
if (ctx.unaryOp()) {
|
|
ConstValue operand = EvaluateUnaryExp(*ctx.unaryExp());
|
|
std::string op = ctx.unaryOp()->getText();
|
|
if (op == "-") {
|
|
return NegValue(operand);
|
|
} else if (op == "!") {
|
|
return NotValue(operand);
|
|
} else {
|
|
return operand; // "+" 操作符
|
|
}
|
|
} else if (ctx.primaryExp()) {
|
|
return EvaluatePrimaryExp(*ctx.primaryExp());
|
|
} else {
|
|
throw std::runtime_error(FormatError("sema", "非法常量表达式"));
|
|
}
|
|
}
|
|
|
|
// 求值基本表达式
|
|
ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
|
|
if (ctx.exp()) {
|
|
return EvaluateExp(*ctx.exp()->addExp());
|
|
} else if (ctx.lVar()) {
|
|
// 处理变量引用:向上遍历 AST 找到对应的常量定义并求值
|
|
auto* ident = ctx.lVar()->Ident();
|
|
if (!ident) {
|
|
throw std::runtime_error(FormatError("sema", "非法变量引用"));
|
|
}
|
|
std::string name = ident->getText();
|
|
// 向上遍历 AST 找到作用域内的 constDef
|
|
antlr4::ParserRuleContext* scope =
|
|
dynamic_cast<antlr4::ParserRuleContext*>(ctx.lVar()->parent);
|
|
while (scope) {
|
|
// 检查当前作用域中的所有 constDecl
|
|
for (auto* tree_child : scope->children) {
|
|
auto* child = dynamic_cast<antlr4::ParserRuleContext*>(tree_child);
|
|
if (!child) continue;
|
|
auto* block_item = dynamic_cast<SysYParser::BlockItemContext*>(child);
|
|
if (block_item && block_item->decl()) {
|
|
auto* decl = block_item->decl();
|
|
if (decl->constDecl()) {
|
|
for (auto* def : decl->constDecl()->constDef()) {
|
|
if (def->Ident() && def->Ident()->getText() == name) {
|
|
if (def->constInitVal() && def->constInitVal()->constExp()) {
|
|
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
|
|
bool decl_is_int = decl->constDecl()->bType() &&
|
|
decl->constDecl()->bType()->Int();
|
|
if (decl_is_int) {
|
|
cv.is_int = true;
|
|
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
|
|
cv.float_val = static_cast<double>(cv.int_val);
|
|
}
|
|
return cv;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// compUnit 级别的 constDecl
|
|
auto* decl = dynamic_cast<SysYParser::DeclContext*>(child);
|
|
if (decl && decl->constDecl()) {
|
|
for (auto* def : decl->constDecl()->constDef()) {
|
|
if (def->Ident() && def->Ident()->getText() == name) {
|
|
if (def->constInitVal() && def->constInitVal()->constExp()) {
|
|
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
|
|
// If declared as int, truncate to integer
|
|
bool decl_is_int = decl->constDecl()->bType() &&
|
|
decl->constDecl()->bType()->Int();
|
|
if (decl_is_int) {
|
|
cv.is_int = true;
|
|
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
|
|
cv.float_val = static_cast<double>(cv.int_val);
|
|
}
|
|
return cv;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
scope = dynamic_cast<antlr4::ParserRuleContext*>(scope->parent);
|
|
}
|
|
// 未找到常量定义,返回 0
|
|
ConstValue val;
|
|
val.is_int = true;
|
|
val.int_val = 0;
|
|
val.float_val = 0.0;
|
|
return val;
|
|
} else if (ctx.number()) {
|
|
// 处理数字字面量
|
|
auto* int_const = ctx.number()->IntConst();
|
|
auto* float_const = ctx.number()->FloatConst();
|
|
|
|
ConstValue val;
|
|
if (int_const) {
|
|
val.is_int = true;
|
|
val.int_val = std::stoll(int_const->getText(), nullptr, 0);
|
|
val.float_val = static_cast<double>(val.int_val);
|
|
} else if (float_const) {
|
|
val.is_int = false;
|
|
std::string text = float_const->getText();
|
|
if (text.size() >= 2 && (text[1] == 'x' || text[1] == 'X')) {
|
|
val.float_val = ToFloat32(ParseHexFloat(text));
|
|
} else {
|
|
val.float_val = ToFloat32(std::stod(text));
|
|
}
|
|
val.int_val = static_cast<long long>(val.float_val);
|
|
} else {
|
|
throw std::runtime_error(FormatError("sema", "非法数字字面量"));
|
|
}
|
|
return val;
|
|
} else {
|
|
throw std::runtime_error(FormatError("sema", "非法基本表达式"));
|
|
}
|
|
}
|
|
|
|
// 辅助函数:检查值是否为零
|
|
bool IsZero(const ConstValue& val) {
|
|
if (val.is_int) {
|
|
return val.int_val == 0;
|
|
} else {
|
|
return val.float_val == 0.0;
|
|
}
|
|
}
|
|
|
|
// 辅助函数:加法
|
|
ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) {
|
|
ConstValue result;
|
|
if (lhs.is_int && rhs.is_int) {
|
|
result.is_int = true;
|
|
result.int_val = lhs.int_val + rhs.int_val;
|
|
result.float_val = static_cast<double>(result.int_val);
|
|
} else {
|
|
result.is_int = false;
|
|
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
|
|
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
|
|
result.float_val = ToFloat32(l + r);
|
|
result.int_val = static_cast<long long>(result.float_val);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// 辅助函数:减法
|
|
ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) {
|
|
ConstValue result;
|
|
if (lhs.is_int && rhs.is_int) {
|
|
result.is_int = true;
|
|
result.int_val = lhs.int_val - rhs.int_val;
|
|
result.float_val = static_cast<double>(result.int_val);
|
|
} else {
|
|
result.is_int = false;
|
|
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
|
|
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
|
|
result.float_val = ToFloat32(l - r);
|
|
result.int_val = static_cast<long long>(result.float_val);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// 辅助函数:乘法
|
|
ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) {
|
|
ConstValue result;
|
|
if (lhs.is_int && rhs.is_int) {
|
|
result.is_int = true;
|
|
result.int_val = lhs.int_val * rhs.int_val;
|
|
result.float_val = static_cast<double>(result.int_val);
|
|
} else {
|
|
result.is_int = false;
|
|
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
|
|
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
|
|
result.float_val = ToFloat32(l * r);
|
|
result.int_val = static_cast<long long>(result.float_val);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// 辅助函数:除法
|
|
ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) {
|
|
ConstValue result;
|
|
if (lhs.is_int && rhs.is_int) {
|
|
result.is_int = true;
|
|
result.int_val = lhs.int_val / rhs.int_val;
|
|
result.float_val = static_cast<double>(result.int_val);
|
|
} else {
|
|
result.is_int = false;
|
|
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
|
|
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
|
|
result.float_val = ToFloat32(l / r);
|
|
result.int_val = static_cast<long long>(result.float_val);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// 辅助函数:取模
|
|
ConstValue ModValues(const ConstValue& lhs, const ConstValue& rhs) {
|
|
ConstValue result;
|
|
if (!lhs.is_int || !rhs.is_int) {
|
|
throw std::runtime_error(FormatError("sema", "取模运算只能用于整数"));
|
|
}
|
|
result.is_int = true;
|
|
result.int_val = lhs.int_val % rhs.int_val;
|
|
result.float_val = static_cast<double>(result.int_val);
|
|
return result;
|
|
}
|
|
|
|
// 辅助函数:取负
|
|
ConstValue NegValue(const ConstValue& val) {
|
|
ConstValue result;
|
|
result.is_int = val.is_int;
|
|
if (val.is_int) {
|
|
result.int_val = -val.int_val;
|
|
result.float_val = static_cast<double>(result.int_val);
|
|
} else {
|
|
result.float_val = -val.float_val;
|
|
result.int_val = static_cast<long long>(result.float_val);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// 辅助函数:逻辑非
|
|
ConstValue NotValue(const ConstValue& val) {
|
|
ConstValue result;
|
|
result.is_int = true;
|
|
if (val.is_int) {
|
|
result.int_val = !val.int_val;
|
|
} else {
|
|
result.int_val = !val.float_val;
|
|
}
|
|
result.float_val = static_cast<double>(result.int_val);
|
|
return result;
|
|
}
|
|
|
|
// 检查常量初始化器
|
|
size_t CheckConstInitVal(SysYParser::ConstInitValContext& ctx,
|
|
const std::vector<size_t>& dimensions,
|
|
bool is_int,
|
|
size_t total_elements) {
|
|
if (ctx.constExp()) {
|
|
// 单个常量值
|
|
// 求值并检查常量表达式
|
|
ConstValue value = EvaluateConstExp(*ctx.constExp());
|
|
|
|
// 检查类型约束
|
|
if (is_int && !value.is_int) {
|
|
throw std::runtime_error(FormatError("sema", "整型数组的初始化列表中不能出现浮点型常量"));
|
|
}
|
|
|
|
// 检查值域
|
|
if (is_int) {
|
|
if (value.int_val < INT_MIN || value.int_val > INT_MAX) {
|
|
throw std::runtime_error(FormatError("sema", "整数值超过int类型表示范围"));
|
|
}
|
|
}
|
|
return 1;
|
|
} else if (ctx.L_BRACE()) {
|
|
// 花括号初始化列表
|
|
size_t count = 0;
|
|
auto init_vals = ctx.constInitVal();
|
|
|
|
for (auto* init_val : init_vals) {
|
|
// 计算剩余维度的总元素个数
|
|
size_t remaining_elements = total_elements;
|
|
if (!dimensions.empty()) {
|
|
remaining_elements = total_elements / dimensions[0];
|
|
}
|
|
|
|
count += CheckConstInitVal(*init_val,
|
|
std::vector<size_t>(dimensions.begin() + 1, dimensions.end()),
|
|
is_int,
|
|
remaining_elements);
|
|
}
|
|
// 检查总元素个数
|
|
if (count > total_elements) {
|
|
throw std::runtime_error(FormatError("sema", "初始化列表元素个数超过数组大小"));
|
|
}
|
|
return count;
|
|
} else {
|
|
// 空初始化列表
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
// 检查变量初始化器
|
|
size_t CheckInitVal(SysYParser::InitValContext& ctx,
|
|
const std::vector<size_t>& dimensions,
|
|
bool is_int,
|
|
size_t total_elements) {
|
|
if (ctx.exp()) {
|
|
// 单个表达式值
|
|
// 检查表达式中的变量引用
|
|
// 这里不需要编译时求值,只需要检查类型约束
|
|
// 类型检查在IR生成阶段进行
|
|
return 1;
|
|
} else if (ctx.L_BRACE()) {
|
|
// 花括号初始化列表
|
|
size_t count = 0;
|
|
auto init_vals = ctx.initVal();
|
|
|
|
for (auto* init_val : init_vals) {
|
|
// 计算剩余维度的总元素个数
|
|
size_t remaining_elements = total_elements;
|
|
if (!dimensions.empty()) {
|
|
remaining_elements = total_elements / dimensions[0];
|
|
}
|
|
|
|
count += CheckInitVal(*init_val,
|
|
std::vector<size_t>(dimensions.begin() + 1, dimensions.end()),
|
|
is_int,
|
|
remaining_elements);
|
|
}
|
|
// 检查总元素个数
|
|
if (count > total_elements) {
|
|
throw std::runtime_error(FormatError("sema", "初始化列表元素个数超过数组大小"));
|
|
}
|
|
return count;
|
|
} else {
|
|
// 空初始化列表
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
} // namespace sem
|