阶段性保存

Shrink 1 month ago
parent 1fbdbb2ea1
commit 4413cfc4f5

@ -10,10 +10,10 @@
5. 15_graph_coloring - 图着色算法 (使用2D数组和指针参数)
6. 22_matrix_multiply - 矩阵乘法 (2D数组)
7. 25_scope3 - 作用域测试
8. 29_break - break语句
9. 36_op_priority2 - 运算符优先级
10. simple_add - 简单加法
### ✗ 失败的测试 (1个):
- 95_float - **需要浮点数常量支持** (当前仅支持int)

@ -58,3 +58,4 @@ cmake --build build -j "$(nproc)"
若最终输出 `输出匹配: test/test_case/simple_add.out`,说明当前示例用例 `return a + b` 的完整后端链路已经跑通。
但最终不能只检查 `simple_add`。完成 Lab3 后,应对 `test/test_case` 下全部测试用例逐个回归,确认代码生成结果能够通过统一验证;如有需要,也可以自行编写批量测试脚本统一执行。

@ -188,6 +188,7 @@ enum class Opcode {
Div,
Mod,
Cmp,
Cast,
Br,
CondBr,
Call,
@ -199,6 +200,7 @@ enum class Opcode {
};
enum class CmpOp { Eq, Ne, Lt, Le, Gt, Ge };
enum class CastOp { IntToFloat, FloatToInt };
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。
@ -229,14 +231,19 @@ class GlobalValue : public User {
// 数组:打印为 @name = global [count x i32] zeroinitializer。
class GlobalVariable : public GlobalValue {
public:
GlobalVariable(std::string name, int init_val = 0, int count = 1);
GlobalVariable(std::string name, std::shared_ptr<Type> ptr_ty,
int init_val = 0, int count = 1,
std::vector<int> init_elems = {});
int GetInitValue() const { return init_val_; }
int GetCount() const { return count_; }
bool IsArray() const { return count_ > 1; }
bool IsFloat() const { return GetType() && GetType()->IsPtrFloat32(); }
const std::vector<int>& GetInitElements() const { return init_elems_; }
private:
int init_val_;
int count_;
std::vector<int> init_elems_;
};
class Instruction : public User {
@ -272,6 +279,16 @@ class CmpInst : public Instruction {
CmpOp cmp_op_;
};
class CastInst : public Instruction {
public:
CastInst(CastOp op, std::shared_ptr<Type> ty, Value* val, std::string name);
CastOp GetCastOp() const;
Value* GetValue() const;
private:
CastOp cast_op_;
};
class ReturnInst : public Instruction {
public:
ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
@ -410,7 +427,10 @@ class Module {
Function* FindFunction(const std::string& name) const;
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVar(const std::string& name, int init_val = 0, int count = 1);
GlobalVariable* CreateGlobalVar(const std::string& name, int init_val = 0,
int count = 1,
std::shared_ptr<Type> ptr_ty = Type::GetPtrInt32Type(),
std::vector<int> init_elems = {});
GlobalVariable* FindGlobalVar(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVars() const;
@ -436,6 +456,8 @@ class IRBuilder {
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name);
CmpInst* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name);
CastInst* CreateSIToFP(Value* v, const std::string& name);
CastInst* CreateFPToSI(Value* v, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaArray(int count, const std::string& name);
AllocaInst* CreateAllocaF32(const std::string& name);

@ -25,6 +25,8 @@ class IRGenImpl final : public SysYBaseVisitor {
public:
// const 变量名 -> 编译期整数值,用于数组维度折叠。
using ConstEnv = std::unordered_map<std::string, int>;
// const 变量名 -> 编译期浮点值,用于 float const 折叠。
using ConstFloatEnv = std::unordered_map<std::string, float>;
IRGenImpl(ir::Module& module, const SemanticContext& sema);
@ -81,8 +83,12 @@ class IRGenImpl final : public SysYBaseVisitor {
// 编译期常量整数求值(用于数组维度)。
int EvalConstExpr(SysYParser::ConstExpContext* ctx) const;
// 编译期常量浮点求值(用于 float const
float EvalConstExprAsFloat(SysYParser::ConstExpContext* ctx) const;
// 将 ExpContext即 addExp按编译期常量求值用于 funcFParam 维度)。
int EvalExpAsConst(SysYParser::ExpContext* ctx) const;
// 将 ExpContext 按编译期常量浮点求值(用于 float 全局初始化等)。
float EvalExpAsConstFloat(SysYParser::ExpContext* ctx) const;
// 查找变量的数组维度(先查局部,再查全局)。
const std::vector<int>* FindArrayDims(const std::string& name) const;
@ -91,15 +97,28 @@ class IRGenImpl final : public SysYBaseVisitor {
ir::Value* ComputeLinearIndex(const std::vector<int>& dims,
const std::vector<SysYParser::ExpContext*>& subs);
// 简单隐式类型转换i32 <-> float。
ir::Value* CastToFloat(ir::Value* v);
ir::Value* CastToInt(ir::Value* v);
// 扁平化 constInitValue 到整数数组(供 const 数组初始化使用)。
void FlattenConstInit(SysYParser::ConstInitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
std::vector<int>& out, int& pos);
void FlattenConstInitFloat(SysYParser::ConstInitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
std::vector<float>& out, int& pos);
// 扁平化 initValue 到 ir::Value* 数组(供普通数组初始化使用)。
void FlattenInit(SysYParser::InitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
std::vector<ir::Value*>& out, int& pos);
void FlattenGlobalInitInt(SysYParser::InitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
std::vector<int>& out, int& pos);
void FlattenGlobalInitFloat(SysYParser::InitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
std::vector<float>& out, int& pos);
ir::AllocaInst* CreateEntryAllocaI32(const std::string& name);
ir::AllocaInst* CreateEntryAllocaArray(int count, const std::string& name);
@ -121,6 +140,8 @@ class IRGenImpl final : public SysYBaseVisitor {
std::unordered_map<std::string, ir::Value*> global_storage_;
// 编译期 const 整数环境(全局 + 当前函数)。
ConstEnv const_env_;
// 编译期 const 浮点环境(全局 + 当前函数)。
ConstFloatEnv const_float_env_;
// 数组维度信息:全局数组(跨函数持久)。
std::unordered_map<std::string, std::vector<int>> global_array_dims_;
// 数组维度信息:局部数组/参数(每函数清空)。

@ -57,6 +57,8 @@ enum class Opcode {
FSubRR, // 浮点减法
FMulRR, // 浮点乘法
FDivRR, // 浮点除法
SIToFP, // 有符号整型转浮点
FPToSI, // 浮点转有符号整型
CmpRR,
FCmpRR, // 浮点比较
Bl,
@ -162,14 +164,17 @@ class MachineModule {
return functions_;
}
void AddGlobalVar(std::string name, int init_val, int count);
const std::vector<std::tuple<std::string, int, int>>& GetGlobalVars() const {
void AddGlobalVar(std::string name, int init_val, int count, bool is_float,
std::vector<int> init_elems = {});
const std::vector<std::tuple<std::string, int, int, bool, std::vector<int>>>&
GetGlobalVars() const {
return global_vars_;
}
private:
std::vector<std::unique_ptr<MachineFunction>> functions_;
std::vector<std::tuple<std::string, int, int>> global_vars_; // (name, init, count)
std::vector<std::tuple<std::string, int, int, bool, std::vector<int>>>
global_vars_; // (name, init, count, is_float, init_elements)
};
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);

@ -83,7 +83,8 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
if diff -u <(perl -0pe 's/\n\z//' "$expected_file") \
<(perl -0pe 's/\n\z//' "$actual_file"); then
echo "输出匹配: $expected_file"
else
echo "输出不匹配: $expected_file" >&2

@ -7,9 +7,12 @@ namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
GlobalVariable::GlobalVariable(std::string name, int init_val, int count)
: GlobalValue(Type::GetPtrInt32Type(), std::move(name)),
GlobalVariable::GlobalVariable(std::string name, std::shared_ptr<Type> ptr_ty,
int init_val, int count,
std::vector<int> init_elems)
: GlobalValue(std::move(ptr_ty), std::move(name)),
init_val_(init_val),
count_(count) {}
count_(count),
init_elems_(std::move(init_elems)) {}
} // namespace ir

@ -75,6 +75,28 @@ CmpInst* IRBuilder::CreateCmp(CmpOp op, Value* lhs, Value* rhs,
name);
}
CastInst* IRBuilder::CreateSIToFP(Value* v, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!v) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateSIToFP 缺少操作数"));
}
return insert_block_->Append<CastInst>(CastOp::IntToFloat, Type::GetFloat32Type(),
v, name);
}
CastInst* IRBuilder::CreateFPToSI(Value* v, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!v) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateFPToSI 缺少操作数"));
}
return insert_block_->Append<CastInst>(CastOp::FloatToInt, Type::GetInt32Type(),
v, name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -116,7 +138,11 @@ GepInst* IRBuilder::CreateGep(Value* base, Value* index, const std::string& name
if (!base || !index) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGep 缺少操作数"));
}
return insert_block_->Append<GepInst>(Type::GetPtrInt32Type(), base, index, name);
std::shared_ptr<Type> ptr_ty = Type::GetPtrInt32Type();
if (base->GetType() && base->GetType()->IsPtrFloat32()) {
ptr_ty = Type::GetPtrFloat32Type();
}
return insert_block_->Append<GepInst>(ptr_ty, base, index, name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {

@ -4,6 +4,8 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <ostream>
#include <stdexcept>
#include <string>
@ -42,6 +44,8 @@ static const char* OpcodeToString(Opcode op) {
return "srem";
case Opcode::Cmp:
return "icmp";
case Opcode::Cast:
return "cast";
case Opcode::Br:
return "br";
case Opcode::CondBr:
@ -100,11 +104,20 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
// 先打印全局变量
for (const auto& gv : module.GetGlobalVars()) {
if (!gv) continue;
const char* elem_ty = gv->IsFloat() ? "float" : "i32";
if (gv->IsArray()) {
os << "@" << gv->GetName() << " = global [" << gv->GetCount()
<< " x i32] zeroinitializer\n";
<< " x " << elem_ty << "] zeroinitializer\n";
} else {
os << "@" << gv->GetName() << " = global i32 " << gv->GetInitValue() << "\n";
if (gv->IsFloat()) {
std::int32_t bits = static_cast<std::int32_t>(gv->GetInitValue());
float fval = 0.0f;
std::memcpy(&fval, &bits, sizeof(fval));
os << "@" << gv->GetName() << " = global float " << fval << "\n";
} else {
os << "@" << gv->GetName() << " = global i32 " << gv->GetInitValue()
<< "\n";
}
}
}
if (!module.GetGlobalVars().empty()) os << "\n";
@ -163,26 +176,41 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::Cast: {
auto* cast = static_cast<const CastInst*>(inst);
const char* cast_name =
(cast->GetCastOp() == CastOp::IntToFloat) ? "sitofp" : "fptosi";
os << " " << cast->GetName() << " = " << cast_name << " "
<< TypeToString(*cast->GetValue()->GetType()) << " "
<< ValueToString(cast->GetValue()) << " to "
<< TypeToString(*cast->GetType()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
const char* elem_ty = alloca->GetType()->IsPtrFloat32() ? "float" : "i32";
if (alloca->IsArray()) {
os << " " << alloca->GetName() << " = alloca i32, i32 "
os << " " << alloca->GetName() << " = alloca " << elem_ty << ", i32 "
<< alloca->GetCount() << "\n";
} else {
os << " " << alloca->GetName() << " = alloca i32\n";
os << " " << alloca->GetName() << " = alloca " << elem_ty << "\n";
}
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* "
os << " " << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n";
os << " store " << TypeToString(*store->GetValue()->GetType())
<< " " << ValueToString(store->GetValue())
<< ", " << TypeToString(*store->GetPtr()->GetType())
<< " " << ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Br: {
@ -219,18 +247,20 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
auto* base = gep->GetBase();
const char* elem_ty = base->GetType()->IsPtrFloat32() ? "float" : "i32";
// 全局数组用双下标 GEP局部 alloca 用平坦 GEP。
if (auto* gv = dynamic_cast<const GlobalVariable*>(base)) {
if (gv->IsArray()) {
os << " " << gep->GetName()
<< " = getelementptr [" << gv->GetCount() << " x i32], ["
<< gv->GetCount() << " x i32]* @" << gv->GetName()
<< " = getelementptr [" << gv->GetCount() << " x " << elem_ty << "], ["
<< gv->GetCount() << " x " << elem_ty << "]* @" << gv->GetName()
<< ", i32 0, i32 " << ValueToString(gep->GetIndex()) << "\n";
break;
}
}
os << " " << gep->GetName()
<< " = getelementptr i32, i32* " << ValueToString(base)
<< " = getelementptr " << elem_ty << ", "
<< TypeToString(*base->GetType()) << " " << ValueToString(base)
<< ", i32 " << ValueToString(gep->GetIndex()) << "\n";
break;
}

@ -124,8 +124,13 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
}
if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
const bool is_i32 = type_->IsInt32();
const bool is_f32 = type_->IsFloat32();
if (!is_i32 && !is_f32) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32/float"));
}
if (op == Opcode::Mod && !is_i32) {
throw std::runtime_error(FormatError("ir", "BinaryInst 的 mod 仅支持 i32"));
}
AddOperand(lhs);
AddOperand(rhs);
@ -147,9 +152,11 @@ CmpInst::CmpInst(CmpOp op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "CmpInst 结果类型必须为 i32"));
}
if (!lhs->GetType()->IsInt32() || !rhs->GetType()->IsInt32()) {
const bool is_int_cmp = lhs->GetType()->IsInt32() && rhs->GetType()->IsInt32();
const bool is_float_cmp = lhs->GetType()->IsFloat32() && rhs->GetType()->IsFloat32();
if (!is_int_cmp && !is_float_cmp) {
throw std::runtime_error(FormatError(
"ir", "CmpInst 当前只支持 i32 比较,实际为 " +
"ir", "CmpInst 当前只支持 i32/float 同类型比较,实际为 " +
std::string(TypeKindToString(lhs->GetType()->GetKind())) +
"" +
std::string(TypeKindToString(rhs->GetType()->GetKind()))));
@ -164,6 +171,28 @@ Value* CmpInst::GetLhs() const { return GetOperand(0); }
Value* CmpInst::GetRhs() const { return GetOperand(1); }
CastInst::CastInst(CastOp op, std::shared_ptr<Type> ty, Value* val,
std::string name)
: Instruction(Opcode::Cast, std::move(ty), std::move(name)), cast_op_(op) {
if (!val || !val->GetType() || !type_) {
throw std::runtime_error(FormatError("ir", "CastInst 缺少类型信息或操作数"));
}
if (cast_op_ == CastOp::IntToFloat) {
if (!val->GetType()->IsInt32() || !type_->IsFloat32()) {
throw std::runtime_error(FormatError("ir", "IntToFloat 需要 i32 -> float"));
}
} else {
if (!val->GetType()->IsFloat32() || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "FloatToInt 需要 float -> i32"));
}
}
AddOperand(val);
}
CastOp CastInst::GetCastOp() const { return cast_op_; }
Value* CastInst::GetValue() const { return GetOperand(0); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!type_ || !type_->IsVoid()) {
@ -327,8 +356,9 @@ GepInst::GepInst(std::shared_ptr<Type> ptr_ty, Value* base, Value* index,
if (!base || !index) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数"));
}
if (!base->GetType() || !base->GetType()->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "GepInst base 必须为 i32*"));
if (!base->GetType() ||
(!base->GetType()->IsPtrInt32() && !base->GetType()->IsPtrFloat32())) {
throw std::runtime_error(FormatError("ir", "GepInst base 必须为 i32*/float*"));
}
if (!index->GetType() || !index->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "GepInst index 必须为 i32"));

@ -27,8 +27,12 @@ Function* Module::FindFunction(const std::string& name) const {
return nullptr;
}
GlobalVariable* Module::CreateGlobalVar(const std::string& name, int init_val, int count) {
global_vars_.push_back(std::make_unique<GlobalVariable>(name, init_val, count));
GlobalVariable* Module::CreateGlobalVar(const std::string& name, int init_val,
int count, std::shared_ptr<Type> ptr_ty,
std::vector<int> init_elems) {
global_vars_.push_back(
std::make_unique<GlobalVariable>(name, std::move(ptr_ty), init_val, count,
std::move(init_elems)));
return global_vars_.back().get();
}

@ -1,5 +1,7 @@
#include "irgen/IRGen.h"
#include <cmath>
#include <cstdlib>
#include <stdexcept>
#include <string>
@ -9,75 +11,103 @@
// 内部辅助:不依赖类成员,只需 ConstEnv。
namespace {
int EvalAddExp(SysYParser::AddExpContext* ctx,
const IRGenImpl::ConstEnv& env);
int EvalMulExp(SysYParser::MulExpContext* ctx,
const IRGenImpl::ConstEnv& env);
int EvalUnaryExp(SysYParser::UnaryExpContext* ctx,
const IRGenImpl::ConstEnv& env);
double EvalAddExp(SysYParser::AddExpContext* ctx,
const IRGenImpl::ConstEnv& int_env,
const IRGenImpl::ConstFloatEnv& float_env);
double EvalMulExp(SysYParser::MulExpContext* ctx,
const IRGenImpl::ConstEnv& int_env,
const IRGenImpl::ConstFloatEnv& float_env);
double EvalUnaryExp(SysYParser::UnaryExpContext* ctx,
const IRGenImpl::ConstEnv& int_env,
const IRGenImpl::ConstFloatEnv& float_env);
int EvalPrimary(SysYParser::PrimaryExpContext* ctx,
const IRGenImpl::ConstEnv& env) {
int ParseIntLiteral(const std::string& text) {
if (text.size() >= 2 && text[0] == '0' &&
(text[1] == 'x' || text[1] == 'X')) {
return std::stoi(text, nullptr, 16);
}
if (text.size() > 1 && text[0] == '0') {
return std::stoi(text, nullptr, 8);
}
return std::stoi(text);
}
double EvalPrimary(SysYParser::PrimaryExpContext* ctx,
const IRGenImpl::ConstEnv& int_env,
const IRGenImpl::ConstFloatEnv& float_env) {
if (!ctx) throw std::runtime_error(FormatError("consteval", "空主表达式"));
if (ctx->number()) {
if (!ctx->number()->ILITERAL())
throw std::runtime_error(
FormatError("consteval", "constExp 不支持浮点字面量"));
return std::stoi(ctx->number()->getText());
if (ctx->number()->ILITERAL()) {
return static_cast<double>(ParseIntLiteral(ctx->number()->getText()));
}
if (ctx->number()->FLITERAL()) {
return static_cast<double>(std::strtof(ctx->number()->getText().c_str(), nullptr));
}
throw std::runtime_error(FormatError("consteval", "非法数字字面量"));
}
if (ctx->exp()) return EvalAddExp(ctx->exp()->addExp(), env);
if (ctx->exp()) return EvalAddExp(ctx->exp()->addExp(), int_env, float_env);
if (ctx->lValue()) {
if (!ctx->lValue()->ID())
throw std::runtime_error(FormatError("consteval", "非法 lValue"));
const std::string name = ctx->lValue()->ID()->getText();
auto it = env.find(name);
if (it == env.end())
throw std::runtime_error(
FormatError("consteval", "constExp 引用非 const 变量: " + name));
return it->second;
auto it_int = int_env.find(name);
if (it_int != int_env.end()) return static_cast<double>(it_int->second);
auto it_float = float_env.find(name);
if (it_float != float_env.end()) return static_cast<double>(it_float->second);
throw std::runtime_error(
FormatError("consteval", "constExp 引用非 const 变量: " + name));
}
throw std::runtime_error(FormatError("consteval", "不支持的主表达式形式"));
}
int EvalUnaryExp(SysYParser::UnaryExpContext* ctx,
const IRGenImpl::ConstEnv& env) {
double EvalUnaryExp(SysYParser::UnaryExpContext* ctx,
const IRGenImpl::ConstEnv& int_env,
const IRGenImpl::ConstFloatEnv& float_env) {
if (!ctx) throw std::runtime_error(FormatError("consteval", "空一元表达式"));
if (ctx->primaryExp()) return EvalPrimary(ctx->primaryExp(), env);
if (ctx->primaryExp()) return EvalPrimary(ctx->primaryExp(), int_env, float_env);
if (ctx->unaryOp() && ctx->unaryExp()) {
int v = EvalUnaryExp(ctx->unaryExp(), env);
double v = EvalUnaryExp(ctx->unaryExp(), int_env, float_env);
if (ctx->unaryOp()->SUB()) return -v;
if (ctx->unaryOp()->ADD()) return v;
if (ctx->unaryOp()->NOT()) return (v == 0) ? 1 : 0;
if (ctx->unaryOp()->NOT()) return (v == 0.0) ? 1.0 : 0.0;
}
throw std::runtime_error(
FormatError("consteval", "函数调用不能出现在 constExp 中"));
}
int EvalMulExp(SysYParser::MulExpContext* ctx,
const IRGenImpl::ConstEnv& env) {
double EvalMulExp(SysYParser::MulExpContext* ctx,
const IRGenImpl::ConstEnv& int_env,
const IRGenImpl::ConstFloatEnv& float_env) {
if (!ctx) throw std::runtime_error(FormatError("consteval", "空乘法表达式"));
if (ctx->mulExp()) {
int lhs = EvalMulExp(ctx->mulExp(), env);
int rhs = EvalUnaryExp(ctx->unaryExp(), env);
double lhs = EvalMulExp(ctx->mulExp(), int_env, float_env);
double rhs = EvalUnaryExp(ctx->unaryExp(), int_env, float_env);
if (ctx->MUL()) return lhs * rhs;
if (ctx->DIV()) { if (!rhs) throw std::runtime_error("除以零"); return lhs / rhs; }
if (ctx->MOD()) { if (!rhs) throw std::runtime_error("模零"); return lhs % rhs; }
if (ctx->DIV()) {
if (rhs == 0.0) throw std::runtime_error("除以零");
return lhs / rhs;
}
if (ctx->MOD()) {
if (rhs == 0.0) throw std::runtime_error("模零");
return std::fmod(lhs, rhs);
}
throw std::runtime_error(FormatError("consteval", "未知乘法运算符"));
}
return EvalUnaryExp(ctx->unaryExp(), env);
return EvalUnaryExp(ctx->unaryExp(), int_env, float_env);
}
int EvalAddExp(SysYParser::AddExpContext* ctx,
const IRGenImpl::ConstEnv& env) {
double EvalAddExp(SysYParser::AddExpContext* ctx,
const IRGenImpl::ConstEnv& int_env,
const IRGenImpl::ConstFloatEnv& float_env) {
if (!ctx) throw std::runtime_error(FormatError("consteval", "空加法表达式"));
if (ctx->addExp()) {
int lhs = EvalAddExp(ctx->addExp(), env);
int rhs = EvalMulExp(ctx->mulExp(), env);
double lhs = EvalAddExp(ctx->addExp(), int_env, float_env);
double rhs = EvalMulExp(ctx->mulExp(), int_env, float_env);
if (ctx->ADD()) return lhs + rhs;
if (ctx->SUB()) return lhs - rhs;
throw std::runtime_error(FormatError("consteval", "未知加法运算符"));
}
return EvalMulExp(ctx->mulExp(), env);
return EvalMulExp(ctx->mulExp(), int_env, float_env);
}
} // namespace
@ -85,11 +115,23 @@ int EvalAddExp(SysYParser::AddExpContext* ctx,
int IRGenImpl::EvalConstExpr(SysYParser::ConstExpContext* ctx) const {
if (!ctx || !ctx->addExp())
throw std::runtime_error(FormatError("consteval", "空 constExp"));
return EvalAddExp(ctx->addExp(), const_env_);
return static_cast<int>(EvalAddExp(ctx->addExp(), const_env_, const_float_env_));
}
float IRGenImpl::EvalConstExprAsFloat(SysYParser::ConstExpContext* ctx) const {
if (!ctx || !ctx->addExp())
throw std::runtime_error(FormatError("consteval", "空 constExp"));
return static_cast<float>(EvalAddExp(ctx->addExp(), const_env_, const_float_env_));
}
int IRGenImpl::EvalExpAsConst(SysYParser::ExpContext* ctx) const {
if (!ctx || !ctx->addExp())
throw std::runtime_error(FormatError("consteval", "空 exp"));
return EvalAddExp(ctx->addExp(), const_env_);
return static_cast<int>(EvalAddExp(ctx->addExp(), const_env_, const_float_env_));
}
float IRGenImpl::EvalExpAsConstFloat(SysYParser::ExpContext* ctx) const {
if (!ctx || !ctx->addExp())
throw std::runtime_error(FormatError("consteval", "空 exp"));
return static_cast<float>(EvalAddExp(ctx->addExp(), const_env_, const_float_env_));
}

@ -1,5 +1,7 @@
#include "irgen/IRGen.h"
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include "SysYParser.h"
@ -10,6 +12,8 @@ std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
const auto saved_const_env = const_env_;
const auto saved_const_float_env = const_float_env_;
for (auto* item : ctx->blockItem()) {
if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
@ -17,6 +21,8 @@ std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
}
}
}
const_env_ = saved_const_env;
const_float_env_ = saved_const_float_env;
return {};
}
@ -98,6 +104,40 @@ void IRGenImpl::FlattenConstInit(SysYParser::ConstInitValueContext* ctx,
while (pos < start + agg_size) out[pos++] = 0;
}
void IRGenImpl::FlattenConstInitFloat(SysYParser::ConstInitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
std::vector<float>& out, int& pos) {
if (!ctx) return;
if (ctx->constExp()) {
out[pos++] = EvalConstExprAsFloat(ctx->constExp());
return;
}
int sub_size = 1;
for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i];
int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1;
int start = pos;
for (auto* item : ctx->constInitValue()) {
if (!item || pos >= start + agg_size) break;
if (item->constExp()) {
out[pos++] = EvalConstExprAsFloat(item->constExp());
} else {
if (sub_size > 1) {
int offset = pos - start;
int rem = offset % sub_size;
if (rem != 0) pos += sub_size - rem;
}
int sub_start = pos;
FlattenConstInitFloat(item, dims, dim_idx + 1, out, pos);
int sub_end = sub_start + sub_size;
while (pos < sub_end && pos < start + agg_size) out[pos++] = 0.0f;
}
}
while (pos < start + agg_size) out[pos++] = 0.0f;
}
// ─── 工具:扁平化 initValue ───────────────────────────────────────────────
void IRGenImpl::FlattenInit(SysYParser::InitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
@ -133,6 +173,75 @@ void IRGenImpl::FlattenInit(SysYParser::InitValueContext* ctx,
while (pos < start + agg_size) pos++; // zeros
}
void IRGenImpl::FlattenGlobalInitInt(SysYParser::InitValueContext* ctx,
const std::vector<int>& dims, int dim_idx,
std::vector<int>& out, int& pos) {
if (!ctx) return;
if (ctx->exp()) {
out[pos++] = EvalExpAsConst(ctx->exp());
return;
}
int sub_size = 1;
for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i];
int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1;
int start = pos;
for (auto* item : ctx->initValue()) {
if (!item || pos >= start + agg_size) break;
if (item->exp()) {
out[pos++] = EvalExpAsConst(item->exp());
} else {
if (sub_size > 1) {
int offset = pos - start;
int rem = offset % sub_size;
if (rem != 0) pos += sub_size - rem;
}
int sub_start = pos;
FlattenGlobalInitInt(item, dims, dim_idx + 1, out, pos);
int sub_end = sub_start + sub_size;
while (pos < sub_end && pos < start + agg_size) out[pos++] = 0;
}
}
while (pos < start + agg_size) out[pos++] = 0;
}
void IRGenImpl::FlattenGlobalInitFloat(SysYParser::InitValueContext* ctx,
const std::vector<int>& dims,
int dim_idx, std::vector<float>& out,
int& pos) {
if (!ctx) return;
if (ctx->exp()) {
out[pos++] = EvalExpAsConstFloat(ctx->exp());
return;
}
int sub_size = 1;
for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i];
int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1;
int start = pos;
for (auto* item : ctx->initValue()) {
if (!item || pos >= start + agg_size) break;
if (item->exp()) {
out[pos++] = EvalExpAsConstFloat(item->exp());
} else {
if (sub_size > 1) {
int offset = pos - start;
int rem = offset % sub_size;
if (rem != 0) pos += sub_size - rem;
}
int sub_start = pos;
FlattenGlobalInitFloat(item, dims, dim_idx + 1, out, pos);
int sub_end = sub_start + sub_size;
while (pos < sub_end && pos < start + agg_size) out[pos++] = 0.0f;
}
}
while (pos < start + agg_size) out[pos++] = 0.0f;
}
// ─── const 声明 ───────────────────────────────────────────────────────────
std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
@ -140,16 +249,17 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
if (!ctx->btype()) {
throw std::runtime_error(FormatError("irgen", "缺少类型声明"));
}
// 暂时只处理int constfloat const留待后续实现
if (ctx->btype()->FLOAT()) {
throw std::runtime_error(FormatError("irgen", "暂不支持 float const 声明"));
}
if (!ctx->btype()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持 int const 声明"));
if (ctx->btype()->INT()) {
current_decl_type_ = ir::Type::GetInt32Type();
} else if (ctx->btype()->FLOAT()) {
current_decl_type_ = ir::Type::GetFloat32Type();
} else {
throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float const 声明"));
}
for (auto* def : ctx->constDef()) {
if (def) def->accept(this);
}
current_decl_type_ = nullptr;
return {};
}
@ -162,16 +272,34 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx->constInitValue() || !ctx->constInitValue()->constExp()) {
throw std::runtime_error(FormatError("irgen", "const 标量声明缺少初始值"));
}
int ival = EvalConstExpr(ctx->constInitValue()->constExp());
const_env_[name] = ival; // 存入编译期环境
if (IsGlobalScope()) {
auto* gv = module_.CreateGlobalVar(name, ival);
global_storage_[name] = gv;
const bool is_float_const = current_decl_type_ && current_decl_type_->IsFloat32();
if (is_float_const) {
float fval = EvalConstExprAsFloat(ctx->constInitValue()->constExp());
const_float_env_[name] = fval;
if (IsGlobalScope()) {
std::int32_t bits = 0;
std::memcpy(&bits, &fval, sizeof(bits));
auto* gv = module_.CreateGlobalVar(
name, static_cast<int>(bits), 1, ir::Type::GetPtrFloat32Type());
global_storage_[name] = gv;
} else {
auto* slot = CreateEntryAllocaF32(module_.GetContext().NextTemp());
named_storage_[name] = slot;
builder_.CreateStore(module_.GetContext().GetConstFloat(fval), slot);
}
} else {
auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp());
named_storage_[name] = slot;
builder_.CreateStore(builder_.CreateConstInt(ival), slot);
int ival = EvalConstExpr(ctx->constInitValue()->constExp());
const_env_[name] = ival; // 存入编译期环境
if (IsGlobalScope()) {
auto* gv = module_.CreateGlobalVar(name, ival);
global_storage_[name] = gv;
} else {
auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp());
named_storage_[name] = slot;
builder_.CreateStore(builder_.CreateConstInt(ival), slot);
}
}
return {};
}
@ -184,6 +312,40 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
int total = 1;
for (int d : dims) total *= d;
const bool is_float_const = current_decl_type_ && current_decl_type_->IsFloat32();
if (is_float_const) {
std::vector<float> flat(total, 0.0f);
if (ctx->constInitValue()) {
int pos = 0;
FlattenConstInitFloat(ctx->constInitValue(), dims, 0, flat, pos);
}
std::vector<int> init_bits;
init_bits.reserve(flat.size());
for (float v : flat) {
std::int32_t bits = 0;
std::memcpy(&bits, &v, sizeof(bits));
init_bits.push_back(static_cast<int>(bits));
}
if (IsGlobalScope()) {
auto* gv = module_.CreateGlobalVar(
name, 0, total, ir::Type::GetPtrFloat32Type(), std::move(init_bits));
global_storage_[name] = gv;
global_array_dims_[name] = dims;
} else {
auto* slot = CreateEntryAllocaF32Array(total, module_.GetContext().NextTemp());
named_storage_[name] = slot;
local_array_dims_[name] = dims;
for (int i = 0; i < total; i++) {
auto* idx = builder_.CreateConstInt(i);
auto* ptr = builder_.CreateGep(slot, idx, module_.GetContext().NextTemp());
builder_.CreateStore(module_.GetContext().GetConstFloat(flat[i]), ptr);
}
}
return {};
}
// 扁平化初始化值
std::vector<int> flat(total, 0);
if (ctx->constInitValue()) {
@ -192,9 +354,9 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
}
if (IsGlobalScope()) {
// 全局 const 数组:创建全局数组变量(仅支持零初始化;非零初始化暂用零)
// TODO: 支持全局 const 数组的非零初始化
auto* gv = module_.CreateGlobalVar(name, 0, total);
auto* gv = module_.CreateGlobalVar(name, 0, total,
ir::Type::GetPtrInt32Type(),
std::move(flat));
global_storage_[name] = gv;
global_array_dims_[name] = dims;
} else {
@ -255,11 +417,32 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
for (int d : dims) total *= d;
if (IsGlobalScope()) {
auto* gv = module_.CreateGlobalVar(name, 0, total);
std::vector<int> init_elems;
if (auto* init_val = ctx->initValue()) {
if (current_decl_type_->IsFloat32()) {
std::vector<float> flat(total, 0.0f);
int pos = 0;
FlattenGlobalInitFloat(init_val, dims, 0, flat, pos);
init_elems.reserve(flat.size());
for (float v : flat) {
std::int32_t bits = 0;
std::memcpy(&bits, &v, sizeof(bits));
init_elems.push_back(static_cast<int>(bits));
}
} else {
init_elems.assign(total, 0);
int pos = 0;
FlattenGlobalInitInt(init_val, dims, 0, init_elems, pos);
}
}
auto* gv = module_.CreateGlobalVar(
name, 0, total,
current_decl_type_->IsFloat32() ? ir::Type::GetPtrFloat32Type()
: ir::Type::GetPtrInt32Type(),
std::move(init_elems));
storage_map_[ctx] = gv;
global_storage_[name] = gv;
global_array_dims_[name] = dims;
// 全局数组:不支持运行时初始化(全零已足够)
} else {
// 根据当前声明类型创建数组alloca
ir::AllocaInst* slot;
@ -291,7 +474,13 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (flat[i] != nullptr) {
auto* idx = builder_.CreateConstInt(i);
auto* ptr = builder_.CreateGep(slot, idx, module_.GetContext().NextTemp());
builder_.CreateStore(flat[i], ptr);
ir::Value* val = flat[i];
if (ptr->GetType()->IsPtrFloat32()) {
val = CastToFloat(val);
} else {
val = CastToInt(val);
}
builder_.CreateStore(val, ptr);
}
}
}
@ -301,15 +490,32 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
// ── 标量变量 ──────────────────────────────────────────────────────────
if (IsGlobalScope()) {
int ival = 0;
if (auto* init_value = ctx->initValue()) {
if (!init_value->exp()) {
throw std::runtime_error(
FormatError("irgen", "全局标量变量仅支持表达式初始化"));
int init_bits_or_int = 0;
if (current_decl_type_->IsFloat32()) {
float fval = 0.0f;
if (auto* init_value = ctx->initValue()) {
if (!init_value->exp()) {
throw std::runtime_error(
FormatError("irgen", "全局标量变量仅支持表达式初始化"));
}
fval = EvalExpAsConstFloat(init_value->exp());
}
std::int32_t bits = 0;
std::memcpy(&bits, &fval, sizeof(bits));
init_bits_or_int = static_cast<int>(bits);
} else {
if (auto* init_value = ctx->initValue()) {
if (!init_value->exp()) {
throw std::runtime_error(
FormatError("irgen", "全局标量变量仅支持表达式初始化"));
}
init_bits_or_int = EvalExpAsConst(init_value->exp());
}
ival = EvalExpAsConst(init_value->exp());
}
auto* gv = module_.CreateGlobalVar(name, ival);
auto* gv = module_.CreateGlobalVar(
name, init_bits_or_int, 1,
current_decl_type_->IsFloat32() ? ir::Type::GetPtrFloat32Type()
: ir::Type::GetPtrInt32Type());
storage_map_[ctx] = gv;
global_storage_[name] = gv;
return {};
@ -343,6 +549,11 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
init = builder_.CreateConstInt(0);
}
}
if (current_decl_type_->IsFloat32()) {
init = CastToFloat(init);
} else {
init = CastToInt(init);
}
builder_.CreateStore(init, slot);
return {};
}

@ -14,18 +14,42 @@ ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) {
return std::any_cast<ir::Value*>(cond.accept(this));
}
ir::Value* IRGenImpl::CastToFloat(ir::Value* v) {
if (!v || !v->GetType()) {
throw std::runtime_error(FormatError("irgen", "CastToFloat 输入为空"));
}
if (v->GetType()->IsFloat32()) return v;
if (v->GetType()->IsInt32()) {
return builder_.CreateSIToFP(v, module_.GetContext().NextTemp());
}
throw std::runtime_error(FormatError("irgen", "不支持转换到 float 的类型"));
}
ir::Value* IRGenImpl::CastToInt(ir::Value* v) {
if (!v || !v->GetType()) {
throw std::runtime_error(FormatError("irgen", "CastToInt 输入为空"));
}
if (v->GetType()->IsInt32()) return v;
if (v->GetType()->IsFloat32()) {
return builder_.CreateFPToSI(v, module_.GetContext().NextTemp());
}
throw std::runtime_error(FormatError("irgen", "不支持转换到 i32 的类型"));
}
ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) {
if (!v) {
throw std::runtime_error(FormatError("irgen", "条件值为空"));
}
if (v->GetType() && v->GetType()->IsPtrInt32()) {
if (v->GetType() && (v->GetType()->IsPtrInt32() || v->GetType()->IsPtrFloat32())) {
// SysY 中数组名退化得到的指针在当前实现里总是非空。
return builder_.CreateConstInt(1);
}
if (dynamic_cast<ir::CmpInst*>(v) != nullptr) {
return v;
}
auto* zero = builder_.CreateConstInt(0);
ir::Value* zero = v->GetType()->IsFloat32()
? static_cast<ir::Value*>(module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp());
}
@ -60,7 +84,7 @@ ir::Value* IRGenImpl::ComputeLinearIndex(
int stride = 1;
for (int j = k + 1; j < (int)dims.size(); j++) stride *= dims[j];
ir::Value* idx = EvalExpr(*subs[k]);
ir::Value* idx = CastToInt(EvalExpr(*subs[k]));
if (stride != 1) {
auto* sv = builder_.CreateConstInt(stride);
idx = builder_.CreateMul(idx, sv, module_.GetContext().NextTemp());
@ -184,6 +208,15 @@ std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) {
const std::string name = ctx->ID()->getText();
if (ctx->exp().empty()) {
auto itf = const_float_env_.find(name);
if (itf != const_float_env_.end()) {
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(itf->second));
}
auto iti = const_env_.find(name);
if (iti != const_env_.end()) {
return static_cast<ir::Value*>(builder_.CreateConstInt(iti->second));
}
// 无下标:标量读取 或 数组基址引用
ir::Value* slot = ResolveStorage(ctx);
if (!slot) {
@ -230,7 +263,9 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (ctx->unaryOp() && ctx->unaryExp()) {
ir::Value* v = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->unaryOp()->SUB()) {
auto* zero = builder_.CreateConstInt(0);
ir::Value* zero = v->GetType()->IsFloat32()
? static_cast<ir::Value*>(module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
return static_cast<ir::Value*>(builder_.CreateSub(
zero, v, module_.GetContext().NextTemp()));
}
@ -239,7 +274,9 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
}
if (ctx->unaryOp()->NOT()) {
// !v ≡ (v == 0)
auto* zero = builder_.CreateConstInt(0);
ir::Value* zero = v->GetType()->IsFloat32()
? static_cast<ir::Value*>(module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Eq, v, zero, module_.GetContext().NextTemp()));
}
@ -255,8 +292,19 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
}
std::vector<ir::Value*> args;
if (auto* rparams = ctx->funcRParams()) {
const auto& param_types = callee->GetParamTypes();
size_t i = 0;
for (auto* ep : rparams->exp()) {
args.push_back(EvalExpr(*ep));
ir::Value* arg = EvalExpr(*ep);
if (i < param_types.size()) {
if (param_types[i]->IsFloat32()) {
arg = CastToFloat(arg);
} else if (param_types[i]->IsInt32()) {
arg = CastToInt(arg);
}
}
args.push_back(arg);
++i;
}
}
const std::string name =
@ -277,6 +325,11 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
const bool has_float = lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32();
if (has_float) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
}
if (ctx->MUL()) {
return static_cast<ir::Value*>(
builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp()));
@ -286,6 +339,8 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
builder_.CreateDiv(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->MOD()) {
lhs = CastToInt(lhs);
rhs = CastToInt(rhs);
return static_cast<ir::Value*>(
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
}
@ -307,6 +362,10 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
}
if (ctx->ADD()) {
return static_cast<ir::Value*>(
builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()));
@ -333,6 +392,10 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
}
if (ctx->LT()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Lt, lhs, rhs, module_.GetContext().NextTemp()));
@ -367,6 +430,10 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
}
if (ctx->EQ()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Eq, lhs, rhs, module_.GetContext().NextTemp()));

