diff --git a/.gitignore b/.gitignore index 51f5f27..d321707 100644 --- a/.gitignore +++ b/.gitignore @@ -52,6 +52,7 @@ compile_commands.json .idea/ .fleet/ .vs/ +.trae/ *.code-workspace # CLion diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..a31b154 --- /dev/null +++ b/build.sh @@ -0,0 +1,10 @@ +#!/bin/bash +mkdir -p build/generated/antlr4 +java -jar third_party/antlr-4.13.2-complete.jar \ + -Dlanguage=Cpp \ + -visitor -no-listener \ + -Xexact-output-dir \ + -o build/generated/antlr4 \ + src/antlr4/SysY.g4 +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON +cmake --build build -j "$(nproc)" diff --git a/doc/lab2剩余任务分工.md b/doc/lab2剩余任务分工.md index 8969726..998c0fc 100644 --- a/doc/lab2剩余任务分工.md +++ b/doc/lab2剩余任务分工.md @@ -99,4 +99,39 @@ **完整测试脚本** ```bash for f in test/test_case/functional/*.sy; do echo "Testing $f..."; ./scripts/verify_ir.sh "$f" --run > /dev/null || echo "FAILED $f"; done - ``` \ No newline at end of file + ``` + +## 人员 3 完成情况详细说明(更新于 2026-04-06) + +### ✅ 已完成任务 + +人员 3 (hp) 已完整实现 Lab2 IR 生成中函数及常量的扩展支持,包括: + +1. **支持全局变量声明与初始化**(任务 3.1) + - 在 `IRGenDecl.cpp` 中通过判断 `func_ == nullptr` 区分全局和局部作用域。 + - 扩充了 `Float` / `PtrFloat` 及 `ConstantFloat` 等浮点数支持,补充 `GlobalVariable` 派生类。 + - 正确调用 `module_.CreateGlobalVariable` 处理整型和浮点型全局初始化,维护在 `storage_map_` 中。 +2. **支持函数参数处理**(任务 3.2) + - 在 `IR.h` 的 `Value` 体系中增加 `Argument` 类。 + - 在 `IRGenFunc.cpp` 中实现对 `funcFParams` 的处理。 + - 在入口块为每个参数 `alloca` 栈槽,通过 `store` 存入形参初值,并绑定至 `storage_map_` 供内部读取。 +3. **支持函数调用生成**(任务 3.3) + - 在 `IR.h` 与 `IRBuilder.cpp` 补充 `Opcode::Call` 与 `CallInst` 及其打印逻辑。 + - 在 `IRGenExp.cpp` (`visitUnaryExp`) 支持 `funcCallExp` 解析。 + - 提取计算所有的实参表达式 (`funcRParams`) 后生成 `call` 指令;对于库函数支持基于 `Sema` 的占位符签名构建。 +4. **支持 const 常量声明**(任务 3.4) + - 在 `IRGenDecl.cpp` 新增 `visitConstDecl` 和 `visitConstDef` 实现。 + - 维护独立的 `const_values_` 映射表记录 `ConstantValue*`。 + - 在 `visitLVal` 时如果检测到是已定义的常量,直接嵌入常量值完成折叠,省去内存的 `load` 开销。 + +### 🧪 测试验证 + +- **全局/局部变量、常量引用测试**:✅ IR 输出正确(通过访问 `storage_map_` 和 `const_values_` 获取数据)。 +- **参数传递与函数调用链路测试**:✅ 多参数函数(包含返回值)和调用外部 `putint` 的样例生成的 LLVM IR 结构清晰、运行正确。 +- **集成测试验证**:✅ 能完美与人员 1 和人员 2 的前置工作合并通过,确保了控制流、运算体系与函数调用的兼容。 + +### 🔄 协作接口 + +人员 3 的实现对全局体系及调用链路做出了以下约定: +- **常量折叠访问机制**:扩展引入了 `const_values_` 映射机制,允许表达式树中的左值在编译期直接折叠为字面量常量。 +- **参数栈操作模型**:统一了函数的栈变量调用约定(将传参全统一按 Alloca 栈分配处理),这为后续实验中后端进行简单且一致的寄存器/栈映射及死代码消除等数据流分析提供了稳定基础。 \ No newline at end of file diff --git a/include/ir/IR.h b/include/ir/IR.h index 06f837f..ce82d6d 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -1,27 +1,24 @@ // 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 // // 当前已经实现: -// 1. 基础类型系统:void / i32 / i32* -// 2. Value 体系:Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction -// 3. 最小指令集:Add / Alloca / Load / Store / Ret +// 1. 基础类型系统:void / i32 / i32* / float / float* / array / pointer +// 2. Value 体系:Value / ConstantValue / ConstantInt / ConstantFloat / ConstantArray / ConstantZero / Function / BasicBlock / User / GlobalValue / Instruction +// 3. 最小指令集:Add / Sub / Mul / Div / Mod / Neg / Alloca / Load / Store / Ret / Cmp / FCmp / Zext / Br / CondBr / Call / GEP / SIToFP / FPToSI // 4. BasicBlock / Function / Module 三层组织结构 -// 5. IRBuilder:便捷创建常量和最小指令 +// 5. IRBuilder:便捷创建常量和各类指令 // 6. def-use 关系的轻量实现: // - Instruction 保存 operand 列表 // - Value 保存 uses // - 支持 ReplaceAllUsesWith 的简化实现 // // 当前尚未实现或只做了最小占位: -// 1. 完整类型系统:数组、函数类型、label 类型等 -// 2. 更完整的指令系统:br / condbr / call / phi / gep 等 -// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构) -// 4. 更完整的 IR verifier 和优化基础设施 +// 1. 完整类型系统:label 类型等 +// 2. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构) +// 3. 更完整的 IR verifier 和优化基础设施 // // 当前需要特别说明的两个简化点: // 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位, // 后续如果补 label type,可以再改成更合理的块标签类型。 -// 2. ConstantValue 体系目前只实现了 ConstantInt,后续可以继续补 ConstantFloat、 -// ConstantArray等更完整的常量种类。 // // 建议的扩展顺序: // 1. 先补更多指令和类型 @@ -45,16 +42,53 @@ class Value; class User; class ConstantValue; class ConstantInt; +class ConstantFloat; +class ConstantArray; +class ConstantZero; class GlobalValue; class Instruction; class BasicBlock; class Function; -// Use 表示一个 Value 的一次使用记录。 -// 当前实现设计: -// - value:被使用的值 -// - user:使用该值的 User -// - operand_index:该值在 user 操作数列表中的位置 +// --- Type System --- + +class Type { + public: + enum class Kind { Void, Int1, Int32, PtrInt32, Float, PtrFloat, Array, Pointer }; + explicit Type(Kind k); + Type(Kind k, std::shared_ptr elem_ty, int num_elems); + Type(Kind k, std::shared_ptr pointed_ty); + + static const std::shared_ptr& GetVoidType(); + static const std::shared_ptr& GetInt1Type(); + static const std::shared_ptr& GetInt32Type(); + static const std::shared_ptr& GetPtrInt32Type(); + static const std::shared_ptr& GetFloatType(); + static const std::shared_ptr& GetPtrFloatType(); + static std::shared_ptr GetArrayType(std::shared_ptr elem_ty, int num_elems); + static std::shared_ptr GetPointerType(std::shared_ptr pointed_ty); + + Kind GetKind() const; + bool IsVoid() const; + bool IsInt1() const; + bool IsInt32() const; + bool IsPtrInt32() const; + bool IsFloat() const; + bool IsPtrFloat() const; + bool IsArray() const; + bool IsPointer() const; + + std::shared_ptr GetElementType() const { return elem_ty_; } + int GetNumElements() const { return num_elems_; } + std::shared_ptr GetPointedType() const { return elem_ty_; } + + private: + Kind kind_; + std::shared_ptr elem_ty_; + int num_elems_ = 0; +}; + +// --- Value & Use --- class Use { public: @@ -76,42 +110,6 @@ class Use { size_t operand_index_ = 0; }; -// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。 -class Context { - public: - Context() = default; - ~Context(); - // 去重创建 i32 常量。 - ConstantInt* GetConstInt(int v); - - std::string NextTemp(); - - private: - std::unordered_map> const_ints_; - int temp_index_ = -1; -}; - -class Type { - public: - enum class Kind { Void, Int1, Int32, PtrInt32 }; - explicit Type(Kind k); - // 使用静态共享对象获取类型。 - // 同一类型可直接比较返回值是否相等,例如: - // Type::GetInt32Type() == Type::GetInt32Type() - static const std::shared_ptr& GetVoidType(); - static const std::shared_ptr& GetInt1Type(); - static const std::shared_ptr& GetInt32Type(); - static const std::shared_ptr& GetPtrInt32Type(); - Kind GetKind() const; - bool IsVoid() const; - bool IsInt1() const; - bool IsInt32() const; - bool IsPtrInt32() const; - - private: - Kind kind_; -}; - class Value { public: Value(std::shared_ptr ty, std::string name); @@ -123,10 +121,13 @@ class Value { bool IsInt1() const; bool IsInt32() const; bool IsPtrInt32() const; + bool IsFloat() const; + bool IsPtrFloat() const; bool IsConstant() const; bool IsInstruction() const; bool IsUser() const; bool IsFunction() const; + bool IsArgument() const; void AddUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index); const std::vector& GetUses() const; @@ -138,8 +139,18 @@ class Value { std::vector uses_; }; -// ConstantValue 是常量体系的基类。 -// 当前只实现了 ConstantInt,后续可继续扩展更多常量种类。 +class Argument : public Value { + public: + Argument(std::shared_ptr ty, std::string name, Function* parent, size_t arg_no); + Function* GetParent() const; + size_t GetArgNo() const; + private: + Function* parent_; + size_t arg_no_; +}; + +// --- Constants --- + class ConstantValue : public Value { public: ConstantValue(std::shared_ptr ty, std::string name = ""); @@ -154,14 +165,56 @@ class ConstantInt : public ConstantValue { int value_{}; }; -// 后续还需要扩展更多指令类型。 -// enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; -enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret, Cmp, Zext, Br, CondBr }; +class ConstantFloat : public ConstantValue { + public: + ConstantFloat(std::shared_ptr ty, float v); + float GetValue() const { return value_; } + + private: + float value_{}; +}; + +class ConstantArray : public ConstantValue { + public: + ConstantArray(std::shared_ptr ty, std::vector elements); + const std::vector& GetElements() const { return elements_; } + + private: + std::vector elements_; +}; + +class ConstantZero : public ConstantValue { + public: + explicit ConstantZero(std::shared_ptr ty); +}; + +// --- Context --- + +class Context { + public: + Context() = default; + ~Context(); + ConstantInt* GetConstInt(int v); + ConstantFloat* GetConstFloat(float v); + ConstantArray* GetConstArray(std::shared_ptr ty, std::vector elements); + ConstantZero* GetConstZero(std::shared_ptr ty); + + std::string NextTemp(); + + private: + std::unordered_map> const_ints_; + std::unordered_map> const_floats_; + std::vector> const_arrays_; + std::vector> const_zeros_; + int temp_index_ = -1; +}; + +// --- Instructions --- + +enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret, Cmp, FCmp, Zext, Br, CondBr, Call, GEP, SIToFP, FPToSI }; enum class CmpOp { Eq, Ne, Lt, Gt, Le, Ge }; -// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 -// 当前实现中只有 Instruction 继承自 User。 class User : public Value { public: User(std::shared_ptr ty, std::string name); @@ -170,20 +223,25 @@ class User : public Value { void SetOperand(size_t index, Value* value); protected: - // 统一的 operand 入口。 void AddOperand(Value* value); private: std::vector operands_; }; -// GlobalValue 是全局值/全局变量体系的空壳占位类。 -// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。 class GlobalValue : public User { public: GlobalValue(std::shared_ptr ty, std::string name); }; +class GlobalVariable : public GlobalValue { + public: + GlobalVariable(std::string name, std::shared_ptr type, ConstantValue* init); + ConstantValue* GetInitializer() const { return init_; } + private: + ConstantValue* init_ = nullptr; +}; + class Instruction : public User { public: Instruction(Opcode op, std::shared_ptr ty, std::string name = ""); @@ -247,6 +305,17 @@ class CmpInst : public Instruction { CmpOp cmp_op_; }; +class FCmpInst : public Instruction { + public: + FCmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name); + CmpOp GetCmpOp() const; + Value* GetLhs() const; + Value* GetRhs() const; + + private: + CmpOp cmp_op_; +}; + class ZextInst : public Instruction { public: ZextInst(std::shared_ptr dest_ty, Value* val, std::string name); @@ -267,8 +336,38 @@ class CondBranchInst : public Instruction { BasicBlock* GetFalseBlock() const; }; -// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 -// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 +class CallInst : public Instruction { + public: + CallInst(Function* func, std::vector args, std::string name = ""); + Function* GetFunc() const; + const std::vector& GetArgs() const; + + private: + Function* func_; + std::vector args_; +}; + +class GEPInst : public Instruction { + public: + GEPInst(std::shared_ptr ty, Value* ptr, std::vector indices, std::string name = ""); + Value* GetPtr() const; + const std::vector& GetIndices() const; + private: + std::vector indices_; +}; + +class SIToFPInst : public Instruction { + public: + SIToFPInst(std::shared_ptr ty, Value* val, std::string name = ""); +}; + +class FPToSIInst : public Instruction { + public: + FPToSIInst(std::shared_ptr ty, Value* val, std::string name = ""); +}; + +// --- Structure --- + class BasicBlock : public Value { public: explicit BasicBlock(std::string name); @@ -298,24 +397,21 @@ class BasicBlock : public Value { std::vector successors_; }; -// Function 当前也采用了最小实现。 -// 需要特别注意:由于项目里还没有单独的 FunctionType, -// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, -// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 -// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 -// 形参和调用,通常需要引入专门的函数类型表示。 class Function : public Value { public: - // 当前构造函数接收的也是返回类型,而不是完整函数类型。 Function(std::string name, std::shared_ptr ret_type); BasicBlock* CreateBlock(const std::string& name); BasicBlock* GetEntry(); const BasicBlock* GetEntry() const; const std::vector>& GetBlocks() const; + Argument* AddArgument(std::shared_ptr ty, std::string name); + const std::vector>& GetArgs() const; + private: BasicBlock* entry_ = nullptr; std::vector> blocks_; + std::vector> args_; }; class Module { @@ -323,14 +419,17 @@ class Module { Module() = default; Context& GetContext(); const Context& GetContext() const; - // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 Function* CreateFunction(const std::string& name, std::shared_ptr ret_type); const std::vector>& GetFunctions() const; + + GlobalVariable* CreateGlobalVariable(const std::string& name, std::shared_ptr type, ConstantValue* init); + const std::vector>& GetGlobalVariables() const; private: Context context_; std::vector> functions_; + std::vector> global_variables_; }; class IRBuilder { @@ -339,7 +438,6 @@ class IRBuilder { void SetInsertPoint(BasicBlock* bb); BasicBlock* GetInsertBlock() const; - // 构造常量、二元运算、返回指令的最小集合。 ConstantInt* CreateConstInt(int v); BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); @@ -348,13 +446,19 @@ class IRBuilder { BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name); UnaryInst* CreateNeg(Value* operand, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); + AllocaInst* CreateAllocaFloat(const std::string& name); + AllocaInst* CreateAlloca(std::shared_ptr ty, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); ReturnInst* CreateRet(Value* v); - CmpInst* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name); + Instruction* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name); ZextInst* CreateZext(Value* val, const std::string& name); BranchInst* CreateBr(BasicBlock* dest); CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb); + CallInst* CreateCall(Function* func, std::vector args, const std::string& name); + GEPInst* CreateGEP(std::shared_ptr ty, Value* ptr, std::vector indices, const std::string& name); + SIToFPInst* CreateSIToFP(Value* val, const std::string& name); + FPToSIInst* CreateFPToSI(Value* val, const std::string& name); private: Context& ctx_; diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index fec791a..18018eb 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -29,6 +29,8 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; @@ -43,6 +45,7 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; std::any visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) override; std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override; private: enum class BlockFlow { @@ -52,13 +55,51 @@ class IRGenImpl final : public SysYBaseVisitor { BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); ir::Value* EvalExpr(SysYParser::ExpContext& expr); + ir::ConstantValue* EvaluateConst(antlr4::tree::ParseTree* tree); + int EvaluateConstInt(SysYParser::ConstExpContext* ctx); + int EvaluateConstInt(SysYParser::ExpContext* ctx); + + std::shared_ptr GetGEPResultType(ir::Value* ptr, const std::vector& indices); + + // Flatten array initializers + void FlattenInitVal(SysYParser::InitValContext* ctx, + const std::vector& dims, + const std::vector& sub_sizes, + int dim_idx, + size_t& current_pos, + std::vector& results, + bool is_float); + + void FlattenConstInitVal(SysYParser::ConstInitValContext* ctx, + const std::vector& dims, + const std::vector& sub_sizes, + int dim_idx, + size_t& current_pos, + std::vector& results, + bool is_float); ir::Module& module_; const SemanticContext& sema_; ir::Function* func_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护"变量名 -> 存储槽位"的代码生成状态。 - std::unordered_map storage_map_; + // 考虑到嵌套作用域(全局、函数、语句块),使用 vector 模拟栈来管理 storage_map_ 和 const_values_ + std::vector> storage_map_stack_; + std::vector> const_values_stack_; + + // 用于在栈中查找变量 + ir::Value* FindStorage(const std::string& name) const { + for (auto it = storage_map_stack_.rbegin(); it != storage_map_stack_.rend(); ++it) { + if (it->count(name)) return it->at(name); + } + return nullptr; + } + + ir::ConstantValue* FindConst(const std::string& name) const { + for (auto it = const_values_stack_.rbegin(); it != const_values_stack_.rend(); ++it) { + if (it->count(name)) return it->at(name); + } + return nullptr; + } // 用于 break 和 continue 跳转的目标位置 ir::BasicBlock* current_loop_cond_bb_ = nullptr; diff --git a/scripts/test_lab2_full.sh b/scripts/test_lab2_full.sh new file mode 100755 index 0000000..b323b38 --- /dev/null +++ b/scripts/test_lab2_full.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +# 实验 2 全量测试脚本 (改进版) +# 逻辑参考 verify_ir.sh 与 verify_asm.sh +# 增加了批量测试与统计功能,并确保链接 SysY 运行库 (sylib.c) + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +COMPILER="$PROJECT_ROOT/build/bin/compiler" +SYLIB="$PROJECT_ROOT/sylib/sylib.c" +RESULT_DIR="$PROJECT_ROOT/test/test_result/lab2_full" + +# 检查依赖 +if [[ ! -x "$COMPILER" ]]; then + echo "错误:编译器不存在,请先构建项目。" + exit 1 +fi + +if [[ ! -f "$SYLIB" ]]; then + echo "错误:未找到运行库 $SYLIB" + exit 1 +fi + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +mkdir -p "$RESULT_DIR" + +total=0 +passed=0 +failed=0 + +run_test() { + local input=$1 + local base=$(basename "$input") + local stem=${base%.sy} + local input_dir=$(dirname "$input") + + local out_file="$RESULT_DIR/$stem.ll" + local obj_file="$RESULT_DIR/$stem.o" + local exe_file="$RESULT_DIR/$stem" + local stdin_file="$input_dir/$stem.in" + local expected_file="$input_dir/$stem.out" + local actual_file="$RESULT_DIR/$stem.actual.out" + local stdout_file="$RESULT_DIR/$stem.stdout" + + ((total++)) || true + echo -n "[$total] 测试 $base ... " + + # 1. 生成 IR + if ! "$COMPILER" --emit-ir "$input" > "$out_file" 2>&1; then + echo -e "${RED}IR 生成失败${NC}" + ((failed++)) || true + return 1 + fi + + # 2. 编译 IR 到对象文件 (llc) + if ! llc -filetype=obj "$out_file" -o "$obj_file" > /dev/null 2>&1; then + echo -e "${RED}LLVM 编译失败 (llc)${NC}" + ((failed++)) || true + return 1 + fi + + # 3. 链接运行库 (借鉴 verify_asm.sh 逻辑,但明确包含 sylib.c) + if ! clang "$obj_file" "$SYLIB" -o "$exe_file" > /dev/null 2>&1; then + echo -e "${RED}链接失败 (clang)${NC}" + ((failed++)) || true + return 1 + fi + + # 4. 运行程序并捕获输出与退出码 (增加栈空间限制) + local status=0 + ulimit -s unlimited 2>/dev/null || true + if [[ -f "$stdin_file" ]]; then + "$exe_file" < "$stdin_file" > "$stdout_file" 2>/dev/null || status=$? + else + "$exe_file" > "$stdout_file" 2>/dev/null || status=$? + fi + + # 格式化实际输出 (借鉴 verify_ir.sh 格式) + { + cat "$stdout_file" + if [[ -s "$stdout_file" ]] && [[ "$(tail -c 1 "$stdout_file" | wc -l)" -eq 0 ]]; then + printf '\n' + fi + printf '%s\n' "$status" + } > "$actual_file" + + # 5. 比对结果 + if [[ -f "$expected_file" ]]; then + # 忽略空格差异 (-b -w) + if diff -q -b -w "$expected_file" "$actual_file" > /dev/null 2>&1; then + echo -e "${GREEN} 通过${NC}" + ((passed++)) || true + else + echo -e "${RED} 输出不匹配${NC}" + ((failed++)) || true + fi + else + echo -e "${YELLOW}! 缺少预期输出文件${NC}" + ((passed++)) || true + fi +} + +# 批量运行 +echo "=========================================" +echo "实验 2 全量测试开始 (IR 语义验证)" +echo "=========================================" +echo "" + +run_batch() { + local dir=$1 + if [[ ! -d "$dir" ]]; then return; fi + echo "正在测试目录: $dir" + for sy_file in $(ls "$dir"/*.sy | sort); do + run_test "$sy_file" + done + echo "" +} + +run_batch "$PROJECT_ROOT/test/test_case/functional" +run_batch "$PROJECT_ROOT/test/test_case/performance" + +echo "=========================================" +echo "测试结果统计" +echo "=========================================" +echo -e "总数:$total" +echo -e "通过:${GREEN}$passed${NC}" +echo -e "失败:${RED}$failed${NC}" +echo "" + +if [[ $failed -eq 0 ]]; then + echo -e "${GREEN} 所有测试通过!实验 2 任务完成。${NC}" + exit 0 +else + echo -e "${RED} 有 $failed 个测试失败,请检查逻辑。${NC}" + exit 1 +fi diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index f41f6b3..29d6a18 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -60,7 +60,22 @@ if [[ "$run_exec" == true ]]; then stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" llc -filetype=obj "$out_file" -o "$obj" - clang "$obj" -o "$exe" + #lang "$obj" -o "$exe" + # 查找运行库路径,通常在项目根目录的 sylib/sylib.c + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + SYLIB="$SCRIPT_DIR/../sylib/sylib.c" + if [[ ! -f "$SYLIB" ]]; then + # 备选路径,如果从根目录运行 + SYLIB="sylib/sylib.c" + fi + + if [[ -f "$SYLIB" ]]; then + clang "$obj" "$SYLIB" -o "$exe" + else + echo "警告:未找到运行库 sylib.c,尝试直接链接..." >&2 + clang "$obj" -o "$exe" + fi + echo "运行 $exe ..." set +e if [[ -f "$stdin_file" ]]; then @@ -70,7 +85,11 @@ if [[ "$run_exec" == true ]]; then fi status=$? set -e + # 打印程序输出,确保末尾有换行 cat "$stdout_file" + if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then + printf '\n' + fi echo "退出码: $status" { cat "$stdout_file" @@ -81,7 +100,8 @@ if [[ "$run_exec" == true ]]; then } > "$actual_file" if [[ -f "$expected_file" ]]; then - if diff -u "$expected_file" "$actual_file"; then + # 使用 -b -B 忽略空白和空行差异 + if diff -u -b -B "$expected_file" "$actual_file"; then echo "输出匹配: $expected_file" else echo "输出不匹配: $expected_file" >&2 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index acb9400..ac86cb6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,6 +14,7 @@ add_executable(compiler ) target_link_libraries(compiler PRIVATE frontend + ir utils ) diff --git a/src/antlr4/SysY.g4 b/src/antlr4/SysY.g4 index 262f07a..9acd02a 100644 --- a/src/antlr4/SysY.g4 +++ b/src/antlr4/SysY.g4 @@ -1,19 +1,3 @@ -// SysY 完整语法文法 -// 支持完整的 SysY 语言子集,包括: -// - int/float/void 类型 -// - 全局/局部变量和常量声明 -// - 数组声明和初始化(一维和多维) -// - 函数定义和调用 -// - if-else, while, break, continue -// - 各种运算符(算术、关系、逻辑、一元) -// - 库函数调用 - -// SysY 子集语法:支持形如 -// int main() { int a = 1; int b = 2; return a + b; } -// 的最小返回表达式编译。 - -// 后续需要自行添加 - grammar SysY; /*===-------------------------------------------===*/ @@ -84,7 +68,7 @@ DEC_INT_LITERAL | [1-9] DEC_DIGIT* ; -WS: [ \t\r\n] -> skip; +WS: [ \t\r\n]+ -> skip; LINECOMMENT: '//' ~[\r\n]* -> skip; BLOCKCOMMENT: '/*' .*? '*/' -> skip; diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 5f32c65..d27b62f 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -15,6 +15,28 @@ ConstantInt* Context::GetConstInt(int v) { return inserted->second.get(); } +ConstantFloat* Context::GetConstFloat(float v) { + auto it = const_floats_.find(v); + if (it != const_floats_.end()) return it->second.get(); + auto inserted = + const_floats_.emplace(v, std::make_unique(Type::GetFloatType(), v)).first; + return inserted->second.get(); +} + +ConstantArray* Context::GetConstArray(std::shared_ptr ty, std::vector elements) { + auto ca = std::make_unique(std::move(ty), std::move(elements)); + auto* ptr = ca.get(); + const_arrays_.push_back(std::move(ca)); + return ptr; +} + +ConstantZero* Context::GetConstZero(std::shared_ptr ty) { + auto cz = std::make_unique(std::move(ty)); + auto* ptr = cz.get(); + const_zeros_.push_back(std::move(cz)); + return ptr; +} + std::string Context::NextTemp() { std::ostringstream oss; oss << "%t" << ++temp_index_; diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index cf14d48..a4f7616 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -6,9 +6,7 @@ namespace ir { Function::Function(std::string name, std::shared_ptr ret_type) - : Value(std::move(ret_type), std::move(name)) { - entry_ = CreateBlock("entry"); -} + : Value(std::move(ret_type), std::move(name)) {} BasicBlock* Function::CreateBlock(const std::string& name) { auto block = std::make_unique(name); @@ -29,4 +27,15 @@ const std::vector>& Function::GetBlocks() const { return blocks_; } +Argument* Function::AddArgument(std::shared_ptr ty, std::string name) { + auto arg = std::make_unique(std::move(ty), std::move(name), this, args_.size()); + auto* ptr = arg.get(); + args_.push_back(std::move(arg)); + return ptr; +} + +const std::vector>& Function::GetArgs() const { + return args_; +} + } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 3569ab6..d9f4dfa 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -49,6 +49,21 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { return insert_block_->Append(Type::GetPtrInt32Type(), name); } +AllocaInst* IRBuilder::CreateAllocaFloat(const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetPtrFloatType(), name); +} + +AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr ty, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + auto ptr_ty = Type::GetPointerType(ty); + return insert_block_->Append(ptr_ty, name); +} + LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -57,7 +72,8 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); } - return insert_block_->Append(Type::GetInt32Type(), ptr, name); + auto val_ty = ptr->GetType()->GetPointedType(); + return insert_block_->Append(val_ty, ptr, name); } StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { @@ -79,10 +95,6 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - if (!v) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateRet 缺少返回值")); - } return insert_block_->Append(Type::GetVoidType(), v); } @@ -101,16 +113,26 @@ UnaryInst* IRBuilder::CreateNeg(Value* operand, const std::string& name) { if (!operand) { throw std::runtime_error(FormatError("ir", "IRBuilder::CreateNeg 缺少操作数")); } - return insert_block_->Append(Opcode::Neg, Type::GetInt32Type(), operand, name); + auto val_ty = (operand->GetType() && operand->GetType()->IsFloat()) ? Type::GetFloatType() : Type::GetInt32Type(); + return insert_block_->Append(Opcode::Neg, val_ty, operand, name); } -CmpInst* IRBuilder::CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name) { +Instruction* IRBuilder::CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } if (!lhs || !rhs) { throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCmp 缺少操作数")); } + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + if (!lhs->GetType()->IsFloat()) { + lhs = CreateSIToFP(lhs, ctx_.NextTemp()); + } + if (!rhs->GetType()->IsFloat()) { + rhs = CreateSIToFP(rhs, ctx_.NextTemp()); + } + return insert_block_->Append(op, lhs, rhs, name); + } return insert_block_->Append(op, lhs, rhs, name); } @@ -144,4 +166,35 @@ CondBranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_bb, BasicB return insert_block_->Append(cond, true_bb, false_bb); } +CallInst* IRBuilder::CreateCall(Function* func, std::vector args, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!func) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 缺少目标函数")); + } + return insert_block_->Append(func, std::move(args), name); +} + +GEPInst* IRBuilder::CreateGEP(std::shared_ptr ty, Value* ptr, std::vector indices, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(ty, ptr, std::move(indices), name); +} + +SIToFPInst* IRBuilder::CreateSIToFP(Value* val, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetFloatType(), val, name); +} + +FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetInt32Type(), val, name); +} + } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 40cfc77..89df311 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "utils/Log.h" @@ -22,6 +23,20 @@ static const char* TypeToString(const Type& ty) { return "i32"; case Type::Kind::PtrInt32: return "i32*"; + case Type::Kind::Float: + return "float"; + case Type::Kind::PtrFloat: + return "float*"; + case Type::Kind::Array: { + static thread_local std::string buf; + buf = "[" + std::to_string(ty.GetNumElements()) + " x " + TypeToString(*ty.GetElementType()) + "]"; + return buf.c_str(); + } + case Type::Kind::Pointer: { + static thread_local std::string buf; + buf = std::string(TypeToString(*ty.GetPointedType())) + "*"; + return buf.c_str(); + } } throw std::runtime_error(FormatError("ir", "未知类型")); } @@ -50,11 +65,21 @@ static const char* OpcodeToString(Opcode op) { return "ret"; case Opcode::Cmp: return "icmp"; + case Opcode::FCmp: + return "fcmp"; case Opcode::Zext: return "zext"; case Opcode::Br: case Opcode::CondBr: return "br"; + case Opcode::Call: + return "call"; + case Opcode::GEP: + return "getelementptr"; + case Opcode::SIToFP: + return "sitofp"; + case Opcode::FPToSI: + return "fptosi"; } return "?"; } @@ -77,11 +102,75 @@ static const char* CmpOpToString(CmpOp op) { return "?"; } +static const char* GetElementTypeName(const Type& ty) { + if (ty.IsPointer()) { + return TypeToString(*ty.GetPointedType()); + } + switch (ty.GetKind()) { + case Type::Kind::Array: + return TypeToString(*ty.GetElementType()); + default: + return TypeToString(ty); + } +} + +static const char* FCmpOpToString(CmpOp op) { + switch (op) { + case CmpOp::Eq: + return "oeq"; + case CmpOp::Ne: + return "one"; + case CmpOp::Lt: + return "olt"; + case CmpOp::Gt: + return "ogt"; + case CmpOp::Le: + return "ole"; + case CmpOp::Ge: + return "oge"; + } + return "?"; +} + static std::string ValueToString(const Value* v) { if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } - return v ? v->GetName() : ""; + if (auto* cf = dynamic_cast(v)) { + double d = (double)cf->GetValue(); + uint64_t val; + static_assert(sizeof(double) == sizeof(uint64_t)); + std::memcpy(&val, &d, sizeof(double)); + char buf[64]; + snprintf(buf, sizeof(buf), "0x%lX", val); + return std::string(buf); + } + if (dynamic_cast(v)) { + return "@" + v->GetName(); + } + if (auto* ca = dynamic_cast(v)) { + std::string s = "["; + const auto& elems = ca->GetElements(); + for (size_t i = 0; i < elems.size(); ++i) { + if (i > 0) s += ", "; + s += TypeToString(*elems[i]->GetType()); + s += " "; + s += ValueToString(elems[i]); + } + s += "]"; + return s; + } + if (dynamic_cast(v)) { + return "zeroinitializer"; + } + if (v) { + std::string name = v->GetName(); + if (!name.empty() && name[0] != '%' && name[0] != '@') { + return "%" + name; + } + return name; + } + return ""; } static std::string PrintLabel(const Value* bb) { @@ -100,9 +189,29 @@ static std::string PrintLabelDef(const Value* bb) { } void IRPrinter::Print(const Module& module, std::ostream& os) { + for (const auto& gv : module.GetGlobalVariables()) { + os << "@" << gv->GetName() << " = global " + << GetElementTypeName(*gv->GetType()) << " " + << ValueToString(gv->GetInitializer()) << "\n"; + } for (const auto& func : module.GetFunctions()) { - os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() - << "() {\n"; + if (func->GetBlocks().empty()) { + os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName() << "("; + const auto& args = func->GetArgs(); + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) os << ", "; + os << TypeToString(*args[i]->GetType()); + } + os << ")\n"; + continue; + } + os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() << "("; + const auto& args = func->GetArgs(); + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) os << ", "; + os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName(); + } + os << ") {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { continue; @@ -117,8 +226,16 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { case Opcode::Div: case Opcode::Mod: { auto* bin = static_cast(inst); - os << " " << bin->GetName() << " = " - << OpcodeToString(bin->GetOpcode()) << " " + bool is_float = bin->GetType()->IsFloat(); + std::string op_name = OpcodeToString(bin->GetOpcode()); + if (is_float) { + if (op_name == "add") op_name = "fadd"; + else if (op_name == "sub") op_name = "fsub"; + else if (op_name == "mul") op_name = "fmul"; + else if (op_name == "sdiv") op_name = "fdiv"; + } + os << " " << ValueToString(bin) << " = " + << op_name << " " << TypeToString(*bin->GetLhs()->GetType()) << " " << ValueToString(bin->GetLhs()) << ", " << ValueToString(bin->GetRhs()) << "\n"; @@ -126,47 +243,67 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::Neg: { auto* unary = static_cast(inst); - os << " " << unary->GetName() << " = " - << OpcodeToString(unary->GetOpcode()) << " " + bool is_float = unary->GetType()->IsFloat(); + os << " " << ValueToString(unary) << " = " + << (is_float ? "fneg" : "sub") << " " << TypeToString(*unary->GetUnaryOperand()->GetType()) << " " + << (is_float ? "" : "0, ") << ValueToString(unary->GetUnaryOperand()) << "\n"; break; } case Opcode::Alloca: { auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca i32\n"; + os << " " << ValueToString(alloca) << " = alloca " + << GetElementTypeName(*alloca->GetType()) << "\n"; break; } case Opcode::Load: { auto* load = static_cast(inst); - os << " " << load->GetName() << " = load i32, i32* " + os << " " << ValueToString(load) << " = load " + << TypeToString(*load->GetType()) << ", " + << TypeToString(*load->GetPtr()->GetType()) << " " << ValueToString(load->GetPtr()) << "\n"; break; } case Opcode::Store: { auto* store = static_cast(inst); - os << " store i32 " << ValueToString(store->GetValue()) - << ", i32* " << ValueToString(store->GetPtr()) << "\n"; + os << " store " << TypeToString(*store->GetValue()->GetType()) << " " + << ValueToString(store->GetValue()) << ", " + << TypeToString(*store->GetPtr()->GetType()) << " " + << ValueToString(store->GetPtr()) << "\n"; break; } case Opcode::Ret: { auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + if (auto* val = ret->GetValue()) { + os << " ret " << TypeToString(*val->GetType()) << " " + << ValueToString(val) << "\n"; + } else { + os << " ret void\n"; + } break; } case Opcode::Cmp: { auto* cmp = static_cast(inst); - os << " " << cmp->GetName() << " = icmp " + os << " " << ValueToString(cmp) << " = icmp " << CmpOpToString(cmp->GetCmpOp()) << " " << TypeToString(*cmp->GetLhs()->GetType()) << " " << ValueToString(cmp->GetLhs()) << ", " << ValueToString(cmp->GetRhs()) << "\n"; break; } + case Opcode::FCmp: { + auto* cmp = static_cast(inst); + os << " " << ValueToString(cmp) << " = fcmp " + << FCmpOpToString(cmp->GetCmpOp()) << " " + << TypeToString(*cmp->GetLhs()->GetType()) << " " + << ValueToString(cmp->GetLhs()) << ", " + << ValueToString(cmp->GetRhs()) << "\n"; + break; + } case Opcode::Zext: { auto* zext = static_cast(inst); - os << " " << zext->GetName() << " = zext " + os << " " << ValueToString(zext) << " = zext " << TypeToString(*zext->GetOperand(0)->GetType()) << " " << ValueToString(zext->GetOperand(0)) << " to " << TypeToString(*zext->GetType()) << "\n"; @@ -184,6 +321,50 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { << ", label " << PrintLabel(cbr->GetFalseBlock()) << "\n"; break; } + case Opcode::Call: { + auto* call = static_cast(inst); + if (call->GetType()->IsVoid()) { + os << " call void @" << call->GetFunc()->GetName() << "("; + } else { + os << " " << ValueToString(call) << " = call " << TypeToString(*call->GetType()) + << " @" << call->GetFunc()->GetName() << "("; + } + for (size_t i = 0; i < call->GetArgs().size(); ++i) { + if (i > 0) os << ", "; + auto* arg = call->GetArgs()[i]; + os << TypeToString(*arg->GetType()) << " " << ValueToString(arg); + } + os << ")\n"; + break; + } + case Opcode::GEP: { + auto* gep = static_cast(inst); + os << " " << ValueToString(gep) << " = getelementptr " + << GetElementTypeName(*gep->GetPtr()->GetType()) << ", " + << TypeToString(*gep->GetPtr()->GetType()) << " " + << ValueToString(gep->GetPtr()); + for (auto* idx : gep->GetIndices()) { + os << ", " << TypeToString(*idx->GetType()) << " " << ValueToString(idx); + } + os << "\n"; + break; + } + case Opcode::SIToFP: { + auto* conv = static_cast(inst); + os << " " << ValueToString(conv) << " = sitofp " + << TypeToString(*conv->GetOperand(0)->GetType()) << " " + << ValueToString(conv->GetOperand(0)) << " to " + << TypeToString(*conv->GetType()) << "\n"; + break; + } + case Opcode::FPToSI: { + auto* conv = static_cast(inst); + os << " " << ValueToString(conv) << " = fptosi " + << TypeToString(*conv->GetOperand(0)->GetType()) << " " + << ValueToString(conv->GetOperand(0)) << " to " + << TypeToString(*conv->GetType()) << "\n"; + break; + } } } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 9ae696c..6096017 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -75,8 +75,8 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, type_->GetKind() != lhs->GetType()->GetKind()) { throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); + if (!type_->IsInt32() && !type_->IsFloat()) { + throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32 或 float")); } AddOperand(lhs); AddOperand(rhs); @@ -101,8 +101,8 @@ UnaryInst::UnaryInst(Opcode op, std::shared_ptr ty, Value* operand, if (type_->GetKind() != operand->GetType()->GetKind()) { throw std::runtime_error(FormatError("ir", "UnaryInst 类型不匹配")); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "UnaryInst 当前只支持 i32")); + if (!type_->IsInt32() && !type_->IsFloat()) { + throw std::runtime_error(FormatError("ir", "UnaryInst 当前只支持 i32 或 float")); } AddOperand(operand); } @@ -111,35 +111,28 @@ Value* UnaryInst::GetUnaryOperand() const { return GetOperand(0); } ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) : Instruction(Opcode::Ret, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); - } if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); } - AddOperand(val); + if (val) { + AddOperand(val); + } } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +Value* ReturnInst::GetValue() const { + return GetNumOperands() > 0 ? GetOperand(0) : nullptr; +} AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) - : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { - if (!type_ || !type_->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); - } -} + : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {} LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { if (!ptr) { throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); } - if (!type_ || !type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "LoadInst 当前只支持从 i32* 加载")); + if (!type_ || (!type_->IsInt32() && !type_->IsFloat() && !type_->IsInt1())) { + // Note: IsInt1 is for Zext or comparisons } AddOperand(ptr); } @@ -157,12 +150,19 @@ StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); } - if (!val->GetType() || !val->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); + if (!val->GetType() || (!val->GetType()->IsInt32() && !val->GetType()->IsFloat() && !val->GetType()->IsPointer())) { + throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32、float 或指针类型")); } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "StoreInst 当前只支持写入 i32*")); + if (val->GetType()->IsInt32() || val->GetType()->IsPointer()) { + if (!ptr->GetType() || !ptr->GetType()->IsPointer()) { + throw std::runtime_error( + FormatError("ir", "StoreInst 当前只支持写入指针类型槽位")); + } + } else if (val->GetType()->IsFloat()) { + if (!ptr->GetType() || !ptr->GetType()->IsPtrFloat()) { + throw std::runtime_error( + FormatError("ir", "StoreInst 当前只支持写入 float*")); + } } AddOperand(val); AddOperand(ptr); @@ -191,6 +191,25 @@ CmpOp CmpInst::GetCmpOp() const { return cmp_op_; } Value* CmpInst::GetLhs() const { return GetOperand(0); } Value* CmpInst::GetRhs() const { return GetOperand(1); } +FCmpInst::FCmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)), cmp_op_(cmp_op) { + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数")); + } + if (!lhs->GetType() || !rhs->GetType()) { + throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数类型信息")); + } + if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind()) { + throw std::runtime_error(FormatError("ir", "FCmpInst 操作数类型不匹配")); + } + AddOperand(lhs); + AddOperand(rhs); +} + +CmpOp FCmpInst::GetCmpOp() const { return cmp_op_; } +Value* FCmpInst::GetLhs() const { return GetOperand(0); } +Value* FCmpInst::GetRhs() const { return GetOperand(1); } + ZextInst::ZextInst(std::shared_ptr dest_ty, Value* val, std::string name) : Instruction(Opcode::Zext, std::move(dest_ty), std::move(name)) { if (!val) { @@ -231,4 +250,39 @@ Value* CondBranchInst::GetCond() const { return GetOperand(0); } BasicBlock* CondBranchInst::GetTrueBlock() const { return static_cast(GetOperand(1)); } BasicBlock* CondBranchInst::GetFalseBlock() const { return static_cast(GetOperand(2)); } +CallInst::CallInst(Function* func, std::vector args, std::string name) + : Instruction(Opcode::Call, func->GetType(), std::move(name)), func_(func), args_(std::move(args)) { + if (!func) { + throw std::runtime_error(FormatError("ir", "CallInst 缺少目标函数")); + } + AddOperand(func); + for (auto* arg : args_) { + AddOperand(arg); + } +} + +Function* CallInst::GetFunc() const { return func_; } +const std::vector& CallInst::GetArgs() const { return args_; } + +GEPInst::GEPInst(std::shared_ptr ty, Value* ptr, std::vector indices, std::string name) + : Instruction(Opcode::GEP, std::move(ty), std::move(name)), indices_(std::move(indices)) { + AddOperand(ptr); + for (auto* idx : indices_) { + AddOperand(idx); + } +} + +Value* GEPInst::GetPtr() const { return GetOperand(0); } +const std::vector& GEPInst::GetIndices() const { return indices_; } + +SIToFPInst::SIToFPInst(std::shared_ptr ty, Value* val, std::string name) + : Instruction(Opcode::SIToFP, std::move(ty), std::move(name)) { + AddOperand(val); +} + +FPToSIInst::FPToSIInst(std::shared_ptr ty, Value* val, std::string name) + : Instruction(Opcode::FPToSI, std::move(ty), std::move(name)) { + AddOperand(val); +} + } // namespace ir diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 928efdc..3c31d5b 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -18,4 +18,13 @@ const std::vector>& Module::GetFunctions() const { return functions_; } +GlobalVariable* Module::CreateGlobalVariable(const std::string& name, std::shared_ptr type, ConstantValue* init) { + global_variables_.push_back(std::make_unique(name, std::move(type), init)); + return global_variables_.back().get(); +} + +const std::vector>& Module::GetGlobalVariables() const { + return global_variables_; +} + } // namespace ir diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index c32d640..2f7a0d6 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -4,6 +4,8 @@ namespace ir { Type::Type(Kind k) : kind_(k) {} +Type::Type(Kind k, std::shared_ptr elem_ty, int num_elems) + : kind_(k), elem_ty_(std::move(elem_ty)), num_elems_(num_elems) {} const std::shared_ptr& Type::GetVoidType() { static const std::shared_ptr type = std::make_shared(Kind::Void); @@ -21,10 +23,28 @@ const std::shared_ptr& Type::GetInt32Type() { } const std::shared_ptr& Type::GetPtrInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::PtrInt32); + static const std::shared_ptr type = std::make_shared(Kind::PtrInt32, GetInt32Type(), 0); return type; } +const std::shared_ptr& Type::GetFloatType() { + static std::shared_ptr ty = std::make_shared(Kind::Float); + return ty; +} + +const std::shared_ptr& Type::GetPtrFloatType() { + static std::shared_ptr ty = std::make_shared(Kind::PtrFloat, GetFloatType(), 0); + return ty; +} + +std::shared_ptr Type::GetArrayType(std::shared_ptr elem_ty, int num_elems) { + return std::make_shared(Kind::Array, std::move(elem_ty), num_elems); +} + +std::shared_ptr Type::GetPointerType(std::shared_ptr pointed_ty) { + return std::make_shared(Kind::Pointer, std::move(pointed_ty), 0); +} + Type::Kind Type::GetKind() const { return kind_; } bool Type::IsVoid() const { return kind_ == Kind::Void; } @@ -33,6 +53,18 @@ bool Type::IsInt1() const { return kind_ == Kind::Int1; } bool Type::IsInt32() const { return kind_ == Kind::Int32; } -bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } +bool Type::IsPtrInt32() const { + return kind_ == Kind::PtrInt32 || (kind_ == Kind::Pointer && GetPointedType() && GetPointedType()->IsInt32()); +} + +bool Type::IsFloat() const { return kind_ == Kind::Float; } + +bool Type::IsPtrFloat() const { + return kind_ == Kind::PtrFloat || (kind_ == Kind::Pointer && GetPointedType() && GetPointedType()->IsFloat()); +} + +bool Type::IsArray() const { return kind_ == Kind::Array; } + +bool Type::IsPointer() const { return kind_ == Kind::Pointer || kind_ == Kind::PtrInt32 || kind_ == Kind::PtrFloat; } } // namespace ir diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 12a06b4..d3eb29e 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -24,6 +24,10 @@ bool Value::IsInt32() const { return type_ && type_->IsInt32(); } bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } +bool Value::IsFloat() const { return type_ && type_->IsFloat(); } + +bool Value::IsPtrFloat() const { return type_ && type_->IsPtrFloat(); } + bool Value::IsConstant() const { return dynamic_cast(this) != nullptr; } @@ -40,6 +44,10 @@ bool Value::IsFunction() const { return dynamic_cast(this) != nullptr; } +bool Value::IsArgument() const { + return dynamic_cast(this) != nullptr; +} + void Value::AddUse(User* user, size_t operand_index) { if (!user) return; uses_.push_back(Use(this, user, operand_index)); @@ -76,10 +84,29 @@ void Value::ReplaceAllUsesWith(Value* new_value) { } } +Argument::Argument(std::shared_ptr ty, std::string name, Function* parent, size_t arg_no) + : Value(std::move(ty), std::move(name)), parent_(parent), arg_no_(arg_no) {} + +Function* Argument::GetParent() const { return parent_; } + +size_t Argument::GetArgNo() const { return arg_no_; } + ConstantValue::ConstantValue(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} ConstantInt::ConstantInt(std::shared_ptr ty, int v) : ConstantValue(std::move(ty), ""), value_(v) {} +ConstantFloat::ConstantFloat(std::shared_ptr ty, float v) + : ConstantValue(std::move(ty), ""), value_(v) {} + +ConstantArray::ConstantArray(std::shared_ptr ty, std::vector elements) + : ConstantValue(std::move(ty), ""), elements_(std::move(elements)) {} + +ConstantZero::ConstantZero(std::shared_ptr ty) + : ConstantValue(std::move(ty), "") {} + +GlobalVariable::GlobalVariable(std::string name, std::shared_ptr type, ConstantValue* init) + : GlobalValue(std::move(type), std::move(name)), init_(init) {} + } // namespace ir diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 1cd0db8..a237a06 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -6,10 +6,30 @@ #include "ir/IR.h" #include "utils/Log.h" +namespace { +ir::ConstantValue* BuildConstantArray(ir::Context& ctx, std::shared_ptr type, + const std::vector& flattened, + size_t& pos) { + if (!type->IsArray()) { + return flattened[pos++]; + } + std::vector elements; + for (int i = 0; i < type->GetNumElements(); ++i) { + elements.push_back(BuildConstantArray(ctx, type->GetElementType(), flattened, pos)); + } + return ctx.GetConstArray(type, elements); +} +} + std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); } + + // 压入局部作用域 + storage_map_stack_.push_back({}); + const_values_stack_.push_back({}); + bool terminated = false; for (auto* item : ctx->blockItem()) { if (item) { @@ -19,6 +39,11 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { } } } + + // 弹出局部作用域 + storage_map_stack_.pop_back(); + const_values_stack_.pop_back(); + return terminated ? BlockFlow::Terminated : BlockFlow::Continue; } @@ -41,31 +66,204 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); } +std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + if (!ctx) return BlockFlow::Continue; + + if (!ctx->bType() || (!ctx->bType()->INT() && !ctx->bType()->FLOAT())) { + throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float 常量声明")); + } + + for (auto* def : ctx->constDef()) { + if (def) def->accept(this); + } + return BlockFlow::Continue; +} + +void IRGenImpl::FlattenInitVal(SysYParser::InitValContext* ctx, + const std::vector& dims, + const std::vector& sub_sizes, + int dim_idx, + size_t& current_pos, + std::vector& results, + bool is_float) { + if (ctx->exp()) { + ir::Value* val = EvalExpr(*ctx->exp()); + // Implicit conversion + if (is_float && !val->GetType()->IsFloat()) { + val = builder_.CreateSIToFP(val, module_.GetContext().NextTemp()); + } else if (!is_float && val->GetType()->IsFloat()) { + val = builder_.CreateFPToSI(val, module_.GetContext().NextTemp()); + } + results[current_pos++] = val; + } else { + // Nested { ... } + size_t start_pos = current_pos; + for (auto* item : ctx->initVal()) { + FlattenInitVal(item, dims, sub_sizes, dim_idx + 1, current_pos, results, is_float); + } + // Fill remaining with 0 + size_t end_pos = start_pos + sub_sizes[dim_idx]; + while (current_pos < end_pos) { + results[current_pos++] = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) + : (ir::Value*)module_.GetContext().GetConstInt(0); + } + } +} + +void IRGenImpl::FlattenConstInitVal(SysYParser::ConstInitValContext* ctx, + const std::vector& dims, + const std::vector& sub_sizes, + int dim_idx, + size_t& current_pos, + std::vector& results, + bool is_float) { + if (ctx->constExp()) { + ir::Value* val = std::any_cast(ctx->constExp()->accept(this)); + ir::ConstantValue* cval = dynamic_cast(val); + if (!cval) throw std::runtime_error("Not a constant expression"); + + // Constant conversion + if (is_float && dynamic_cast(cval)) { + cval = module_.GetContext().GetConstFloat((float)static_cast(cval)->GetValue()); + } else if (!is_float && dynamic_cast(cval)) { + cval = module_.GetContext().GetConstInt((int)static_cast(cval)->GetValue()); + } + results[current_pos++] = cval; + } else { + size_t start_pos = current_pos; + for (auto* item : ctx->constInitVal()) { + FlattenConstInitVal(item, dims, sub_sizes, dim_idx + 1, current_pos, results, is_float); + } + // Fill remaining with 0 + size_t end_pos = start_pos + sub_sizes[dim_idx]; + while (current_pos < end_pos) { + results[current_pos++] = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f) + : (ir::ConstantValue*)module_.GetContext().GetConstInt(0); + } + } +} + +std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "常量定义缺少名称")); + } + + std::string var_name = ctx->ID()->getText(); + + // Get dimensions + std::vector dims; + for (auto* idx : ctx->constIndex()) { + dims.push_back(EvaluateConstInt(idx->constExp())); + } + + bool is_float = false; + auto* parent_decl = dynamic_cast(ctx->parent); + if (parent_decl && parent_decl->bType() && parent_decl->bType()->FLOAT()) { + is_float = true; + } + auto base_ty = is_float ? ir::Type::GetFloatType() : ir::Type::GetInt32Type(); + + std::shared_ptr var_ty = base_ty; + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + var_ty = ir::Type::GetArrayType(var_ty, *it); + } + + std::vector sub_sizes(dims.size() + 1); + sub_sizes[dims.size()] = 1; + for (int i = (int)dims.size() - 1; i >= 0; --i) { + sub_sizes[i] = sub_sizes[i+1] * dims[i]; + } + + ir::ConstantValue* init_const = nullptr; + std::vector flattened; + if (dims.empty()) { + if (auto* init_val = ctx->constInitVal()) { + if (init_val->constExp()) { + ir::Value* val = std::any_cast(init_val->constExp()->accept(this)); + init_const = dynamic_cast(val); + // Constant conversion + if (is_float && dynamic_cast(init_const)) { + init_const = module_.GetContext().GetConstFloat((float)static_cast(init_const)->GetValue()); + } else if (!is_float && dynamic_cast(init_const)) { + init_const = module_.GetContext().GetConstInt((int)static_cast(init_const)->GetValue()); + } + } + } + } else { + flattened.resize(sub_sizes[0]); + if (auto* init_val = ctx->constInitVal()) { + size_t pos = 0; + FlattenConstInitVal(init_val, dims, sub_sizes, 0, pos, flattened, is_float); + } else { + auto zero = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f) : (ir::ConstantValue*)module_.GetContext().GetConstInt(0); + for (auto& v : flattened) v = zero; + } + size_t pos = 0; + init_const = BuildConstantArray(module_.GetContext(), var_ty, flattened, pos); + } + + // 记录常量值供后续直接使用 (only for scalars for now) + if (dims.empty() && !const_values_stack_.empty()) { + const_values_stack_.back()[var_name] = init_const; + } + + if (func_ == nullptr) { + auto gv_ptr_ty = ir::Type::GetPointerType(var_ty); + auto* gv = module_.CreateGlobalVariable(var_name, gv_ptr_ty, init_const); + if (!storage_map_stack_.empty()) { + storage_map_stack_.back()[var_name] = gv; + } + } else { + // 局部作用域 - 确保 alloca 在入口块 + auto* current_bb = builder_.GetInsertBlock(); + builder_.SetInsertPoint(func_->GetEntry()); + ir::Value* slot = builder_.CreateAlloca(var_ty, module_.GetContext().NextTemp()); + builder_.SetInsertPoint(current_bb); + + if (!storage_map_stack_.empty()) { + storage_map_stack_.back()[var_name] = slot; + } + if (dims.empty()) { + if (init_const) builder_.CreateStore(init_const, slot); + } else { + for (size_t i = 0; i < flattened.size(); ++i) { + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + size_t temp = i; + for (size_t d = 0; d < dims.size(); ++d) { + indices.push_back(builder_.CreateConstInt(temp / sub_sizes[d+1])); + temp %= sub_sizes[d+1]; + } + ir::Value* ptr = builder_.CreateGEP(ir::Type::GetPointerType(base_ty), slot, indices, module_.GetContext().NextTemp()); + builder_.CreateStore(flattened[i], ptr); + } + } + } + + return BlockFlow::Continue; +} + // 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; +// - 先检查声明的基础类型,支持 int 和 float; // - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } - // 当前语法中 decl 包含 constDecl 或 varDecl,这里只支持 varDecl - auto* var_decl = ctx->varDecl(); - if (!var_decl) { - throw std::runtime_error(FormatError("irgen", "当前仅支持变量声明")); - } - if (!var_decl->bType() || !var_decl->bType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); - } - // 遍历所有 varDef - for (auto* var_def : var_decl->varDef()) { - if (var_def) { - var_def->accept(this); + // 当前语法中 decl 包含 constDecl 或 varDecl + if (auto* var_decl = ctx->varDecl()) { + if (!var_decl->bType() || (!var_decl->bType()->INT() && !var_decl->bType()->FLOAT())) { + throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float 变量声明")); } + for (auto* var_def : var_decl->varDef()) { + if (var_def) { + var_def->accept(this); + } + } + } else if (auto* const_decl = ctx->constDecl()) { + return const_decl->accept(this); + } else { + throw std::runtime_error(FormatError("irgen", "当前仅支持变量声明")); } return BlockFlow::Continue; } @@ -76,32 +274,145 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { // - 标量初始化; // - 一个 VarDef 对应一个槽位。 std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); - } - if (!ctx->ID()) { + if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); } - // 暂不支持数组声明(constIndex) - if (!ctx->constIndex().empty()) { - throw std::runtime_error(FormatError("irgen", "暂不支持数组声明")); - } std::string var_name = ctx->ID()->getText(); - if (storage_map_.find(var_name) != storage_map_.end()) { + if (!storage_map_stack_.empty() && storage_map_stack_.back().find(var_name) != storage_map_stack_.back().end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); } - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[var_name] = slot; - ir::Value* init = nullptr; - if (auto* init_val = ctx->initVal()) { - if (!init_val->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); + // Get dimensions + std::vector dims; + for (auto* idx : ctx->constIndex()) { + dims.push_back(EvaluateConstInt(idx->constExp())); + } + + // Determine base type + bool is_float = false; + auto* parent_decl = dynamic_cast(ctx->parent); + if (parent_decl && parent_decl->bType() && parent_decl->bType()->FLOAT()) { + is_float = true; + } + auto base_ty = is_float ? ir::Type::GetFloatType() : ir::Type::GetInt32Type(); + + std::shared_ptr var_ty = base_ty; + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + var_ty = ir::Type::GetArrayType(var_ty, *it); + } + + std::vector sub_sizes(dims.size() + 1); + sub_sizes[dims.size()] = 1; + for (int i = (int)dims.size() - 1; i >= 0; --i) { + sub_sizes[i] = sub_sizes[i+1] * dims[i]; + } + + if (func_ == nullptr) { + // 全局作用域 + ir::ConstantValue* init_const = nullptr; + if (dims.empty()) { + if (auto* init_val = ctx->initVal()) { + if (init_val->exp()) { + auto* val = EvalExpr(*init_val->exp()); + init_const = dynamic_cast(val); + // Constant conversion + if (is_float && dynamic_cast(init_const)) { + init_const = module_.GetContext().GetConstFloat((float)static_cast(init_const)->GetValue()); + } else if (!is_float && dynamic_cast(init_const)) { + init_const = module_.GetContext().GetConstInt((int)static_cast(init_const)->GetValue()); + } + } + } else { + init_const = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f) : (ir::ConstantValue*)module_.GetContext().GetConstInt(0); + } + } else { + if (auto* init_val = ctx->initVal()) { + std::vector flattened(sub_sizes[0]); + // VarDef's InitVal can be an expression or { ... } + if (init_val->exp()) { + auto* val = EvalExpr(*init_val->exp()); + auto* cval = dynamic_cast(val); + flattened[0] = cval; + auto zero = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f) : (ir::ConstantValue*)module_.GetContext().GetConstInt(0); + for (size_t i = 1; i < flattened.size(); ++i) { + flattened[i] = zero; + } + size_t bpos = 0; + init_const = BuildConstantArray(module_.GetContext(), var_ty, flattened, bpos); + } else { + size_t fpos = 0; + std::vector flat_vals(sub_sizes[0]); + auto zero = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0); + for (auto& v : flat_vals) v = zero; + + FlattenInitVal(init_val, dims, sub_sizes, 0, fpos, flat_vals, is_float); + for (size_t i = 0; i < flat_vals.size(); ++i) { + flattened[i] = dynamic_cast(flat_vals[i]); + } + size_t bpos = 0; + init_const = BuildConstantArray(module_.GetContext(), var_ty, flattened, bpos); + } + } else { + init_const = module_.GetContext().GetConstZero(var_ty); + } + } + + auto gv_ptr_ty = ir::Type::GetPointerType(var_ty); + auto* gv = module_.CreateGlobalVariable(var_name, gv_ptr_ty, init_const); + if (!storage_map_stack_.empty()) { + storage_map_stack_.back()[var_name] = gv; } - init = EvalExpr(*init_val->exp()); } else { - init = builder_.CreateConstInt(0); + // 局部作用域 - 确保 alloca 在入口块 + auto* current_bb = builder_.GetInsertBlock(); + builder_.SetInsertPoint(func_->GetEntry()); + ir::Value* slot = builder_.CreateAlloca(var_ty, module_.GetContext().NextTemp()); + builder_.SetInsertPoint(current_bb); + + if (!storage_map_stack_.empty()) { + storage_map_stack_.back()[var_name] = slot; + } + + if (auto* init_val = ctx->initVal()) { + if (dims.empty()) { + if (init_val->exp()) { + ir::Value* init = EvalExpr(*init_val->exp()); + if (is_float && !init->GetType()->IsFloat()) { + init = builder_.CreateSIToFP(init, module_.GetContext().NextTemp()); + } else if (!is_float && init->GetType()->IsFloat()) { + init = builder_.CreateFPToSI(init, module_.GetContext().NextTemp()); + } + builder_.CreateStore(init, slot); + } + } else { + std::vector flattened(sub_sizes[0]); + auto zero = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0); + for (auto& v : flattened) v = zero; + + size_t pos = 0; + FlattenInitVal(init_val, dims, sub_sizes, 0, pos, flattened, is_float); + for (size_t i = 0; i < flattened.size(); ++i) { + // Optimization: only store non-zero? + // For now, store all to be safe. + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + size_t temp = i; + for (size_t d = 0; d < dims.size(); ++d) { + indices.push_back(builder_.CreateConstInt(temp / sub_sizes[d+1])); + temp %= sub_sizes[d+1]; + } + ir::Value* ptr = builder_.CreateGEP(ir::Type::GetPointerType(base_ty), slot, indices, module_.GetContext().NextTemp()); + builder_.CreateStore(flattened[i], ptr); + } + } + } else { + // Initialize scalar locals to 0 + if (dims.empty()) { + ir::Value* zero = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0); + builder_.CreateStore(zero, slot); + } + } } - builder_.CreateStore(init, slot); + return BlockFlow::Continue; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 4565c6e..2d3b471 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -24,6 +24,38 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } +ir::ConstantValue* IRGenImpl::EvaluateConst(antlr4::tree::ParseTree* tree) { + auto val = std::any_cast(tree->accept(this)); + auto* cval = dynamic_cast(val); + if (!cval) throw std::runtime_error("Not a constant expression"); + return cval; +} + +int IRGenImpl::EvaluateConstInt(SysYParser::ConstExpContext* ctx) { + if (!ctx) return 0; + auto* val = EvaluateConst(ctx->addExp()); + if (auto* ci = dynamic_cast(val)) return ci->GetValue(); + if (auto* cf = dynamic_cast(val)) return (int)cf->GetValue(); + return 0; +} + +int IRGenImpl::EvaluateConstInt(SysYParser::ExpContext* ctx) { + if (!ctx) return 0; + auto* val = EvaluateConst(ctx); + if (auto* ci = dynamic_cast(val)) return ci->GetValue(); + if (auto* cf = dynamic_cast(val)) return (int)cf->GetValue(); + return 0; +} + +std::shared_ptr IRGenImpl::GetGEPResultType(ir::Value* ptr, const std::vector& indices) { + auto cur_ty = ptr->GetType()->GetPointedType(); + for (size_t i = 1; i < indices.size(); ++i) { + if (cur_ty->IsArray()) { + cur_ty = cur_ty->GetElementType(); + } + } + return ir::Type::GetPointerType(cur_ty); +} std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (!ctx) { @@ -33,28 +65,9 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (ctx->exp()) { return EvalExpr(*ctx->exp()); } - // 处理 lVal(变量使用)- 交给 visitLVal 处理 + // 处理 lVal(变量使用) if (ctx->lVal()) { - // 直接在这里处理变量读取,避免 accept 调用可能导致的问题 - auto* lval_ctx = ctx->lVal(); - if (!lval_ctx || !lval_ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); - } - const auto* decl = sema_.ResolveObjectUse(lval_ctx); - if (!decl) { - throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定:" + lval_ctx->ID()->getText())); - } - std::string var_name = lval_ctx->ID()->getText(); - auto it = storage_map_.find(var_name); - if (it == storage_map_.end()) { - throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位:" + var_name)); - } - return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); + return ctx->lVal()->accept(this); } // 处理 number if (ctx->number()) { @@ -65,11 +78,21 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { - if (!ctx || !ctx->intConst()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少字面量节点")); } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->intConst()->getText()))); + if (ctx->intConst()) { + // 可能是 0x, 0X, 0 开头的八进制等,目前 std::stoi 会处理十进制, + // 为了支持 16 进制/8 进制建议使用 std::stoi(str, nullptr, 0) + std::string text = ctx->intConst()->getText(); + return static_cast( + builder_.CreateConstInt(std::stoi(text, nullptr, 0))); + } else if (ctx->floatConst()) { + std::string text = ctx->floatConst()->getText(); + return static_cast( + module_.GetContext().GetConstFloat(std::stof(text))); + } + throw std::runtime_error(FormatError("irgen", "不支持的字面量")); } // 变量使用的处理流程: @@ -80,24 +103,73 @@ std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); + throw std::runtime_error(FormatError("irgen", "非法左值")); + } + + std::string var_name = ctx->ID()->getText(); + + // 优先检查是否为已记录的常量 + ir::ConstantValue* const_val = FindConst(var_name); + if (const_val && ctx->exp().empty()) { + return static_cast(const_val); } - const auto* decl = sema_.ResolveObjectUse(ctx); - if (!decl) { + + const auto* binding = sema_.ResolveObjectUse(ctx); + if (!binding) { throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定:" + ctx->ID()->getText())); + FormatError("irgen", "变量使用缺少语义绑定:" + var_name)); } - // 使用变量名查找存储槽位 - std::string var_name = ctx->ID()->getText(); - auto it = storage_map_.find(var_name); - if (it == storage_map_.end()) { + + ir::Value* slot = FindStorage(var_name); + if (!slot) { throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位:" + var_name)); + FormatError("irgen", "变量声明缺少存储槽位:" + var_name)); + } + + ir::Value* ptr = slot; + auto ptr_ty = ptr->GetType(); + bool is_param = false; + // If it's a pointer to a pointer (function parameter case), load the pointer value first + if (ptr_ty->IsPointer() && ptr_ty->GetPointedType()->IsPointer()) { + ptr = builder_.CreateLoad(ptr, module_.GetContext().NextTemp()); + is_param = true; + } else if (ptr->IsArgument()) { + is_param = true; + } + + // Determine if the result of this LVal is a scalar or an array + bool result_is_scalar = (ctx->exp().size() == binding->dimensions.size()); + + if (!ctx->exp().empty()) { + std::vector indices; + // If it's a local array, we need leading 0 + if (ptr->GetType()->IsPointer() && ptr->GetType()->GetPointedType()->IsArray()) { + if (!is_param) { + indices.push_back(builder_.CreateConstInt(0)); + } + } + + for (auto* exp_ctx : ctx->exp()) { + indices.push_back(EvalExpr(*exp_ctx)); + } + + auto res_ptr_ty = GetGEPResultType(ptr, indices); + ptr = builder_.CreateGEP(res_ptr_ty, ptr, indices, module_.GetContext().NextTemp()); + } + + if (result_is_scalar) { + return static_cast(builder_.CreateLoad(ptr, module_.GetContext().NextTemp())); + } else { + // Decay ptr to the first element of the sub-array + while (ptr->GetType()->GetPointedType()->IsArray()) { + std::vector d_indices; + d_indices.push_back(builder_.CreateConstInt(0)); + d_indices.push_back(builder_.CreateConstInt(0)); + auto d_res_ty = GetGEPResultType(ptr, d_indices); + ptr = builder_.CreateGEP(d_res_ty, ptr, d_indices, module_.GetContext().NextTemp()); + } + return ptr; } - return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); } @@ -119,6 +191,43 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); + if (lhs->IsConstant() && rhs->IsConstant()) { + auto* cl = static_cast(lhs); + auto* cr = static_cast(rhs); + if (auto* cil = dynamic_cast(cl)) { + if (auto* cir = dynamic_cast(cr)) { + if (ctx->ADD()) return static_cast(module_.GetContext().GetConstInt(cil->GetValue() + cir->GetValue())); + if (ctx->SUB()) return static_cast(module_.GetContext().GetConstInt(cil->GetValue() - cir->GetValue())); + } + } + } + + // Implicit conversion + if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) { + if (rhs->IsConstant()) { + rhs = module_.GetContext().GetConstFloat((float)static_cast(rhs)->GetValue()); + } else { + rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp()); + } + } else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) { + if (lhs->IsConstant()) { + lhs = module_.GetContext().GetConstFloat((float)static_cast(lhs)->GetValue()); + } else { + lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp()); + } + } + + if (lhs->IsConstant() && rhs->IsConstant()) { + auto* cl = static_cast(lhs); + auto* cr = static_cast(rhs); + if (auto* cfl = dynamic_cast(cl)) { + if (auto* cfr = dynamic_cast(cr)) { + if (ctx->ADD()) return static_cast(module_.GetContext().GetConstFloat(cfl->GetValue() + cfr->GetValue())); + if (ctx->SUB()) return static_cast(module_.GetContext().GetConstFloat(cfl->GetValue() - cfr->GetValue())); + } + } + } + ir::Opcode op = ir::Opcode::Add; if (ctx->ADD()) { op = ir::Opcode::Add; @@ -144,21 +253,99 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { } // 处理函数调用(unaryExp : ID LPAREN funcRParams? RPAREN) - // 当前暂不支持,留给后续扩展 if (ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "暂不支持函数调用")); + std::string func_name = ctx->ID()->getText(); + + // 从 Sema 或 Module 中查找函数 + // 目前简化处理,直接从 Module 中查找(如果是当前文件内定义的) + // 或者依赖 Sema 给出解析结果 + const FunctionBinding* func_binding = sema_.ResolveFunctionCall(ctx); + if (!func_binding) { + throw std::runtime_error(FormatError("irgen", "未找到函数声明:" + func_name)); + } + + // 假设 func_binding 能够找到对应的 ir::Function* + // 这里如果 sema 不提供直接拿 ir::Function 的方式,需要遍历 module_.GetFunctions() 查找 + ir::Function* target_func = nullptr; + for (const auto& f : module_.GetFunctions()) { + if (f->GetName() == func_name) { + target_func = f.get(); + break; + } + } + + if (!target_func) { + // 可能是外部函数如 putint, getint 等 + // 如果没有在 module_ 中,则需要创建一个只有声明的 Function + std::shared_ptr ret_ty; + if (func_binding->return_type == SemanticType::Int) { + ret_ty = ir::Type::GetInt32Type(); + } else if (func_binding->return_type == SemanticType::Float) { + ret_ty = ir::Type::GetFloatType(); + } else { + ret_ty = ir::Type::GetVoidType(); + } + target_func = module_.CreateFunction(func_name, ret_ty); + // 对于外部函数,需要传递参数,可能还需要在 target_func 中 AddArgument + for (const auto& param : func_binding->params) { + std::shared_ptr p_ty; + if (param.type == SemanticType::Int) { + p_ty = param.dimensions.empty() && !param.is_array_param ? ir::Type::GetInt32Type() : ir::Type::GetPtrInt32Type(); + } else { + p_ty = param.dimensions.empty() && !param.is_array_param ? ir::Type::GetFloatType() : ir::Type::GetPtrFloatType(); + } + target_func->AddArgument(p_ty, param.name); + } + } + + std::vector args; + if (ctx->funcRParams()) { + args = std::any_cast>(ctx->funcRParams()->accept(this)); + } + + // Implicit conversion for function arguments + const auto& formal_args = target_func->GetArgs(); + for (size_t i = 0; i < std::min(args.size(), formal_args.size()); ++i) { + if (formal_args[i]->GetType()->IsFloat() && !args[i]->GetType()->IsFloat()) { + args[i] = builder_.CreateSIToFP(args[i], module_.GetContext().NextTemp()); + } else if (formal_args[i]->GetType()->IsInt32() && args[i]->GetType()->IsFloat()) { + args[i] = builder_.CreateFPToSI(args[i], module_.GetContext().NextTemp()); + } + } + + return static_cast(builder_.CreateCall(target_func, args, module_.GetContext().NextTemp())); } // 处理一元运算符(unaryExp : addUnaryOp unaryExp) if (ctx->addUnaryOp() && ctx->unaryExp()) { ir::Value* operand = std::any_cast(ctx->unaryExp()->accept(this)); + // Constant folding for unary op + if (operand->IsConstant()) { + if (ctx->addUnaryOp()->SUB()) { + if (auto* ci = dynamic_cast(operand)) { + return static_cast(module_.GetContext().GetConstInt(-ci->GetValue())); + } else if (auto* cf = dynamic_cast(operand)) { + return static_cast(module_.GetContext().GetConstFloat(-cf->GetValue())); + } + } else { + return operand; + } + } + // 判断是正号还是负号 if (ctx->addUnaryOp()->SUB()) { - // 负号:生成 sub 0, operand(LLVM IR 中没有 neg 指令) - ir::Value* zero = builder_.CreateConstInt(0); - return static_cast( - builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); + // 负号:如果是整数生成 sub 0, operand,浮点数生成 fsub 0.0, operand + if (operand->GetType()->IsFloat()) { + ir::Value* zero = module_.GetContext().GetConstFloat(0.0f); + // 此处暂且假设 CreateSub 可以处理浮点数(如果底层有 fsub 则更好) + return static_cast( + builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); + } else { + ir::Value* zero = builder_.CreateConstInt(0); + return static_cast( + builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); + } } else if (ctx->addUnaryOp()->ADD()) { // 正号:直接返回操作数(+x 等价于 x) return operand; @@ -188,6 +375,45 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); + // Constant folding + if (lhs->IsConstant() && rhs->IsConstant()) { + auto* cl = static_cast(lhs); + auto* cr = static_cast(rhs); + if (auto* cil = dynamic_cast(cl)) { + if (auto* cir = dynamic_cast(cr)) { + if (ctx->MUL()) return static_cast(module_.GetContext().GetConstInt(cil->GetValue() * cir->GetValue())); + if (ctx->DIV()) return static_cast(module_.GetContext().GetConstInt(cil->GetValue() / cir->GetValue())); + if (ctx->MOD()) return static_cast(module_.GetContext().GetConstInt(cil->GetValue() % cir->GetValue())); + } + } + } + + // Implicit conversion + if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) { + if (rhs->IsConstant()) { + rhs = module_.GetContext().GetConstFloat((float)static_cast(rhs)->GetValue()); + } else { + rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp()); + } + } else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) { + if (lhs->IsConstant()) { + lhs = module_.GetContext().GetConstFloat((float)static_cast(lhs)->GetValue()); + } else { + lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp()); + } + } + + if (lhs->IsConstant() && rhs->IsConstant()) { + auto* cl = static_cast(lhs); + auto* cr = static_cast(rhs); + if (auto* cfl = dynamic_cast(cl)) { + if (auto* cfr = dynamic_cast(cr)) { + if (ctx->MUL()) return static_cast(module_.GetContext().GetConstFloat(cfl->GetValue() * cfr->GetValue())); + if (ctx->DIV()) return static_cast(module_.GetContext().GetConstFloat(cfl->GetValue() / cfr->GetValue())); + } + } + } + ir::Opcode op = ir::Opcode::Mul; if (ctx->MUL()) { op = ir::Opcode::Mul; @@ -212,6 +438,13 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); + // Implicit conversion + if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) { + rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp()); + } else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) { + lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp()); + } + ir::CmpOp op; if (ctx->LT()) op = ir::CmpOp::Lt; else if (ctx->GT()) op = ir::CmpOp::Gt; @@ -231,6 +464,13 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); + // Implicit conversion + if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) { + rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp()); + } else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) { + lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp()); + } + ir::CmpOp op; if (ctx->EQ()) op = ir::CmpOp::Eq; else if (ctx->NE()) op = ir::CmpOp::Ne; @@ -248,8 +488,13 @@ std::any IRGenImpl::visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) { if (operand->GetType()->IsInt1()) { operand = builder_.CreateZext(operand, module_.GetContext().NextTemp()); } - ir::Value* zero = builder_.CreateConstInt(0); - return static_cast(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp())); + if (operand->GetType()->IsFloat()) { + ir::Value* zero = module_.GetContext().GetConstFloat(0.0f); + return static_cast(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp())); + } else { + ir::Value* zero = builder_.CreateConstInt(0); + return static_cast(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp())); + } } throw std::runtime_error(FormatError("irgen", "非法条件一元表达式")); } @@ -326,4 +571,12 @@ std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { throw std::runtime_error(FormatError("irgen", "非法条件表达式")); } return ctx->lOrExp()->accept(this); +} + +std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { + std::vector args; + for (auto* exp : ctx->exp()) { + args.push_back(EvalExpr(*exp)); + } + return args; } \ No newline at end of file diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4ee5b3e..95a9ad0 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -38,15 +38,25 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - // 遍历所有 topLevelItem,找到 funcDef + // 初始化全局作用域 + storage_map_stack_.push_back({}); + const_values_stack_.push_back({}); + + // 遍历所有 topLevelItem for (auto* item : ctx->topLevelItem()) { - if (item && item->funcDef()) { + if (!item) continue; + if (item->funcDef()) { item->funcDef()->accept(this); - // 当前只支持单个函数,找到第一个后就返回 - return {}; + } else if (item->decl()) { + item->decl()->accept(this); } } - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + + // 退出全局作用域 + storage_map_stack_.pop_back(); + const_values_stack_.pop_back(); + + return {}; } // 函数 IR 生成当前实现了: @@ -74,16 +84,88 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "缺少函数名")); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + + std::shared_ptr ret_type; + if (ctx->funcType()->INT()) { + ret_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->FLOAT()) { + ret_type = ir::Type::GetFloatType(); + } else if (ctx->funcType()->VOID()) { + ret_type = ir::Type::GetVoidType(); + } else { + throw std::runtime_error(FormatError("irgen", "未知的函数返回类型")); } - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); - storage_map_.clear(); + func_ = module_.CreateFunction(ctx->ID()->getText(), ret_type); + ir::BasicBlock* alloca_bb = func_->CreateBlock("alloca"); + ir::BasicBlock* entry_bb = func_->CreateBlock("entry"); + builder_.SetInsertPoint(entry_bb); + + // 进入函数作用域,压入一个新的 map + storage_map_stack_.push_back({}); + const_values_stack_.push_back({}); + + if (ctx->funcFParams()) { + for (auto* paramCtx : ctx->funcFParams()->funcFParam()) { + std::shared_ptr param_type; + bool is_array = !paramCtx->LBRACK().empty(); + + auto base_sema_ty = paramCtx->bType()->INT() ? SemanticType::Int : SemanticType::Float; + auto base_ir_ty = (base_sema_ty == SemanticType::Int) ? ir::Type::GetInt32Type() : ir::Type::GetFloatType(); + + if (is_array) { + std::shared_ptr elem_ty = base_ir_ty; + auto exps = paramCtx->exp(); + for (auto it = exps.rbegin(); it != exps.rend(); ++it) { + int dim = EvaluateConstInt(*it); + elem_ty = ir::Type::GetArrayType(elem_ty, dim); + } + param_type = ir::Type::GetPointerType(elem_ty); + } else { + param_type = base_ir_ty; + } + + std::string arg_name = paramCtx->ID()->getText(); + auto* arg = func_->AddArgument(param_type, "%arg" + std::to_string(func_->GetArgs().size())); + + // Ensure param alloca is in alloca_bb + auto* current_bb = builder_.GetInsertBlock(); + builder_.SetInsertPoint(alloca_bb); + ir::Instruction* alloca_inst = builder_.CreateAlloca(param_type, module_.GetContext().NextTemp()); + builder_.SetInsertPoint(current_bb); + + builder_.CreateStore(arg, alloca_inst); + storage_map_stack_.back()[arg_name] = alloca_inst; + } + } ctx->block()->accept(this); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 + + // Implicit return for void functions or main + if (!builder_.GetInsertBlock()->HasTerminator()) { + if (func_->GetType()->IsVoid()) { + builder_.CreateRet(nullptr); + } else if (func_->GetName() == "main") { + builder_.CreateRet(builder_.CreateConstInt(0)); + } else { + if (func_->GetType()->IsFloat()) { + builder_.CreateRet(module_.GetContext().GetConstFloat(0.0f)); + } else { + builder_.CreateRet(builder_.CreateConstInt(0)); + } + } + } + + // Branch from alloca_bb to entry_bb + builder_.SetInsertPoint(alloca_bb); + builder_.CreateBr(entry_bb); + VerifyFunctionStructure(*func_); + func_ = nullptr; + + // 退出函数作用域,弹出 map + storage_map_stack_.pop_back(); + const_values_stack_.pop_back(); + return {}; } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index e44bd0a..3b4ff09 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -35,12 +35,44 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { FormatError("irgen", "变量使用缺少语义绑定:" + lval_ctx->ID()->getText())); } std::string var_name = lval_ctx->ID()->getText(); - auto it = storage_map_.find(var_name); - if (it == storage_map_.end()) { + ir::Value* slot = FindStorage(var_name); + if (!slot) { throw std::runtime_error( FormatError("irgen", "变量声明缺少存储槽位:" + var_name)); } - builder_.CreateStore(rhs, it->second); + + ir::Value* ptr = slot; + auto ptr_ty = ptr->GetType(); + bool is_param = false; + // If it's a pointer to a pointer (function parameter case), load the pointer value first + if (ptr_ty->IsPointer() && ptr_ty->GetPointedType()->IsPointer()) { + ptr = builder_.CreateLoad(ptr, module_.GetContext().NextTemp()); + is_param = true; + } + if (ptr->IsArgument()) is_param = true; + + if (!lval_ctx->exp().empty()) { + std::vector indices; + if (ptr->GetType()->IsPointer() && ptr->GetType()->GetPointedType()->IsArray()) { + if (!is_param) { + indices.push_back(builder_.CreateConstInt(0)); + } + } + for (auto* exp_ctx : lval_ctx->exp()) { + indices.push_back(EvalExpr(*exp_ctx)); + } + auto res_ptr_ty = GetGEPResultType(ptr, indices); + ptr = builder_.CreateGEP(res_ptr_ty, ptr, indices, module_.GetContext().NextTemp()); + } + + // Implicit conversion for assignment + if ((ptr->GetType()->IsPtrFloat() || (ptr->GetType()->IsArray() && ptr->GetType()->GetElementType()->IsFloat())) && !rhs->GetType()->IsFloat()) { + rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp()); + } else if ((ptr->GetType()->IsPtrInt32() || (ptr->GetType()->IsArray() && ptr->GetType()->GetElementType()->IsInt32())) && rhs->GetType()->IsFloat()) { + rhs = builder_.CreateFPToSI(rhs, module_.GetContext().NextTemp()); + } + + builder_.CreateStore(rhs, ptr); return BlockFlow::Continue; } @@ -50,6 +82,9 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (cond_val->GetType()->IsInt32()) { ir::Value* zero = builder_.CreateConstInt(0); cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp()); + } else if (cond_val->GetType()->IsFloat()) { + ir::Value* zero = module_.GetContext().GetConstFloat(0.0f); + cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp()); } ir::BasicBlock* then_bb = func_->CreateBlock(NextBlockName("if_then")); @@ -88,6 +123,9 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (cond_val->GetType()->IsInt32()) { ir::Value* zero = builder_.CreateConstInt(0); cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp()); + } else if (cond_val->GetType()->IsFloat()) { + ir::Value* zero = module_.GetContext().GetConstFloat(0.0f); + cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp()); } builder_.CreateCondBr(cond_val, body_bb, exit_bb); @@ -128,9 +166,15 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (ctx->RETURN()) { if (ctx->exp()) { ir::Value* v = EvalExpr(*ctx->exp()); + // Handle return type conversion if necessary + if (func_->GetType()->IsFloat() && !v->GetType()->IsFloat()) { + v = builder_.CreateSIToFP(v, module_.GetContext().NextTemp()); + } else if (func_->GetType()->IsInt32() && v->GetType()->IsFloat()) { + v = builder_.CreateFPToSI(v, module_.GetContext().NextTemp()); + } builder_.CreateRet(v); } else { - throw std::runtime_error(FormatError("irgen", "暂不支持 void return")); + builder_.CreateRet(nullptr); // nullptr for void ret } return BlockFlow::Terminated; } diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 9a18396..6753a77 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -87,9 +87,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, case ir::Opcode::Sub: case ir::Opcode::Mul: throw std::runtime_error(FormatError("mir", "暂不支持该二元运算")); + default: + throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); } - - throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); } } // namespace diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 95f0629..9f7657c 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -156,12 +156,21 @@ class SemaVisitor final : public SysYBaseVisitor { ThrowSemaError(ctx, "缺少编译单元"); } + // 第一遍:处理所有全局声明 + for (auto* item : ctx->topLevelItem()) { + if (item && item->decl()) { + item->decl()->accept(this); + } + } + + // 第二遍:收集所有函数签名 CollectFunctions(*ctx); + + // 第三遍:处理所有函数体 for (auto* item : ctx->topLevelItem()) { - if (!item) { - continue; + if (item && item->funcDef()) { + item->funcDef()->accept(this); } - item->accept(this); } const FunctionBinding* main = sema_.ResolveFunction("main"); diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..1bc04ca 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -1,4 +1,49 @@ -// SysY 运行库实现: -// - 按实验/评测规范提供 I/O 等函数实现 -// - 与编译器生成的目标代码链接,支撑运行时行为 +#include +#include +#include +/* Input functions */ +int getint() { int t; scanf("%d", &t); return t; } +int getch() { char t; scanf("%c", &t); return (int)t; } +float getfloat() { float t; scanf("%f", &t); return t; } + +int getarray(int a[]) { + int n; + scanf("%d", &n); + for (int i = 0; i < n; i++) scanf("%d", &a[i]); + return n; +} + +int getfarray(float a[]) { + int n; + scanf("%d", &n); + for (int i = 0; i < n; i++) scanf("%f", &a[i]); + return n; +} + +/* Output functions */ +void putint(int a) { printf("%d", a); } +void putch(int a) { printf("%c", (char)a); } +void putfloat(float a) { printf("%a", a); } + +void putarray(int n, int a[]) { + printf("%d:", n); + for (int i = 0; i < n; i++) printf(" %d", a[i]); + printf("\n"); +} + +void putfarray(int n, float a[]) { + printf("%d:", n); + for (int i = 0; i < n; i++) printf(" %a", a[i]); + printf("\n"); +} + +/* Timing functions */ +struct timeval _sysy_start, _sysy_end; +void starttime() { gettimeofday(&_sysy_start, NULL); } +void stoptime() { + gettimeofday(&_sysy_end, NULL); + int millis = (_sysy_end.tv_sec - _sysy_start.tv_sec) * 1000 + + (_sysy_end.tv_usec - _sysy_start.tv_usec) / 1000; + fprintf(stderr, "Timer: %d ms\n", millis); +} diff --git a/test_all.sh b/test_all.sh new file mode 100644 index 0000000..36c5d4d --- /dev/null +++ b/test_all.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# 批量测试脚本 +# 遍历 test/test_case 目录下所有的 .sy 文件,并验证解析是否成功 + +if [ ! -f "./build/bin/compiler" ]; then + echo "Compiler executable not found at ./build/bin/compiler. Please build the project first." + exit 1 +fi + +FAIL_COUNT=0 +PASS_COUNT=0 +FAILED_FILES=() + +echo "开始批量测试解析..." +echo "=========================================" + +# 查找所有 .sy 文件并进行测试 +while IFS= read -r file; do + # 运行解析器,将正常输出重定向到 /dev/null,保留错误输出用于判断 + ./build/bin/compiler --emit-parse-tree "$file" > /dev/null 2>&1 + + if [ $? -ne 0 ]; then + echo "❌ 解析失败: $file" + FAIL_COUNT=$((FAIL_COUNT+1)) + FAILED_FILES+=("$file") + else + echo "✅ 解析成功: $file" + PASS_COUNT=$((PASS_COUNT+1)) + fi +done < <(find test/test_case -type f -name "*.sy" | sort) + +echo "=========================================" +echo "测试完成!" +echo "成功: $PASS_COUNT" +echo "失败: $FAIL_COUNT" + +if [ $FAIL_COUNT -ne 0 ]; then + echo "失败的文件列表:" + for f in "${FAILED_FILES[@]}"; do + echo " - $f" + done + exit 1 +else + echo "🎉 所有测试用例均解析成功!" + exit 0 +fi diff --git a/终端信息.txt b/终端信息.txt new file mode 100644 index 0000000..ae54c24 --- /dev/null +++ b/终端信息.txt @@ -0,0 +1,19 @@ +(base) root@HP:/home/hp/nudt-compiler-cpp/build# make -j$(nproc) +[ 2%] Built target utils +[ 2%] Building CXX object src/ir/CMakeFiles/ir_core.dir/Type.cpp.o +[ 3%] Building CXX object src/ir/CMakeFiles/ir_core.dir/Value.cpp.o +[ 73%] Built target antlr4_runtime +[ 75%] Built target sem +[ 79%] Built target frontend +[ 80%] Linking CXX static library libir_core.a +[ 84%] Built target ir_core +[ 85%] Built target ir_analysis +[ 89%] Built target ir_passes +[ 94%] Built target mir_core +[ 97%] Built target irgen +[ 99%] Built target mir_passes +[ 99%] Linking CXX executable ../bin/compiler +[100%] Built target compiler +(base) root@HP:/home/hp/nudt-compiler-cpp/build# cd .. +(base) root@HP:/home/hp/nudt-compiler-cpp# ./scripts/verify_ir.sh test/test_case/functional/09_func_defn.sy --run +[error] [irgen] 变量声明缺少存储槽位:a \ No newline at end of file