@ -93,6 +93,11 @@ void IRGenImpl::DeclareRuntimeFunctions() {
// 数组 I/O
decl("getarray", i32, {ir::Type::GetPtrInt32Type()});
decl("putarray", void_, {i32, ir::Type::GetPtrInt32Type()});
// 浮点 I/O
decl("getfloat", ir::Type::GetFloat32Type(), {});
decl("getfarray", i32, {ir::Type::GetPtrFloat32Type()});
decl("putfloat", void_, {ir::Type::GetFloat32Type()});
decl("putfarray", void_, {i32, ir::Type::GetPtrFloat32Type()});
// 时间
decl("starttime", void_, {});
decl("stoptime", void_, {});
@ -216,7 +221,12 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
}
} else {
// 标量参数alloca + store
auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp());
ir::AllocaInst* slot = nullptr;
if (arg->GetType()->IsFloat32()) {
slot = CreateEntryAllocaF32(module_.GetContext().NextTemp());
} else {
slot = CreateEntryAllocaI32(module_.GetContext().NextTemp());
}
builder_.CreateStore(arg, slot);
if (!param_names[i].empty()) {
named_storage_[param_names[i]] = slot;

@ -29,6 +29,11 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
? ctx->lValue()->ID()->getText()
: "?")));
}
if (slot->GetType() && slot->GetType()->IsPtrFloat32()) {
rhs = CastToFloat(rhs);
} else if (slot->GetType() && slot->GetType()->IsPtrInt32()) {
rhs = CastToInt(rhs);
}
builder_.CreateStore(rhs, slot);
return BlockFlow::Continue;
}
@ -138,6 +143,13 @@ std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
return BlockFlow::Terminated;
}
ir::Value* v = EvalExpr(*ctx->exp());
if (func_ && func_->GetType()) {
if (func_->GetType()->IsFloat32()) {
v = CastToFloat(v);
} else if (func_->GetType()->IsInt32()) {
v = CastToInt(v);
}
}
builder_.CreateRet(v);
return BlockFlow::Terminated;
}

@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <cstdint>
#include <ostream>
#include <stdexcept>
@ -17,6 +18,43 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintMoveImm32(std::ostream& os, PhysReg reg, int imm) {
std::uint32_t u = static_cast<std::uint32_t>(imm);
std::uint32_t lo = u & 0xFFFFu;
std::uint32_t hi = (u >> 16) & 0xFFFFu;
os << " movz " << PhysRegName(reg) << ", #" << lo << "\n";
if (hi != 0) {
os << " movk " << PhysRegName(reg) << ", #" << hi << ", lsl #16\n";
}
}
void PrintStackAdjust(std::ostream& os, const char* mnemonic, int size) {
if (size >= 0 && size <= 4095) {
os << " " << mnemonic << " sp, sp, #" << size << "\n";
return;
}
PrintMoveImm32(os, PhysReg::X10, size);
os << " " << mnemonic << " sp, sp, x10\n";
}
void PrintAddrFromX29(std::ostream& os, PhysReg dst, int offset) {
if (offset >= -4095 && offset <= 4095) {
if (offset >= 0) {
os << " add " << PhysRegName(dst) << ", x29, #" << offset << "\n";
} else {
os << " sub " << PhysRegName(dst) << ", x29, #" << (-offset) << "\n";
}
return;
}
PrintMoveImm32(os, PhysReg::X10, offset < 0 ? -offset : offset);
if (offset >= 0) {
os << " add " << PhysRegName(dst) << ", x29, x10\n";
} else {
os << " sub " << PhysRegName(dst) << ", x29, x10\n";
}
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
// AArch64 ldur/stur 只支持 -256..255 的立即数偏移
@ -25,13 +63,10 @@ void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
<< "]\n";
} else {
// 大偏移:使用 x10 作为临时寄存器
// sub x10, x29, #abs(offset)
// ldr/str reg, [x10]
int abs_offset = -offset; // offset 是负数
bool is_load = (mnemonic[0] == 'l'); // ldur -> ldr
const char* base_mnemonic = is_load ? "ldr" : "str";
os << " sub x10, x29, #" << abs_offset << "\n";
PrintAddrFromX29(os, PhysReg::X10, offset);
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x10]\n";
}
}
@ -42,7 +77,9 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
// 输出全局变量定义
if (!module.GetGlobalVars().empty()) {
os << ".data\n";
for (const auto& [name, init_val, count] : module.GetGlobalVars()) {
for (const auto& [name, init_val, count, is_float, init_elems] :
module.GetGlobalVars()) {
(void)is_float;
os << ".global " << name << "\n";
os << ".type " << name << ", %object\n";
os << name << ":\n";
@ -50,8 +87,20 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
// 标量全局变量
os << " .word " << init_val << "\n";
} else {
// 数组全局变量(全零初始化)
os << " .zero " << (count * 4) << "\n";
// 数组全局变量:优先输出显式初始化元素,剩余部分补零。
int emitted = 0;
for (int elem : init_elems) {
if (emitted >= count) {
break;
}
os << " .word " << elem << "\n";
++emitted;
}
if (emitted == 0) {
os << " .zero " << (count * 4) << "\n";
} else if (emitted < count) {
os << " .zero " << ((count - emitted) * 4) << "\n";
}
}
}
os << "\n";
@ -80,23 +129,31 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
PrintStackAdjust(os, "sub", function.GetFrameSize());
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
PrintStackAdjust(os, "add", function.GetFrameSize());
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
PrintMoveImm32(os, ops.at(0).GetReg(), ops.at(1).GetImm());
break;
case Opcode::MovReg:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FMovImm:
// 通用浮点立即数:先装载 bit pattern再位级移动到 s 寄存器。
PrintMoveImm32(os, PhysReg::W10, ops.at(1).GetImm());
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", w10\n";
break;
case Opcode::FMovReg:
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
@ -125,12 +182,7 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
// ops: xN, frame_index
// add xN, x29, #offset
const auto& slot = GetFrameSlot(function, ops.at(1));
int offset = slot.offset;
if (offset >= 0) {
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << offset << "\n";
} else {
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << (-offset) << "\n";
}
PrintAddrFromX29(os, ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::LoadIndirect: {
@ -196,6 +248,34 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FAddRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ModRR:
// 不应该出现Mod 在 lowering 时已展开为 div+mul+sub
throw std::runtime_error(FormatError("mir", "ModRR 不应被打印"));
@ -222,6 +302,24 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
<< cond_suffix << "\n";
break;
}
case Opcode::FCmpRR: {
// ops: dst(wN), lhs(sN), rhs(sN), cmpop(imm)
auto cmp_op = static_cast<ir::CmpOp>(ops.at(3).GetImm());
const char* cond_suffix = "";
switch (cmp_op) {
case ir::CmpOp::Eq: cond_suffix = "eq"; break;
case ir::CmpOp::Ne: cond_suffix = "ne"; break;
case ir::CmpOp::Lt: cond_suffix = "lt"; break;
case ir::CmpOp::Le: cond_suffix = "le"; break;
case ir::CmpOp::Gt: cond_suffix = "gt"; break;
case ir::CmpOp::Ge: cond_suffix = "ge"; break;
}
os << " fcmp " << PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
<< cond_suffix << "\n";
break;
}
case Opcode::Bl:
os << " bl " << ops.at(0).GetSymbol() << "\n";
break;

@ -18,12 +18,6 @@ void RunFrameLowering(MachineFunction& function) {
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
// AArch64 ldur/stur 支持 -256 到 +255 的立即数偏移
// 如果超出范围,需要使用多条指令
// 这里暂时放宽限制到 4096单页大小
if (-cursor < -4096) {
throw std::runtime_error(FormatError("mir", "栈帧超过 4KB需要更复杂的栈帧处理"));
}
}
cursor = 0;

@ -1,5 +1,7 @@
#include "mir/MIR.h"
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <unordered_map>
@ -23,8 +25,12 @@ struct GepInfo {
};
using GepMap = std::unordered_map<const ir::Value*, GepInfo>;
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
bool IsPointerType(const std::shared_ptr<ir::Type>& type) {
return type && (type->IsPtrInt32() || type->IsPtrFloat32());
}
void EmitIntValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
@ -48,6 +54,36 @@ void EmitValueToReg(const ir::Value* value, PhysReg target,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
void EmitFloatValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantFloat*>(value)) {
std::int32_t bits = 0;
float fv = constant->GetValue();
std::memcpy(&bits, &fv, sizeof(bits));
block.Append(Opcode::FMovImm,
{Operand::Reg(target), Operand::Imm(static_cast<int>(bits))});
return;
}
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(
FormatError("mir", "找不到浮点值对应的栈槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack,
{Operand::Reg(target), Operand::FrameIndex(it->second)});
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (value->GetType() && value->GetType()->IsFloat32()) {
EmitFloatValueToReg(value, target, slots, block);
return;
}
EmitIntValueToReg(value, target, slots, block);
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
MachineBasicBlock& block, ValueSlotMap& slots,
GepMap& geps) {
@ -120,7 +156,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
}
// 检查 base 是否是指针参数:如果是 Argument 且类型是指针
if (dynamic_cast<const ir::Argument*>(base) && base->GetType()->IsPtrInt32()) {
if (dynamic_cast<const ir::Argument*>(base) && IsPointerType(base->GetType())) {
// 指针参数:从栈加载指针值,然后加上索引
if (auto* const_index = dynamic_cast<const ir::ConstantInt*>(index)) {
// 常量索引
@ -212,12 +248,15 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
auto* ptr = store.GetPtr();
const bool is_float_value =
store.GetValue()->GetType() && store.GetValue()->GetType()->IsFloat32();
const PhysReg src_reg = is_float_value ? PhysReg::S0 : PhysReg::W8;
// 检查是否是 GEP 结果(数组元素)
auto gep_it = geps.find(ptr);
if (gep_it != geps.end()) {
const auto& gep_info = gep_it->second;
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
EmitValueToReg(store.GetValue(), src_reg, slots, block);
if (gep_info.base_slot == -1) {
// 全局数组
@ -233,7 +272,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
} else {
// 变量索引global_array[var_idx]
int index_slot = -1 - gep_info.byte_offset;
@ -254,12 +293,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
Operand::Reg(PhysReg::X10)});
// 5. 存储
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
}
} else if (gep_info.byte_offset >= 0) {
// 本地数组,常量索引
block.Append(Opcode::StoreStackOffset,
{Operand::Reg(PhysReg::W8),
{Operand::Reg(src_reg),
Operand::FrameIndex(gep_info.base_slot),
Operand::Imm(gep_info.byte_offset)});
} else {
@ -278,16 +317,16 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
}
return;
}
// 检查是否是全局变量
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(ptr)) {
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
EmitValueToReg(store.GetValue(), src_reg, slots, block);
block.Append(Opcode::StoreGlobal,
{Operand::Reg(PhysReg::W8), Operand::Symbol(gv->GetName())});
{Operand::Reg(src_reg), Operand::Symbol(gv->GetName())});
return;
}
@ -298,26 +337,28 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
FormatError("mir", "暂不支持对非栈/全局变量地址进行写入"));
}
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
EmitValueToReg(store.GetValue(), src_reg, slots, block);
// 检查是否是GEP结果如果ptr的类型是指针且slot大小是8字节说明存储的是地址
const auto& dst_slot = function.GetFrameSlot(dst->second);
if (ptr->GetType()->IsPtrInt32() && dst_slot.size == 8) {
if (IsPointerType(ptr->GetType()) && dst_slot.size == 8) {
// GEP结果先加载指针地址再通过指针存储值
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(dst->second)});
block.Append(Opcode::StoreIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)});
} else {
// 普通栈变量:直接存储
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
{Operand::Reg(src_reg), Operand::FrameIndex(dst->second)});
}
return;
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
auto* ptr = load.GetPtr();
const bool is_float_load = load.GetType() && load.GetType()->IsFloat32();
const PhysReg value_reg = is_float_load ? PhysReg::S0 : PhysReg::W8;
// 检查是否是 GEP 结果(数组元素)
auto gep_it = geps.find(ptr);
@ -338,7 +379,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
} else {
// 变量索引
int index_slot = -1 - gep_info.byte_offset;
@ -354,12 +395,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
}
} else if (gep_info.byte_offset >= 0) {
// 本地数组,常量索引
block.Append(Opcode::LoadStackOffset,
{Operand::Reg(PhysReg::W8),
{Operand::Reg(value_reg),
Operand::FrameIndex(gep_info.base_slot),
Operand::Imm(gep_info.byte_offset)});
} else {
@ -378,11 +419,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
{Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
@ -391,9 +432,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(ptr)) {
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadGlobal,
{Operand::Reg(PhysReg::W8), Operand::Symbol(gv->GetName())});
{Operand::Reg(value_reg), Operand::Symbol(gv->GetName())});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
{Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
@ -409,72 +450,112 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
// 检查是否是GEP结果如果ptr的类型是指针且slot大小是8字节说明存储的是地址
const auto& src_slot = function.GetFrameSlot(src->second);
if (ptr->GetType()->IsPtrInt32() && src_slot.size == 8) {
if (IsPointerType(ptr->GetType()) && src_slot.size == 8) {
// GEP结果先加载指针地址再通过指针加载值
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::X9), Operand::FrameIndex(src->second)});
block.Append(Opcode::LoadIndirect,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::X9)});
{Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)});
} else {
// 普通栈变量:直接加载
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
{Operand::Reg(value_reg), Operand::FrameIndex(src->second)});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
{Operand::Reg(value_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FAddRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Sub: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FSubRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Mul: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FMulRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Div: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
if (bin.GetType()->IsFloat32()) {
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FDivRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::DivRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
@ -502,23 +583,53 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
case ir::Opcode::Cmp: {
auto& cmp = static_cast<const ir::CmpInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
// cmp 操作符通过 operand 传递
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9),
Operand::Imm(static_cast<int>(cmp.GetCmpOp()))});
if (cmp.GetLhs()->GetType()->IsFloat32()) {
EmitValueToReg(cmp.GetLhs(), PhysReg::S0, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::S1, slots, block);
block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1),
Operand::Imm(static_cast<int>(cmp.GetCmpOp()))});
} else {
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
// cmp 操作符通过 operand 传递
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9),
Operand::Imm(static_cast<int>(cmp.GetCmpOp()))});
}
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Cast: {
auto& cast = static_cast<const ir::CastInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (cast.GetCastOp() == ir::CastOp::IntToFloat) {
EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::SIToFP,
{Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(cast.GetValue(), PhysReg::S0, slots, block);
block.Append(Opcode::FPToSI,
{Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S0)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
// int/float 返回值
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat32() ? PhysReg::S0
: PhysReg::W0;
EmitValueToReg(ret.GetValue(), ret_reg, slots, block);
}
// void 返回:不设置 w0
block.Append(Opcode::Ret);
@ -531,7 +642,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
throw std::runtime_error(FormatError("mir", "Call 指令缺少被调用函数"));
}
// 参数传递:根据类型使用 w0-w7整数或 x0-x7指针
// 参数传递:根据类型使用 w0-w7整数、s0-s7浮点或 x0-x7指针
size_t num_args = call.GetNumArgs();
if (num_args > 8) {
throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数的函数调用"));
@ -540,8 +651,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
const auto& param_types = callee->GetParamTypes();
for (size_t i = 0; i < num_args; i++) {
auto* arg_value = call.GetArg(i);
// 检查参数类型是否是指针
bool is_ptr = (i < param_types.size() && param_types[i]->IsPtrInt32());
bool is_ptr =
(i < param_types.size() &&
(param_types[i]->IsPtrInt32() || param_types[i]->IsPtrFloat32()));
bool is_float = (i < param_types.size() && param_types[i]->IsFloat32());
if (is_ptr) {
// 指针参数:加载到 x 寄存器
@ -564,8 +677,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
FormatError("mir", "找不到指针参数的值: " + arg_value->GetName()));
}
} else {
// 整数参数:加载到 w 寄存器
PhysReg arg_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + i);
// 标量参数:整数用 w浮点用 s
PhysReg arg_reg = is_float
? static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + i)
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + i);
EmitValueToReg(arg_value, arg_reg, slots, block);
}
}
@ -576,8 +691,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
// 处理返回值
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex();
PhysReg ret_reg = call.GetType()->IsFloat32() ? PhysReg::S0 : PhysReg::W0;
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W0), Operand::FrameIndex(dst_slot)});
{Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
}
return;
@ -597,7 +713,8 @@ std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
// 复制全局变量信息
for (const auto& gv_ptr : module.GetGlobalVars()) {
const auto& gv = *gv_ptr;
machine_module->AddGlobalVar(gv.GetName(), gv.GetInitValue(), gv.GetCount());
machine_module->AddGlobalVar(gv.GetName(), gv.GetInitValue(), gv.GetCount(),
gv.IsFloat(), gv.GetInitElements());
}
for (const auto& func_ptr : module.GetFunctions()) {
@ -632,15 +749,18 @@ std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
auto& entry_block = machine_func->GetEntry();
for (size_t i = 0; i < num_params; i++) {
auto* arg = func.GetArgument(i);
bool is_ptr = arg->GetType()->IsPtrInt32();
bool is_ptr = arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat32();
bool is_float = arg->GetType()->IsFloat32();
int slot_size = is_ptr ? 8 : 4; // 指针 8 字节,整数 4 字节
int slot = machine_func->CreateFrameIndex(slot_size);
slots.emplace(arg, slot);
// 根据参数类型选择寄存器:指针用 x0-x7整数用 w0-w7
// 根据参数类型选择寄存器:指针用 x0-x7整数用 w0-w7,浮点用 s0-s7
PhysReg param_reg;
if (is_ptr) {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + i);
} else if (is_float) {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + i);
} else {
param_reg = static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + i);
}

@ -52,8 +52,10 @@ MachineFunction* MachineModule::CreateFunction(std::string name) {
return functions_.back().get();
}
void MachineModule::AddGlobalVar(std::string name, int init_val, int count) {
global_vars_.emplace_back(std::move(name), init_val, count);
void MachineModule::AddGlobalVar(std::string name, int init_val, int count,
bool is_float, std::vector<int> init_elems) {
global_vars_.emplace_back(std::move(name), init_val, count, is_float,
std::move(init_elems));
}
} // namespace mir

Loading…
Cancel
Save