diff --git a/include/ir/IR.h b/include/ir/IR.h index b961192..26bd510 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -1,35 +1,7 @@ -// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 -// -// 当前已经实现: -// 1. 基础类型系统:void / i32 / i32* -// 2. Value 体系:Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction -// 3. 最小指令集:Add / Alloca / Load / Store / Ret -// 4. BasicBlock / Function / Module 三层组织结构 -// 5. IRBuilder:便捷创建常量和最小指令 -// 6. def-use 关系的轻量实现: -// - Instruction 保存 operand 列表 -// - Value 保存 uses -// - 支持 ReplaceAllUsesWith 的简化实现 -// -// 当前尚未实现或只做了最小占位: -// 1. 完整类型系统:数组、函数类型、label 类型等 -// 2. 更完整的指令系统:br / condbr / call / phi / gep 等 -// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构) -// 4. 更完整的 IR verifier 和优化基础设施 -// -// 当前需要特别说明的两个简化点: -// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位, -// 后续如果补 label type,可以再改成更合理的块标签类型。 -// 2. ConstantValue 体系目前只实现了 ConstantInt,后续可以继续补 ConstantFloat、 -// ConstantArray等更完整的常量种类。 -// -// 建议的扩展顺序: -// 1. 先补更多指令和类型 -// 2. 再补控制流相关 IR -// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架 - #pragma once +#include +#include #include #include #include @@ -45,17 +17,27 @@ class Value; class User; class ConstantValue; class ConstantInt; +class ConstantFloat; +class ConstantZero; +class ConstantArray; class GlobalValue; +class GlobalVariable; +class Argument; class Instruction; +class BinaryInst; +class CompareInst; +class ReturnInst; +class AllocaInst; +class LoadInst; +class StoreInst; +class BranchInst; +class CondBranchInst; +class CallInst; +class GetElementPtrInst; +class CastInst; class BasicBlock; class Function; -// Use 表示一个 Value 的一次使用记录。 -// 当前实现设计: -// - value:被使用的值 -// - user:使用该值的 User -// - operand_index:该值在 user 操作数列表中的位置 - class Use { public: Use() = default; @@ -66,64 +48,111 @@ class Use { User* GetUser() const { return user_; } size_t GetOperandIndex() const { return operand_index_; } - void SetValue(Value* value) { value_ = value; } - void SetUser(User* user) { user_ = user; } - void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; } - private: Value* value_ = nullptr; User* user_ = nullptr; size_t operand_index_ = 0; }; -// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。 class Context { public: Context() = default; ~Context(); - // 去重创建 i32 常量。 + ConstantInt* GetConstInt(int v); + ConstantFloat* GetConstFloat(float v); + + template + T* CreateOwnedConstant(Args&&... args) { + auto value = std::make_unique(std::forward(args)...); + auto* ptr = value.get(); + owned_constants_.push_back(std::move(value)); + return ptr; + } std::string NextTemp(); + std::string NextBlock(const std::string& prefix); private: std::unordered_map> const_ints_; + std::unordered_map> const_floats_; + std::vector> owned_constants_; int temp_index_ = -1; + int block_index_ = -1; }; class Type { public: - enum class Kind { Void, Int32, PtrInt32 }; - explicit Type(Kind k); - // 使用静态共享对象获取类型。 - // 同一类型可直接比较返回值是否相等,例如: - // Type::GetInt32Type() == Type::GetInt32Type() + enum class Kind { Void, Int1, Int32, Float32, Pointer, Array, Function }; + + explicit Type(Kind kind); + Type(Kind kind, std::shared_ptr element_type); + Type(Kind kind, std::shared_ptr element_type, size_t array_size); + Type(std::shared_ptr return_type, std::vector> params); + static const std::shared_ptr& GetVoidType(); + static const std::shared_ptr& GetInt1Type(); static const std::shared_ptr& GetInt32Type(); + static const std::shared_ptr& GetFloatType(); + static std::shared_ptr GetPointerType(std::shared_ptr element_type); + static std::shared_ptr GetArrayType(std::shared_ptr element_type, + size_t array_size); + static std::shared_ptr GetFunctionType( + std::shared_ptr return_type, + std::vector> param_types); static const std::shared_ptr& GetPtrInt32Type(); + Kind GetKind() const; + const std::shared_ptr& GetElementType() const; + size_t GetArraySize() const; + const std::shared_ptr& GetReturnType() const; + const std::vector>& GetParamTypes() const; + bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; + bool IsFloat32() const; + bool IsPointer() const; + bool IsArray() const; + bool IsFunction() const; + bool IsScalar() const; + bool IsInteger() const; + bool IsNumeric() const; bool IsPtrInt32() const; + bool Equals(const Type& other) const; private: Kind kind_; + std::shared_ptr element_type_; + size_t array_size_ = 0; + std::shared_ptr return_type_; + std::vector> param_types_; }; class Value { public: Value(std::shared_ptr ty, std::string name); virtual ~Value() = default; + const std::shared_ptr& GetType() const; const std::string& GetName() const; - void SetName(std::string n); + void SetName(std::string name); + bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; + bool IsFloat32() const; + bool IsPointer() const; + bool IsArray() const; + bool IsFunctionValue() const; bool IsPtrInt32() const; bool IsConstant() const; bool IsInstruction() const; bool IsUser() const; bool IsFunction() const; + bool IsGlobalVariable() const; + bool IsArgument() const; + void AddUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index); const std::vector& GetUses() const; @@ -135,52 +164,116 @@ class Value { std::vector uses_; }; -// ConstantValue 是常量体系的基类。 -// 当前只实现了 ConstantInt,后续可继续扩展更多常量种类。 class ConstantValue : public Value { public: ConstantValue(std::shared_ptr ty, std::string name = ""); + virtual bool IsZeroValue() const = 0; }; class ConstantInt : public ConstantValue { public: - ConstantInt(std::shared_ptr ty, int v); + ConstantInt(std::shared_ptr ty, int value); int GetValue() const { return value_; } + bool IsZeroValue() const override { return value_ == 0; } + + private: + int value_ = 0; +}; + +class ConstantFloat : public ConstantValue { + public: + ConstantFloat(std::shared_ptr ty, float value); + float GetValue() const { return value_; } + bool IsZeroValue() const override { return value_ == 0.0f; } private: - int value_{}; + float value_ = 0.0f; }; -// 后续还需要扩展更多指令类型。 -enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +class ConstantZero : public ConstantValue { + public: + explicit ConstantZero(std::shared_ptr ty); + bool IsZeroValue() const override { return true; } +}; + +class ConstantArray : public ConstantValue { + public: + ConstantArray(std::shared_ptr ty, std::vector elements); + + const std::vector& GetElements() const { return elements_; } + bool IsZeroValue() const override; + + private: + std::vector elements_; +}; + +enum class Opcode { + Add, + Sub, + Mul, + SDiv, + SRem, + FAdd, + FSub, + FMul, + FDiv, + Alloca, + Load, + Store, + ICmp, + FCmp, + Br, + CondBr, + Call, + GEP, + SIToFP, + FPToSI, + ZExt, + Ret, +}; + +enum class ICmpPred { Eq, Ne, Slt, Sle, Sgt, Sge }; +enum class FCmpPred { Oeq, One, Olt, Ole, Ogt, Oge }; -// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 -// 当前实现中只有 Instruction 继承自 User。 class User : public Value { public: User(std::shared_ptr ty, std::string name); + size_t GetNumOperands() const; Value* GetOperand(size_t index) const; void SetOperand(size_t index, Value* value); protected: - // 统一的 operand 入口。 void AddOperand(Value* value); private: std::vector operands_; }; -// GlobalValue 是全局值/全局变量体系的空壳占位类。 -// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。 -class GlobalValue : public User { +class GlobalValue : public Value { public: GlobalValue(std::shared_ptr ty, std::string name); }; +class GlobalVariable : public GlobalValue { + public: + GlobalVariable(std::string name, std::shared_ptr value_type, + ConstantValue* initializer, bool is_constant); + + const std::shared_ptr& GetValueType() const { return value_type_; } + ConstantValue* GetInitializer() const { return initializer_; } + bool IsConstant() const { return is_constant_; } + + private: + std::shared_ptr value_type_; + ConstantValue* initializer_ = nullptr; + bool is_constant_ = false; +}; + class Instruction : public User { public: Instruction(Opcode op, std::shared_ptr ty, std::string name = ""); + Opcode GetOpcode() const; bool IsTerminator() const; BasicBlock* GetParent() const; @@ -195,45 +288,116 @@ class BinaryInst : public Instruction { public: BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name); + Value* GetLhs() const; - Value* GetRhs() const; + Value* GetRhs() const; +}; + +class CompareInst : public Instruction { + public: + CompareInst(ICmpPred pred, Value* lhs, Value* rhs, std::string name); + CompareInst(FCmpPred pred, Value* lhs, Value* rhs, std::string name); + + bool IsFloatCompare() const { return is_float_compare_; } + ICmpPred GetICmpPred() const { return icmp_pred_; } + FCmpPred GetFCmpPred() const { return fcmp_pred_; } + Value* GetLhs() const; + Value* GetRhs() const; + + private: + bool is_float_compare_ = false; + ICmpPred icmp_pred_ = ICmpPred::Eq; + FCmpPred fcmp_pred_ = FCmpPred::Oeq; }; class ReturnInst : public Instruction { public: - ReturnInst(std::shared_ptr void_ty, Value* val); + explicit ReturnInst(Value* value); + ReturnInst(); + Value* GetValue() const; }; class AllocaInst : public Instruction { public: - AllocaInst(std::shared_ptr ptr_ty, std::string name); + AllocaInst(std::shared_ptr allocated_type, std::string name); + + const std::shared_ptr& GetAllocatedType() const { return allocated_type_; } + + private: + std::shared_ptr allocated_type_; }; class LoadInst : public Instruction { public: - LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name); + LoadInst(Value* ptr, std::shared_ptr value_type, std::string name); + Value* GetPtr() const; }; class StoreInst : public Instruction { public: - StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr); + StoreInst(Value* value, Value* ptr); + Value* GetValue() const; Value* GetPtr() const; }; -// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 -// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 +class BranchInst : public Instruction { + public: + explicit BranchInst(BasicBlock* target); + + BasicBlock* GetTarget() const; +}; + +class CondBranchInst : public Instruction { + public: + CondBranchInst(Value* cond, BasicBlock* true_block, BasicBlock* false_block); + + Value* GetCond() const; + BasicBlock* GetTrueBlock() const; + BasicBlock* GetFalseBlock() const; +}; + +class CallInst : public Instruction { + public: + CallInst(Function* callee, std::vector args, std::string name); + + Function* GetCallee() const; + std::vector GetArgs() const; +}; + +class GetElementPtrInst : public Instruction { + public: + GetElementPtrInst(Value* base_ptr, std::vector indices, + std::shared_ptr result_type, std::string name); + + Value* GetBasePtr() const; + std::vector GetIndices() const; + std::shared_ptr GetSourceElementType() const; +}; + +class CastInst : public Instruction { + public: + CastInst(Opcode op, Value* value, std::shared_ptr dst_type, + std::string name); + + Value* GetValue() const; +}; + class BasicBlock : public Value { public: explicit BasicBlock(std::string name); + Function* GetParent() const; void SetParent(Function* parent); bool HasTerminator() const; + void AddSuccessor(BasicBlock* succ); + const std::vector>& GetInstructions() const; const std::vector& GetPredecessors() const; const std::vector& GetSuccessors() const; + template T* Append(Args&&... args) { if (HasTerminator()) { @@ -254,60 +418,105 @@ class BasicBlock : public Value { std::vector successors_; }; -// Function 当前也采用了最小实现。 -// 需要特别注意:由于项目里还没有单独的 FunctionType, -// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, -// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 -// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 -// 形参和调用,通常需要引入专门的函数类型表示。 -class Function : public Value { +class Argument : public Value { public: - // 当前构造函数接收的也是返回类型,而不是完整函数类型。 - Function(std::string name, std::shared_ptr ret_type); + Argument(std::shared_ptr ty, std::string name, size_t index, + Function* parent); + + size_t GetIndex() const { return index_; } + Function* GetParent() const { return parent_; } + + private: + size_t index_ = 0; + Function* parent_ = nullptr; +}; + +class Function : public GlobalValue { + public: + Function(std::string name, std::shared_ptr function_type, + bool is_declaration); + + const std::shared_ptr& GetFunctionType() const; + const std::shared_ptr& GetReturnType() const; + const std::vector>& GetArguments() const; + bool IsDeclaration() const { return is_declaration_; } + + Argument* AddArgument(std::shared_ptr ty, const std::string& name); BasicBlock* CreateBlock(const std::string& name); BasicBlock* GetEntry(); const BasicBlock* GetEntry() const; const std::vector>& GetBlocks() const; private: + bool is_declaration_ = false; BasicBlock* entry_ = nullptr; + std::vector> arguments_; std::vector> blocks_; }; class Module { public: Module() = default; + Context& GetContext(); const Context& GetContext() const; - // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 + + GlobalVariable* CreateGlobal(std::string name, std::shared_ptr value_type, + ConstantValue* initializer, bool is_constant); Function* CreateFunction(const std::string& name, - std::shared_ptr ret_type); + std::shared_ptr function_type, + bool is_declaration = false); + Function* FindFunction(const std::string& name) const; + GlobalVariable* FindGlobal(const std::string& name) const; + + const std::vector>& GetGlobals() const; const std::vector>& GetFunctions() const; private: Context context_; + std::vector> globals_; std::vector> functions_; }; class IRBuilder { public: IRBuilder(Context& ctx, BasicBlock* bb); + void SetInsertPoint(BasicBlock* bb); BasicBlock* GetInsertBlock() const; - // 构造常量、二元运算、返回指令的最小集合。 ConstantInt* CreateConstInt(int v); + ConstantFloat* CreateConstFloat(float v); + ConstantValue* CreateZero(std::shared_ptr type); + BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); - BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + AllocaInst* CreateAlloca(std::shared_ptr allocated_type, + const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); - ReturnInst* CreateRet(Value* v); + CompareInst* CreateICmp(ICmpPred pred, Value* lhs, Value* rhs, + const std::string& name); + CompareInst* CreateFCmp(FCmpPred pred, Value* lhs, Value* rhs, + const std::string& name); + BranchInst* CreateBr(BasicBlock* target); + CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_block, + BasicBlock* false_block); + CallInst* CreateCall(Function* callee, const std::vector& args, + const std::string& name); + GetElementPtrInst* CreateGEP(Value* base_ptr, const std::vector& indices, + const std::string& name); + CastInst* CreateSIToFP(Value* value, const std::string& name); + CastInst* CreateFPToSI(Value* value, const std::string& name); + CastInst* CreateZExt(Value* value, std::shared_ptr dst_type, + const std::string& name); + ReturnInst* CreateRet(Value* value); + ReturnInst* CreateRetVoid(); private: Context& ctx_; - BasicBlock* insert_block_; + BasicBlock* insert_block_ = nullptr; }; class IRPrinter { diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index a76a3cc..34f9280 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -1,23 +1,14 @@ -// 将语法树翻译为 IR。 -// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。 - #pragma once #include #include #include +#include #include "SysYParser.h" #include "ir/IR.h" #include "sem/Sema.h" -namespace ir { -class Module; -class Function; -class IRBuilder; -class Value; -} - class IRGenImpl { public: IRGenImpl(ir::Module& module, const SemanticContext& sema); @@ -25,26 +16,84 @@ class IRGenImpl { void Gen(SysYParser::CompUnitContext& cu); private: + struct StorageEntry { + ir::Value* storage = nullptr; + std::shared_ptr declared_type; + bool is_array_param = false; + bool is_global = false; + bool is_const = false; + }; + + void DeclareBuiltins(); + void GenGlobals(SysYParser::CompUnitContext& cu); + void GenFunctionDecls(SysYParser::CompUnitContext& cu); + void GenFunctionBodies(SysYParser::CompUnitContext& cu); + void GenFuncDef(SysYParser::FuncDefContext& func); void GenBlock(SysYParser::BlockContext& block); - bool GenBlockItem(SysYParser::BlockItemContext& item); + void GenBlockItem(SysYParser::BlockItemContext& item); void GenDecl(SysYParser::DeclContext& decl); - bool GenStmt(SysYParser::StmtContext& stmt); + void GenConstDecl(SysYParser::ConstDeclContext& decl); void GenVarDecl(SysYParser::VarDeclContext& decl); - void GenReturnStmt(SysYParser::ReturnStmtContext& ret); + void GenStmt(SysYParser::StmtContext& stmt); ir::Value* GenExpr(SysYParser::ExpContext& expr); ir::Value* GenAddExpr(SysYParser::AddExpContext& add); ir::Value* GenMulExpr(SysYParser::MulExpContext& mul); ir::Value* GenUnaryExpr(SysYParser::UnaryExpContext& unary); ir::Value* GenPrimary(SysYParser::PrimaryContext& primary); + ir::Value* GenRelExpr(SysYParser::RelExpContext& rel); + ir::Value* GenEqExpr(SysYParser::EqExpContext& eq); + + ir::Value* GenLValueAddress(SysYParser::LValContext& lval); + ir::Value* GenLValueValue(SysYParser::LValContext& lval); + + void GenCond(SysYParser::CondContext& cond, ir::BasicBlock* true_block, + ir::BasicBlock* false_block); + void GenLOrCond(SysYParser::LOrExpContext& expr, ir::BasicBlock* true_block, + ir::BasicBlock* false_block); + void GenLAndCond(SysYParser::LAndExpContext& expr, ir::BasicBlock* true_block, + ir::BasicBlock* false_block); + + ir::Value* CastValue(ir::Value* value, const std::shared_ptr& dst_type); + ir::Value* ToBool(ir::Value* value); + ir::Value* DecayArrayPointer(ir::Value* array_ptr); + + void EnterScope(); + void ExitScope(); + void EnsureInsertableBlock(); + void DeclareLocal(const std::string& name, StorageEntry entry); + StorageEntry* LookupStorage(const std::string& name); + const StorageEntry* LookupStorage(const std::string& name) const; + + size_t CountScalars(const std::shared_ptr& type) const; + std::vector FlatIndexToIndices(const std::shared_ptr& type, + size_t flat_index) const; + void EmitArrayStore(ir::Value* base_ptr, const std::shared_ptr& array_type, + size_t flat_index, ir::Value* value); + void ZeroInitializeLocalArray(ir::Value* base_ptr, + const std::shared_ptr& array_type); + void EmitLocalArrayInit(ir::Value* base_ptr, const std::shared_ptr& array_type, + SysYParser::InitValContext& init); + void EmitLocalConstArrayInit(ir::Value* base_ptr, + const std::shared_ptr& array_type, + SysYParser::ConstInitValContext& init); + + ir::ConstantValue* BuildGlobalInitializer(const std::shared_ptr& type, + SysYParser::InitValContext* init); + ir::ConstantValue* BuildGlobalConstInitializer( + const std::shared_ptr& type, SysYParser::ConstInitValContext* init); ir::Module& module_; const SemanticContext& sema_; - ir::Function* func_; + ir::Function* current_function_ = nullptr; + std::shared_ptr current_return_type_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; + std::vector> local_scopes_; + std::unordered_map globals_; + std::vector break_targets_; + std::vector continue_targets_; + std::unordered_map global_const_values_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 2f0499f..a2d99be 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,30 +1,77 @@ -// 基于语法树的语义检查与名称绑定。 #pragma once +#include +#include #include +#include #include "SysYParser.h" +#include "ir/IR.h" + +enum class SymbolKind { Object, Function }; + +struct ConstantData { + enum class Kind { Int, Float }; + + Kind kind = Kind::Int; + int int_value = 0; + float float_value = 0.0f; + + static ConstantData FromInt(int value); + static ConstantData FromFloat(float value); + + bool IsInt() const { return kind == Kind::Int; } + bool IsFloat() const { return kind == Kind::Float; } + int AsInt() const; + float AsFloat() const; + ConstantData CastTo(const std::shared_ptr& dst_type) const; + std::shared_ptr GetType() const; +}; + +struct SymbolInfo { + std::string name; + SymbolKind kind = SymbolKind::Object; + std::shared_ptr type; + bool is_const = false; + bool is_global = false; + bool is_parameter = false; + bool is_array_parameter = false; + bool is_builtin = false; + + SysYParser::ConstDefContext* const_def = nullptr; + SysYParser::VarDefContext* var_def = nullptr; + SysYParser::FuncDefContext* func_def = nullptr; + + bool has_const_value = false; + ConstantData const_value{}; +}; class SemanticContext { public: - void BindVarUse(SysYParser::LValContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } + SymbolInfo* CreateSymbol(SymbolInfo symbol); + + void BindConstDef(SysYParser::ConstDefContext* node, const SymbolInfo* symbol); + void BindVarDef(SysYParser::VarDefContext* node, const SymbolInfo* symbol); + void BindFuncDef(SysYParser::FuncDefContext* node, const SymbolInfo* symbol); + void BindLVal(SysYParser::LValContext* node, const SymbolInfo* symbol); + void BindCall(SysYParser::UnaryExpContext* node, const SymbolInfo* symbol); + void SetExprType(const void* node, std::shared_ptr type); - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::LValContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } + const SymbolInfo* ResolveConstDef(const SysYParser::ConstDefContext* node) const; + const SymbolInfo* ResolveVarDef(const SysYParser::VarDefContext* node) const; + const SymbolInfo* ResolveFuncDef(const SysYParser::FuncDefContext* node) const; + const SymbolInfo* ResolveLVal(const SysYParser::LValContext* node) const; + const SymbolInfo* ResolveCall(const SysYParser::UnaryExpContext* node) const; + std::shared_ptr ResolveExprType(const void* node) const; private: - std::unordered_map - var_uses_; + std::vector> owned_symbols_; + std::unordered_map const_defs_; + std::unordered_map var_defs_; + std::unordered_map func_defs_; + std::unordered_map lvals_; + std::unordered_map calls_; + std::unordered_map> expr_types_; }; -// 目前仅检查: -// - 变量先声明后使用 -// - 局部变量不允许重复定义 SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..6b0440b 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,17 +1,20 @@ -// 极简符号表:记录局部变量定义点。 #pragma once #include #include +#include -#include "SysYParser.h" +#include "sem/Sema.h" class SymbolTable { public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; + void EnterScope(); + void ExitScope(); + + bool Declare(const std::string& name, const SymbolInfo* symbol); + const SymbolInfo* Lookup(const std::string& name) const; + const SymbolInfo* LookupCurrent(const std::string& name) const; private: - std::unordered_map table_; + std::vector> scopes_; }; diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index f41f6b3..83801be 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -37,6 +37,13 @@ if [[ ! -x "$compiler" ]]; then exit 1 fi +runtime_src="./sylib/sylib.c" +runtime_hdr="./sylib/sylib.h" +if [[ ! -f "$runtime_src" || ! -f "$runtime_hdr" ]]; then + echo "未找到 SysY 运行库: $runtime_src / $runtime_hdr" >&2 + exit 1 +fi + mkdir -p "$out_dir" base=$(basename "$input") stem=${base%.sy} @@ -56,11 +63,13 @@ if [[ "$run_exec" == true ]]; then exit 1 fi obj="$out_dir/$stem.o" + runtime_obj="$out_dir/sylib.o" exe="$out_dir/$stem" stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" llc -filetype=obj "$out_file" -o "$obj" - clang "$obj" -o "$exe" + clang -c "$runtime_src" -o "$runtime_obj" + clang "$obj" "$runtime_obj" -o "$exe" echo "运行 $exe ..." set +e if [[ -f "$stdin_file" ]]; then @@ -77,11 +86,15 @@ if [[ "$run_exec" == true ]]; then if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then printf '\n' fi - printf '%s\n' "$status" + printf '%s' "$status" } > "$actual_file" if [[ -f "$expected_file" ]]; then - if diff -u "$expected_file" "$actual_file"; then + expected_cmp="$out_dir/$stem.expected.norm" + actual_cmp="$out_dir/$stem.actual.norm" + perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n\z//' "$expected_file" > "$expected_cmp" + perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n\z//' "$actual_file" > "$actual_cmp" + if diff -u "$expected_cmp" "$actual_cmp"; then echo "输出匹配: $expected_file" else echo "输出不匹配: $expected_file" >&2 diff --git a/solution/Lab2-修改记录.md b/solution/Lab2-修改记录.md new file mode 100644 index 0000000..3a56976 --- /dev/null +++ b/solution/Lab2-修改记录.md @@ -0,0 +1,272 @@ +# Lab2 修改记录 + +## 1. 修改目标 + +根据 [doc/Lab2-中间表示生成.md](../doc/Lab2-中间表示生成.md) 的要求,完成 SysY 前端到 LLVM 风格 IR 的主链路扩展,使编译器能够: + +1. 基于现有 ANTLR parse tree 完成语义分析。 +2. 生成可被 `llc` / `clang` 接受的 IR。 +3. 通过运行库和验证脚本完成 “生成 IR -> 链接运行 -> 输出比对”。 + +本次实现继续沿用: + +1. `parse tree -> Sema -> IRGen -> IRPrinter` +2. 局部变量采用 `alloca/store/load` 内存模型 +3. 不在 Lab2 中引入独立 AST + +## 2. 设计修订 + +在实现前,对 [Lab2-设计方案.md](./Lab2-设计方案.md) 做了以下修订: + +1. 明确 SysY 源语言继续只接受 `funcDef`,不额外引入用户自定义函数声明语法。 +2. 将“模块级外部函数声明支持”与“源语言语法支持”区分开。 +3. 将 `sylib` 运行库补全和 `verify_ir.sh` 自动链接运行库纳入阶段 0 前置。 +4. 将 `functional` 与 `performance` 全量通过定义为阶段 C 收口后的总目标,不作为 A1/A2/B 的单阶段硬门槛。 +5. 统一错误归因口径: +- `parse` +- `sema` +- `irgen` +- `llvm-link/run` + +## 3. 代码改动 + +### 3.1 IR 层扩展 + +修改文件: + +1. `include/ir/IR.h` +2. `src/ir/Type.cpp` +3. `src/ir/Value.cpp` +4. `src/ir/Context.cpp` +5. `src/ir/GlobalValue.cpp` +6. `src/ir/Function.cpp` +7. `src/ir/Module.cpp` +8. `src/ir/BasicBlock.cpp` +9. `src/ir/Instruction.cpp` +10. `src/ir/IRBuilder.cpp` +11. `src/ir/IRPrinter.cpp` + +主要改动: + +1. 类型系统从最小 `void/i32/i32*` 扩展到: +- `void` +- `i1` +- `i32` +- `float` +- `pointer` +- `array` +- `function` + +2. 值系统新增: +- `ConstantFloat` +- `ConstantArray` +- `Argument` +- `GlobalVariable` + +3. 指令系统补齐: +- 整数算术:`add/sub/mul/sdiv/srem` +- 浮点算术:`fadd/fsub/fmul/fdiv` +- 比较:`icmp/fcmp` +- 控制流:`br/condbr` +- 调用:`call` +- 地址计算:`gep` +- 类型转换:`sitofp/fptosi/zext` +- 存储与返回:`alloca/load/store/ret` + +4. `IRBuilder` 从按 `i32/i32*` 写死的专用接口改为按 `Type` 驱动的通用接口。 +5. `IRPrinter` 输出调整为 LLVM 可接受文本格式。 +6. SSA 临时名生成改为 `%t0/%t1/...`,避免 LLVM 对纯数字 SSA 命名的歧义。 +7. 浮点常量打印改为 LLVM 可接受的十六进制形式。 +8. `alloca` 统一插入函数入口块,避免循环内重复分配导致的栈增长问题。 +9. `GEP` 结果类型推导修正,支持数组对象、数组指针和多维数组访问。 + +### 3.2 Sema 重构 + +修改文件: + +1. `include/sem/Sema.h` +2. `include/sem/SymbolTable.h` +3. `src/sem/SymbolTable.cpp` +4. `src/sem/Sema.cpp` + +主要改动: + +1. `SemanticContext` 从“变量 use -> decl”扩展为统一语义结果容器,记录: +- 声明绑定 +- 函数绑定 +- 调用绑定 +- 表达式静态类型 + +2. `SymbolTable` 升级为作用域栈,支持: +- 全局作用域 +- 函数作用域 +- 块作用域 +- 同层去重和内层遮蔽 + +3. `RunSema` 改为两遍式: +- 第一遍收集顶层对象和函数签名 +- 第二遍检查函数体 + +4. 注入运行库函数签名,包括: +- `getint/getch/getfloat/getarray/getfarray` +- `putint/putch/putfloat/putarray/putfarray` +- `starttime/stoptime` + +5. 增加语义检查: +- 函数调用实参数量与类型匹配 +- 返回值类型匹配 +- 赋值左值合法性 +- 数组维度和下标检查 +- `break/continue` 循环上下文检查 +- 表达式类型推导和 `int/float` 转换规则 + +6. 常量表达式求值整合到 `Sema.cpp`,用于: +- 数组维度 +- `const` 初始化 +- 全局初始化 + +7. 修正常量数组初始化检查,允许花括号内部出现标量叶子表达式。 + +### 3.3 IRGen 扩展 + +修改文件: + +1. `include/irgen/IRGen.h` +2. `src/irgen/IRGenDriver.cpp` +3. `src/irgen/IRGenFunc.cpp` +4. `src/irgen/IRGenDecl.cpp` +5. `src/irgen/IRGenStmt.cpp` +6. `src/irgen/IRGenExp.cpp` + +主要改动: + +1. 顶层生成分成两步: +- 先建立函数签名、全局对象和运行库声明 +- 再逐函数填充函数体 + +2. 支持: +- 全局变量与全局常量 +- 局部变量与局部常量 +- 数组对象与数组初始化 +- 数组形参 +- 普通函数调用与运行库调用 +- `if/else` +- `while` +- `break/continue` +- `return` + +3. 表达式生成拆分为: +- `GenExpr` +- `GenLValueAddress` +- `GenCond` + +4. 条件表达式和短路逻辑直接降到控制流,不走“先算整型值再判断”的路径。 +5. 多维数组访问统一走逐维 `GEP`。 +6. `int/float` 混合表达式按规则插入 `sitofp/fptosi`。 +7. 修正一元逻辑非 `!` 的 IR 生成,保证其语义为真正的布尔取反。 + +### 3.4 运行库与验证脚本 + +修改文件: + +1. `sylib/sylib.h` +2. `sylib/sylib.c` +3. `scripts/verify_ir.sh` +4. `solution/run_lab2_batch.sh` + +主要改动: + +1. 补全 `sylib` 头文件与 C 实现。 +2. `verify_ir.sh` 在链接时自动编译并链接 `sylib/sylib.c`。 +3. 输出比对增加换行归一化,兼容测试集中的 `CRLF/LF` 差异和末尾换行差异。 +4. 新增 `run_lab2_batch.sh`,用于 Lab2 的全量构建、批量回归和结果统计。 + +## 4. 覆盖的阶段目标 + +本次实现已覆盖设计方案中的全部阶段目标: + +1. 阶段 0:IR 基础设施、运行库、验证链路 +2. 阶段 A1:函数、调用、全局 `int` +3. 阶段 A2:控制流、比较、短路、循环跳转 +4. 阶段 B:数组、初始化、多维下标、数组运行库 +5. 阶段 C:`float`、浮点比较、`int <-> float` 转换、浮点运行库 + +## 5. 验证记录 + +### 5.1 构建验证 + +执行命令: + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF +cmake --build build -j 4 +``` + +结果: + +1. 构建成功。 +2. `./build/bin/compiler --emit-ir` 可正常生成 IR。 + +### 5.2 单样例和阶段样例验证 + +执行过的阶段代表样例包括: + +1. `simple_add.sy` +2. `09_func_defn.sy` +3. `29_break.sy` +4. `36_op_priority2.sy` +5. `if-combine3.sy` +6. `22_matrix_multiply.sy` +7. `15_graph_coloring.sy` +8. `01_mm2.sy` +9. `02_mv3.sy` +10. `03_sort1.sy` +11. `transpose0.sy` +12. `95_float.sy` +13. `large_loop_array_2.sy` +14. `vector_mul3.sy` + +结果: + +1. `--emit-ir` 可生成合法 IR。 +2. `verify_ir.sh --run` 可完成链接、运行与输出比对。 + +### 5.3 全量正例回归 + +执行命令: + +```bash +for case in $(find test/test_case/functional test/test_case/performance -maxdepth 1 -name '*.sy' | sort); do + ./scripts/verify_ir.sh "$case" test/test_result/lab2_ir --run || exit 1 +done +``` + +以及新增批量脚本: + +```bash +./solution/run_lab2_batch.sh +``` + +结果: + +1. `functional`:11/11 通过 +2. `performance`:10/10 通过 +3. 总计:21/21 通过 + +### 5.4 额外自检 + +1. 运行库调用自检: +- `putint(42)` 可正常生成 IR、链接运行并输出 `42` + +2. 语义错误归因自检: +- `break` 出现在循环外时,能够在 `sema` 阶段报错,而不是落到 `irgen` 或 LLVM 工具链 + +## 6. 当前边界说明 + +1. Lab2 的目标是 `--emit-ir` 链路,不是后端汇编链路。 +2. MIR/后端没有同步扩展完整功能,只保持了工程可编译。 +3. 本次实现未引入独立 AST,也未实现 SSA/phi 构造和优化。 + +## 7. 结论 + +本次修改后,Lab2 已完成从 SysY 语法树到 LLVM 风格 IR 的主链路扩展,支持函数、控制流、数组、初始化、浮点与运行库调用,并且通过了当前仓库 `functional` 与 `performance` 正例全集的运行验证。 diff --git a/solution/Lab2-设计方案.md b/solution/Lab2-设计方案.md new file mode 100644 index 0000000..ae618b1 --- /dev/null +++ b/solution/Lab2-设计方案.md @@ -0,0 +1,437 @@ +# Lab2 设计方案(修订版) + +## 1. 目标 + +根据 [doc/Lab2-中间表示生成.md](../doc/Lab2-中间表示生成.md) 的要求,在当前最小编译器框架上扩展 Sema -> IRGen -> IRPrinter 链路,使更多 SysY 语法能够被正确翻译为 LLVM 风格 IR,并通过运行验证完成 IR -> 目标程序 -> 输出比对。 + +本次 Lab2 采用分阶段、可回归、可归因的推进策略,核心原则如下: + +1. 先补基础设施,再补语法覆盖,避免阶段跳步。 +2. 每阶段都定义最小样例集与退出条件,避免只做点测。 +3. 错误分类保持一致: +- 语法错误归 parse +- 语义错误归 sema +- 生成能力缺口归 irgen +- LLVM 文本、运行库链接或运行结果问题归 llvm-link/run + +## 2. 当前实现现状与约束 + +结合当前代码,现状可概括为: + +1. Sema 只覆盖最小名称绑定,范围偏向 main 函数内局部变量。 +2. IRGen 只覆盖最小顺序语句流程,核心是局部 int、基础算术、return。 +3. IR 类型与指令集合都是教学最小子集,无法直接承载完整 Lab2 功能。 + +因此,Lab2 不能只改某一个目录,必须协同扩展: + +1. IR 层 +- include/ir/IR.h +- src/ir + +2. 语义层 +- include/sem/Sema.h +- include/sem/SymbolTable.h +- src/sem + +3. 生成层 +- include/irgen/IRGen.h +- src/irgen + +## 3. 总体设计原则 + +### 3.1 保持 parse tree 直连方案 + +继续基于 ANTLR parse tree,不引入独立 AST。接口仍保持: + +1. RunSema(CompUnit) -> SemanticContext +2. GenerateIR(CompUnit, SemanticContext) -> Module + +理由:降低结构性重构成本,把精力聚焦在 Lab2 的语义补全与 IR 生成补全。 + +### 3.2 继续采用内存模型 + +局部变量和形参默认走 alloca/store/load 模型,不在 Lab2 引入 SSA 构造与 phi 优化。理由:优先保证正确性与可运行性,优化类目标留给后续实验。 + +### 3.3 分阶段门禁 + +每阶段必须满足三类门禁: + +1. 该阶段目标样例通过。 +2. 前阶段样例无回归。 +3. 失败能快速归因到 parse/sema/irgen/llvm-link/run。 + +## 4. 阶段划分(重排后) + +### 4.1 阶段 0:基础设施硬前置 + +这是后续所有阶段的阻塞前置阶段,未完成不得进入 A1。 + +目标: + +1. 扩展 IR 类型系统到最小可用集合: +- void +- i1 +- i32 +- float +- pointer +- array +- function + +2. 扩展关键指令集合: +- 算术补齐 sdiv、srem +- 比较补齐 icmp、fcmp +- 控制流补齐 br、condbr +- 调用补齐 call +- 地址计算补齐 gep +- 转换补齐 sitofp、fptosi、zext + +3. IRBuilder 与 IRPrinter 同步扩展,避免出现能生成但不能打印、或能打印但 LLVM 不接受。 + +4. Sema 架构改为两遍式骨架: +- 第一遍收集顶层符号(函数签名、全局对象、运行库函数) +- 第二遍检查函数体(类型、调用、控制流上下文等) + +5. SymbolTable 升级为作用域栈,支持全局/函数/块作用域和遮蔽规则。 + +6. 运行库与验证环境前置补齐: +- 完整提供 `sylib/sylib.h` 与 `sylib/sylib.c` +- `verify_ir.sh` 在链接阶段自动带上运行库 +- 运行结果比对需要容忍测试集中的换行风格差异 + +阶段样例: + +1. simple_add + +退出条件: + +1. simple_add 不回归。 +2. 新增 IR 元素可被 llc/clang 接受。 +3. parse/sema/irgen 错误分类可区分。 + +### 4.2 阶段 A1:函数与调用主链路(依赖阶段 0) + +目标: + +1. 用户函数定义支持,以及 IR/Module 层的外部函数声明支持。 +2. 形参与返回类型检查。 +3. 函数调用与实参数量/类型检查。 +4. 全局 int 标量与全局初始化。 +5. 运行库函数声明注册与调用生成。 + +实现要点: + +1. SysY 源语言继续只接受 `funcDef`,不额外引入用户自定义函数声明语法。 +2. Module 区分函数声明和函数定义。 +3. 运行库函数和其他外部函数通过模块级声明接入,而不是扩展源语言语法。 +4. 形参映射为 Argument,再按内存模型落地到槽位。 +5. Sema 在调用点完成签名匹配,不把类型错误拖到 IRGen。 + +阶段样例: + +1. simple_add +2. 09_func_defn + +退出条件: + +1. 阶段样例 --emit-ir 成功。 +2. 阶段样例 --run 输出与退出码匹配。 +3. 无阶段 0 回归。 + +### 4.3 阶段 A2:控制流与条件主链路(依赖 A1) + +目标: + +1. 支持赋值语句、表达式语句、块语句。 +2. 支持 if/else。 +3. 支持 while。 +4. 支持 break/continue(含循环嵌套场景)。 +5. 支持比较与逻辑条件生成。 + +实现要点: + +1. 明确三类表达式接口职责: +- GenRValue +- GenLValueAddr +- GenCond + +2. 控制流模板固定化: +- if:cond -> then -> else(可选) -> merge +- while:cond -> body -> exit +- break 绑定 exit +- continue 绑定 cond + +阶段样例: + +1. 29_break +2. 36_op_priority2 +3. if-combine3 + +退出条件: + +1. 阶段样例 --run 全通过。 +2. 短路与循环跳转行为正确。 +3. 无 A1 与阶段 0 回归。 + +### 4.4 阶段 B:数组与初始化(依赖 A2) + +目标: + +1. 一维/多维数组类型与对象表示。 +2. 全局数组与局部数组支持。 +3. 数组形参支持。 +4. 下标访问通过 GEP 生成。 +5. 初始化器递归展开与补零规则落地。 +6. getarray/putarray 相关调用与类型检查支持。 + +实现要点: + +1. 数组对象与数组指针区分清晰。 +2. 下标访问逐维计算,避免扁平化误用。 +3. 局部数组与全局数组初始化路径分离。 + +阶段样例: + +1. 22_matrix_multiply +2. 15_graph_coloring +3. 01_mm2 +4. 02_mv3 +5. transpose0 +6. 03_sort1 + +退出条件: + +1. 数组样例链路通过。 +2. 初始化补零行为与预期一致。 +3. 无 A2 及之前回归。 + +### 4.5 阶段 C:float 与混合类型(依赖 B) + +目标: + +1. float 类型与浮点常量。 +2. 浮点运算与浮点比较。 +3. int <-> float 隐式转换。 +4. getfloat/putfloat/getfarray/putfarray 支持。 + +实现要点: + +1. 明确定义类型提升规则,避免不同模块各自推断。 +2. 转换插入策略统一: +- 算术场景的提升 +- 赋值场景的收窄/转换 +- 调用实参与形参匹配转换 + +阶段样例: + +1. 95_float +2. large_loop_array_2 +3. vector_mul3 + +退出条件: + +1. 浮点样例链路通过。 +2. 类型错误优先在 sema 阶段暴露。 +3. 无 B 及之前回归。 + +## 5. IR 层详细设计 + +### 5.1 类型系统 + +类型至少覆盖: + +1. Void +2. Int1 +3. Int32 +4. Float32 +5. Pointer(element_type) +6. Array(element_type, extent) +7. Function(return_type, param_types) + +要求: + +1. 类型构造和查询接口统一。 +2. 现有按 `i32/i32*` 写死的接口需要升级为按 `Type` 驱动的通用实现。 +3. IRPrinter 打印格式与 LLVM 文本兼容。 +4. 函数签名可完整表达返回值与参数列表。 + +### 5.2 值与对象系统 + +至少补齐: + +1. ConstantFloat +2. ConstantArray +3. Argument +4. GlobalVariable 或等价全局对象表示 + +Module 层至少支持: + +1. 函数声明集合 +2. 函数定义集合 +3. 全局变量/常量对象集合 + +### 5.3 指令与 Builder + +Builder 最小接口建议包括: + +1. CreateBr +2. CreateCondBr +3. CreateCall +4. CreateICmp +5. CreateFCmp +6. CreateGEP +7. CreateSIToFP +8. CreateFPToSI +9. CreateZExt +10. CreateAlloca(type) + +要求: + +1. 新增指令必须同步到 IRPrinter。 +2. 输出 IR 必须可被 llc/clang 接受。 + +## 6. Sema 详细设计 + +### 6.1 SemanticContext 扩展 + +除变量绑定外,至少包含: + +1. 函数绑定信息 +2. 表达式静态类型 +3. 左值可赋值性 +4. 数组维度/退化信息 +5. 调用点签名匹配结果 + +### 6.2 符号表规则 + +采用作用域栈,支持: + +1. Declare(同层去重) +2. Lookup(由内向外) +3. EnterScope / ExitScope + +覆盖范围: + +1. 全局作用域 +2. 函数作用域 +3. 块作用域 + +### 6.3 两遍式语义流程 + +第一遍: + +1. 收集顶层函数签名 +2. 收集全局变量/常量 +3. 注入运行库函数签名 + +第二遍: + +1. 校验函数体 +2. 校验 return 与函数返回类型 +3. 校验调用参数个数与类型 +4. 校验数组下标与维度 +5. 校验 break/continue 上下文 +6. 计算常量表达式(用于维度与初始化) + +## 7. IRGen 详细设计 + +### 7.1 生成流程 + +两阶段生成: + +1. 顶层扫描,建立函数与全局对象骨架。 +2. 逐函数填充基本块和指令。 + +### 7.2 函数状态 + +函数级状态建议包括: + +1. current_func +2. current_bb +3. return_bb +4. return_slot(非 void 可选) +5. break_targets 栈 +6. continue_targets 栈 +7. 局部存储槽位环境 + +### 7.3 表达式与语句职责拆分 + +表达式: + +1. GenRValue +2. GenLValueAddr +3. GenCond + +语句: + +1. 声明 +2. 赋值 +3. 表达式语句 +4. return +5. if/else +6. while +7. break +8. continue +9. block + +## 8. 验证与回归方案 + +### 8.1 单样例验证 + +用于快速定位: + +1. 编译器生成 IR 是否成功 +2. IR 文本是否基本正确 + +### 8.2 阶段样例回归 + +每阶段必须执行对应样例集,不得只跑一个样例。 + +### 8.3 全量回归 + +阶段内只要求回归相关子集并记录失败样例。 + +当前仓库 `functional` 与 `performance` 正例全集覆盖,属于阶段 C 完成后的总目标,不作为 A1/A2/B 的单阶段硬门槛。 + +### 8.4 失败归因矩阵 + +1. parse 失败:语法规则或词法/语法处理问题。 +2. sema 失败:名称绑定、类型检查、上下文约束问题。 +3. irgen 失败:语义到 IR 映射未实现或实现错误。 +4. llvm-link/run 失败:IR 文本不合法、链接缺失、运行行为错误。 + +### 8.5 建议验证命令模板 + +单样例: + +```bash +./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy +./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/function/ir --run +``` + +阶段样例循环回归(示意): + +```bash +for f in test/test_case/functional/09_func_defn.sy test/test_case/functional/29_break.sy; do + ./scripts/verify_ir.sh "$f" test/test_result/function/ir --run || exit 1 +done +``` + +## 9. 设计取舍 + +1. 不引入独立 AST。优先保证 Lab2 可落地与可验证,降低重构成本。 +2. 继续采用内存模型。减少实现复杂度,先确保正确性。 +3. 优先保证 LLVM 可接受性。内部抽象服从外部工具链约束。 +4. 分阶段推进。降低单次改动规模,便于调试与协作。 +5. 明确排除范围。Lab2 不承担 SSA/phi 构造和优化类目标,相关工作放到后续实验。 + +## 10. 最终验收目标 + +Lab2 完成后应达到: + +1. Sema 能完成核心名称绑定与类型检查。 +2. IRGen 能覆盖 Lab2 目标语法并生成合法 LLVM 风格 IR。 +3. 关键样例能通过运行比对。 +4. 形成稳定回归流程,支持后续 Lab3 对接。 +5. 阶段 C 收口后,当前仓库 `functional` 与 `performance` 正例全集应能完成 IR 生成、链接与运行比对。 + +在此基础上,Lab3 再继续推进后端相关能力,包括指令选择、栈帧与寄存器分配。 diff --git a/solution/run_lab2_batch.sh b/solution/run_lab2_batch.sh new file mode 100755 index 0000000..4740674 --- /dev/null +++ b/solution/run_lab2_batch.sh @@ -0,0 +1,173 @@ +#!/usr/bin/env bash + +set -euo pipefail +shopt -s nullglob + +ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)" +BUILD_DIR="$ROOT_DIR/build" +ANTLR_DIR="$BUILD_DIR/generated/antlr4" +JAR_PATH="$ROOT_DIR/third_party/antlr-4.13.2-complete.jar" +GRAMMAR_PATH="$ROOT_DIR/src/antlr4/SysY.g4" +OUT_ROOT="$ROOT_DIR/test/test_result/lab2_ir_batch" + +RUN_FUNCTIONAL=true +RUN_PERFORMANCE=true +DO_BUILD=true + +functional_total=0 +functional_passed=0 +functional_failed=0 +performance_total=0 +performance_passed=0 +performance_failed=0 +failed_cases=() + +usage() { + cat <<'EOF' +Usage: ./solution/run_lab2_batch.sh [options] + +Options: + --no-build Skip ANTLR generation and project rebuild + --functional-only Run only test/test_case/functional/*.sy + --performance-only Run only test/test_case/performance/*.sy + --output-dir Set output directory for generated IR and logs + --help Show this help message +EOF +} + +print_summary() { + local total passed failed + total=$((functional_total + performance_total)) + passed=$((functional_passed + performance_passed)) + failed=$((functional_failed + performance_failed)) + + echo + echo "Summary:" + echo " Functional cases: total=$functional_total, passed=$functional_passed, failed=$functional_failed" + echo " Performance cases: total=$performance_total, passed=$performance_passed, failed=$performance_failed" + echo " Overall: total=$total, passed=$passed, failed=$failed" + + if (( ${#failed_cases[@]} > 0 )); then + echo "Failed cases:" + printf ' - %s\n' "${failed_cases[@]}" + fi +} + +run_case() { + local case_file=$1 + local group=$2 + local stem out_dir log_file + + stem="$(basename "${case_file%.sy}")" + out_dir="$OUT_ROOT/$group" + log_file="$out_dir/$stem.verify.log" + mkdir -p "$out_dir" + + if [[ "$group" == "functional" ]]; then + ((functional_total += 1)) + else + ((performance_total += 1)) + fi + + if ./scripts/verify_ir.sh "$case_file" "$out_dir" --run >"$log_file" 2>&1; then + echo "PASS: $case_file" + if [[ "$group" == "functional" ]]; then + ((functional_passed += 1)) + else + ((performance_passed += 1)) + fi + else + echo "FAIL: $case_file" + cat "$log_file" + if [[ "$group" == "functional" ]]; then + ((functional_failed += 1)) + else + ((performance_failed += 1)) + fi + failed_cases+=("$case_file") + fi +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --no-build) + DO_BUILD=false + ;; + --functional-only) + RUN_FUNCTIONAL=true + RUN_PERFORMANCE=false + ;; + --performance-only) + RUN_FUNCTIONAL=false + RUN_PERFORMANCE=true + ;; + --output-dir) + shift + if [[ $# -eq 0 ]]; then + echo "Missing value for --output-dir" >&2 + usage + exit 1 + fi + if [[ "$1" = /* ]]; then + OUT_ROOT="$1" + else + OUT_ROOT="$ROOT_DIR/$1" + fi + ;; + --help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + usage + exit 1 + ;; + esac + shift +done + +if [[ "$RUN_FUNCTIONAL" == false && "$RUN_PERFORMANCE" == false ]]; then + echo "No test set selected." >&2 + exit 1 +fi + +if [[ "$DO_BUILD" == true ]]; then + echo "[1/4] Generating ANTLR sources..." + mkdir -p "$ANTLR_DIR" + java -jar "$JAR_PATH" \ + -Dlanguage=Cpp \ + -visitor -no-listener \ + -Xexact-output-dir \ + -o "$ANTLR_DIR" \ + "$GRAMMAR_PATH" + + echo "[2/4] Configuring CMake..." + cmake -S "$ROOT_DIR" -B "$BUILD_DIR" -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF + + echo "[3/4] Building project..." + cmake --build "$BUILD_DIR" -j "$(nproc)" +fi + +echo "[4/4] Running IR batch tests..." + +if [[ "$RUN_FUNCTIONAL" == true ]]; then + for case_file in "$ROOT_DIR"/test/test_case/functional/*.sy; do + run_case "$case_file" "functional" + done +fi + +if [[ "$RUN_PERFORMANCE" == true ]]; then + for case_file in "$ROOT_DIR"/test/test_case/performance/*.sy; do + run_case "$case_file" "performance" + done +fi + +print_summary + +if (( functional_failed + performance_failed > 0 )); then + echo "Batch test finished with failures." + exit 1 +fi + +echo "Batch test passed." diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index b18502c..b8634a7 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -1,19 +1,10 @@ -// IR 基本块: -// - 保存指令序列 -// - 为后续 CFG 分析预留前驱/后继接口 -// -// 当前仍是最小实现: -// - BasicBlock 已纳入 Value 体系,但类型先用 void 占位; -// - 指令追加与 terminator 约束主要在头文件中的 Append 模板里处理; -// - 前驱/后继容器已经预留,但当前项目里还没有分支指令与自动维护逻辑。 - #include "ir/IR.h" +#include #include namespace ir { -// 当前 BasicBlock 还没有专门的 label type,因此先用 void 作为占位类型。 BasicBlock::BasicBlock(std::string name) : Value(Type::GetVoidType(), std::move(name)) {} @@ -21,19 +12,29 @@ Function* BasicBlock::GetParent() const { return parent_; } void BasicBlock::SetParent(Function* parent) { parent_ = parent; } - bool BasicBlock::HasTerminator() const { return !instructions_.empty() && instructions_.back()->IsTerminator(); } -// 按插入顺序返回块内指令序列。 +void BasicBlock::AddSuccessor(BasicBlock* succ) { + if (!succ) { + return; + } + if (std::find(successors_.begin(), successors_.end(), succ) == + successors_.end()) { + successors_.push_back(succ); + } + if (std::find(succ->predecessors_.begin(), succ->predecessors_.end(), this) == + succ->predecessors_.end()) { + succ->predecessors_.push_back(this); + } +} + const std::vector>& BasicBlock::GetInstructions() const { return instructions_; } -// 前驱/后继接口先保留给后续 CFG 扩展使用。 -// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。 const std::vector& BasicBlock::GetPredecessors() const { return predecessors_; } diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 16c982c..87be281 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -1,6 +1,6 @@ -// 管理基础类型、整型常量池和临时名生成。 #include "ir/IR.h" +#include #include namespace ir { @@ -9,15 +9,38 @@ Context::~Context() = default; ConstantInt* Context::GetConstInt(int v) { auto it = const_ints_.find(v); - if (it != const_ints_.end()) return it->second.get(); + if (it != const_ints_.end()) { + return it->second.get(); + } auto inserted = - const_ints_.emplace(v, std::make_unique(Type::GetInt32Type(), v)).first; + const_ints_.emplace(v, std::make_unique(Type::GetInt32Type(), v)) + .first; + return inserted->second.get(); +} + +ConstantFloat* Context::GetConstFloat(float v) { + uint32_t bits = 0; + std::memcpy(&bits, &v, sizeof(bits)); + auto it = const_floats_.find(bits); + if (it != const_floats_.end()) { + return it->second.get(); + } + auto inserted = const_floats_ + .emplace(bits, std::make_unique( + Type::GetFloatType(), v)) + .first; return inserted->second.get(); } std::string Context::NextTemp() { std::ostringstream oss; - oss << "%" << ++temp_index_; + oss << "%t" << ++temp_index_; + return oss.str(); +} + +std::string Context::NextBlock(const std::string& prefix) { + std::ostringstream oss; + oss << prefix << "." << ++block_index_; return oss.str(); } diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index cf14d48..7fea573 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -1,16 +1,39 @@ -// IR Function: -// - 保存参数列表、基本块列表 -// - 记录函数属性/元信息(按需要扩展) #include "ir/IR.h" +#include + namespace ir { -Function::Function(std::string name, std::shared_ptr ret_type) - : Value(std::move(ret_type), std::move(name)) { - entry_ = CreateBlock("entry"); +Function::Function(std::string name, std::shared_ptr function_type, + bool is_declaration) + : GlobalValue(std::move(function_type), std::move(name)), + is_declaration_(is_declaration) { + if (!type_ || !type_->IsFunction()) { + throw std::runtime_error("Function 需要 function type"); + } +} + +const std::shared_ptr& Function::GetFunctionType() const { return type_; } + +const std::shared_ptr& Function::GetReturnType() const { + return type_->GetReturnType(); +} + +const std::vector>& Function::GetArguments() const { + return arguments_; +} + +Argument* Function::AddArgument(std::shared_ptr ty, const std::string& name) { + auto arg = std::make_unique(std::move(ty), name, arguments_.size(), this); + auto* ptr = arg.get(); + arguments_.push_back(std::move(arg)); + return ptr; } BasicBlock* Function::CreateBlock(const std::string& name) { + if (is_declaration_) { + throw std::runtime_error("声明函数不能创建基本块"); + } auto block = std::make_unique(name); auto* ptr = block.get(); ptr->SetParent(this); diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp index 7c2abe1..d2fea2d 100644 --- a/src/ir/GlobalValue.cpp +++ b/src/ir/GlobalValue.cpp @@ -1,11 +1,19 @@ -// GlobalValue 占位实现: -// - 具体的全局初始化器、打印和链接语义需要自行补全 - #include "ir/IR.h" namespace ir { GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) - : User(std::move(ty), std::move(name)) {} + : Value(std::move(ty), std::move(name)) {} + +GlobalVariable::GlobalVariable(std::string name, std::shared_ptr value_type, + ConstantValue* initializer, bool is_constant) + : GlobalValue(Type::GetPointerType(value_type), std::move(name)), + value_type_(std::move(value_type)), + initializer_(initializer), + is_constant_(is_constant) {} + +Argument::Argument(std::shared_ptr ty, std::string name, size_t index, + Function* parent) + : Value(std::move(ty), std::move(name)), index_(index), parent_(parent) {} } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 90f03c4..cb637b5 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -1,89 +1,178 @@ -// IR 构建工具: -// - 管理插入点(当前基本块/位置) -// - 提供创建各类指令的便捷接口,降低 IRGen 复杂度 - #include "ir/IR.h" #include -#include "utils/Log.h" - namespace ir { -IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) - : ctx_(ctx), insert_block_(bb) {} +namespace { + +void RequireInsertBlock(BasicBlock* bb) { + if (!bb) { + throw std::runtime_error("IRBuilder 未设置插入点"); + } +} + +std::shared_ptr InferLoadType(Value* ptr) { + if (!ptr || !ptr->GetType() || !ptr->GetType()->IsPointer()) { + throw std::runtime_error("CreateLoad 需要指针"); + } + return ptr->GetType()->GetElementType(); +} + +std::shared_ptr InferGEPResultType(Value* base_ptr, + const std::vector& indices) { + if (!base_ptr || !base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) { + throw std::runtime_error("CreateGEP 需要指针基址"); + } + auto current = base_ptr->GetType()->GetElementType(); + for (size_t i = 0; i < indices.size(); ++i) { + auto* index = indices[i]; + (void)index; + if (!current) { + throw std::runtime_error("CreateGEP 遇到空类型"); + } + if (i == 0) { + continue; + } + if (current->IsArray()) { + current = current->GetElementType(); + continue; + } + if (current->IsPointer()) { + current = current->GetElementType(); + continue; + } + break; + } + return Type::GetPointerType(current); +} + +} // namespace + +IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) : ctx_(ctx), insert_block_(bb) {} void IRBuilder::SetInsertPoint(BasicBlock* bb) { insert_block_ = bb; } BasicBlock* IRBuilder::GetInsertBlock() const { return insert_block_; } -ConstantInt* IRBuilder::CreateConstInt(int v) { - // 常量不需要挂在基本块里,由 Context 负责去重与生命周期。 - return ctx_.GetConstInt(v); +ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); } + +ConstantFloat* IRBuilder::CreateConstFloat(float v) { return ctx_.GetConstFloat(v); } + +ConstantValue* IRBuilder::CreateZero(std::shared_ptr type) { + if (!type) { + throw std::runtime_error("CreateZero 缺少类型"); + } + if (type->IsInt1() || type->IsInt32()) { + return CreateConstInt(0); + } + if (type->IsFloat32()) { + return CreateConstFloat(0.0f); + } + return ctx_.CreateOwnedConstant(type); } BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name) { - if (!insert_block_) { - throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); - } - if (!lhs) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateBinary 缺少 lhs")); - } - if (!rhs) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateBinary 缺少 rhs")); + RequireInsertBlock(insert_block_); + if (!lhs || !rhs) { + throw std::runtime_error("CreateBinary 缺少操作数"); } return insert_block_->Append(op, lhs->GetType(), lhs, rhs, name); } -BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, - const std::string& name) { - return CreateBinary(Opcode::Add, lhs, rhs, name); +AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr allocated_type, + const std::string& name) { + RequireInsertBlock(insert_block_); + auto* parent = insert_block_->GetParent(); + if (!parent || !parent->GetEntry()) { + throw std::runtime_error("CreateAlloca 需要所在函数入口块"); + } + return parent->GetEntry()->Append(std::move(allocated_type), name); } AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { - if (!insert_block_) { - throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); - } - return insert_block_->Append(Type::GetPtrInt32Type(), name); + return CreateAlloca(Type::GetInt32Type(), name); } LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { - if (!insert_block_) { - throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); - } - if (!ptr) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); - } - return insert_block_->Append(Type::GetInt32Type(), ptr, name); + RequireInsertBlock(insert_block_); + return insert_block_->Append(ptr, InferLoadType(ptr), name); } StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { - if (!insert_block_) { - throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); - } - if (!val) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateStore 缺少 val")); - } - if (!ptr) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateStore 缺少 ptr")); - } - return insert_block_->Append(Type::GetVoidType(), val, ptr); + RequireInsertBlock(insert_block_); + return insert_block_->Append(val, ptr); } -ReturnInst* IRBuilder::CreateRet(Value* v) { - if (!insert_block_) { - throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); - } - if (!v) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateRet 缺少返回值")); +CompareInst* IRBuilder::CreateICmp(ICmpPred pred, Value* lhs, Value* rhs, + const std::string& name) { + RequireInsertBlock(insert_block_); + return insert_block_->Append(pred, lhs, rhs, name); +} + +CompareInst* IRBuilder::CreateFCmp(FCmpPred pred, Value* lhs, Value* rhs, + const std::string& name) { + RequireInsertBlock(insert_block_); + return insert_block_->Append(pred, lhs, rhs, name); +} + +BranchInst* IRBuilder::CreateBr(BasicBlock* target) { + RequireInsertBlock(insert_block_); + return insert_block_->Append(target); +} + +CondBranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_block, + BasicBlock* false_block) { + RequireInsertBlock(insert_block_); + return insert_block_->Append(cond, true_block, false_block); +} + +CallInst* IRBuilder::CreateCall(Function* callee, const std::vector& args, + const std::string& name) { + RequireInsertBlock(insert_block_); + std::string actual_name = name; + if (callee && callee->GetReturnType()->IsVoid()) { + actual_name.clear(); } - return insert_block_->Append(Type::GetVoidType(), v); + return insert_block_->Append(callee, args, actual_name); +} + +GetElementPtrInst* IRBuilder::CreateGEP(Value* base_ptr, + const std::vector& indices, + const std::string& name) { + RequireInsertBlock(insert_block_); + return insert_block_->Append( + base_ptr, indices, InferGEPResultType(base_ptr, indices), name); +} + +CastInst* IRBuilder::CreateSIToFP(Value* value, const std::string& name) { + RequireInsertBlock(insert_block_); + return insert_block_->Append(Opcode::SIToFP, value, + Type::GetFloatType(), name); +} + +CastInst* IRBuilder::CreateFPToSI(Value* value, const std::string& name) { + RequireInsertBlock(insert_block_); + return insert_block_->Append(Opcode::FPToSI, value, + Type::GetInt32Type(), name); +} + +CastInst* IRBuilder::CreateZExt(Value* value, std::shared_ptr dst_type, + const std::string& name) { + RequireInsertBlock(insert_block_); + return insert_block_->Append(Opcode::ZExt, value, std::move(dst_type), + name); +} + +ReturnInst* IRBuilder::CreateRet(Value* value) { + RequireInsertBlock(insert_block_); + return value ? insert_block_->Append(value) + : insert_block_->Append(); +} + +ReturnInst* IRBuilder::CreateRetVoid() { + RequireInsertBlock(insert_block_); + return insert_block_->Append(); } } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..35cd1d9 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -1,30 +1,127 @@ -// IR 文本输出: -// - 将 IR 打印为 .ll 风格的文本 -// - 支撑调试与测试对比(diff) - #include "ir/IR.h" +#include +#include +#include +#include #include +#include #include -#include - -#include "utils/Log.h" namespace ir { +namespace { + +std::string TypeToString(const std::shared_ptr& ty); +std::string ConstantToString(const ConstantValue* value); -static const char* TypeToString(const Type& ty) { - switch (ty.GetKind()) { +std::string TypeToString(const std::shared_ptr& ty) { + if (!ty) { + throw std::runtime_error("空类型无法打印"); + } + switch (ty->GetKind()) { case Type::Kind::Void: return "void"; + case Type::Kind::Int1: + return "i1"; case Type::Kind::Int32: return "i32"; - case Type::Kind::PtrInt32: - return "i32*"; + case Type::Kind::Float32: + return "float"; + case Type::Kind::Pointer: + return TypeToString(ty->GetElementType()) + "*"; + case Type::Kind::Array: { + std::ostringstream oss; + oss << "[" << ty->GetArraySize() << " x " + << TypeToString(ty->GetElementType()) << "]"; + return oss.str(); + } + case Type::Kind::Function: { + std::ostringstream oss; + oss << TypeToString(ty->GetReturnType()) << " ("; + const auto& params = ty->GetParamTypes(); + for (size_t i = 0; i < params.size(); ++i) { + if (i != 0) { + oss << ", "; + } + oss << TypeToString(params[i]); + } + oss << ")"; + return oss.str(); + } } - throw std::runtime_error(FormatError("ir", "未知类型")); + throw std::runtime_error("未知类型"); } -static const char* OpcodeToString(Opcode op) { +std::string FloatLiteral(float value) { + std::ostringstream oss; + double widened = static_cast(value); + std::uint64_t bits = 0; + std::memcpy(&bits, &widened, sizeof(bits)); + oss << "0x" << std::uppercase << std::hex << std::setw(16) << std::setfill('0') + << bits; + return oss.str(); +} + +std::string ValueRef(const Value* value) { + if (!value) { + return ""; + } + if (auto* ci = dynamic_cast(value)) { + return std::to_string(ci->GetValue()); + } + if (auto* cf = dynamic_cast(value)) { + return FloatLiteral(cf->GetValue()); + } + if (auto* cz = dynamic_cast(value)) { + if (cz->GetType()->IsFloat32()) { + return FloatLiteral(0.0f); + } + return "0"; + } + if (dynamic_cast(value) != nullptr || + dynamic_cast(value) != nullptr) { + return "@" + value->GetName(); + } + return value->GetName(); +} + +std::string ConstantToString(const ConstantValue* value) { + if (!value) { + throw std::runtime_error("空常量无法打印"); + } + if (auto* ci = dynamic_cast(value)) { + return std::to_string(ci->GetValue()); + } + if (auto* cf = dynamic_cast(value)) { + return FloatLiteral(cf->GetValue()); + } + if (auto* cz = dynamic_cast(value)) { + if (cz->GetType()->IsScalar()) { + return ValueRef(cz); + } + return "zeroinitializer"; + } + if (auto* array = dynamic_cast(value)) { + if (array->IsZeroValue()) { + return "zeroinitializer"; + } + std::ostringstream oss; + oss << "["; + const auto& elements = array->GetElements(); + for (size_t i = 0; i < elements.size(); ++i) { + if (i != 0) { + oss << ", "; + } + oss << TypeToString(elements[i]->GetType()) << " " + << ConstantToString(elements[i]); + } + oss << "]"; + return oss.str(); + } + throw std::runtime_error("未知常量类型"); +} + +const char* BinaryOpcodeName(Opcode op) { switch (op) { case Opcode::Add: return "add"; @@ -32,69 +129,241 @@ static const char* OpcodeToString(Opcode op) { return "sub"; case Opcode::Mul: return "mul"; - case Opcode::Alloca: - return "alloca"; - case Opcode::Load: - return "load"; - case Opcode::Store: - return "store"; - case Opcode::Ret: - return "ret"; - } - return "?"; + case Opcode::SDiv: + return "sdiv"; + case Opcode::SRem: + return "srem"; + case Opcode::FAdd: + return "fadd"; + case Opcode::FSub: + return "fsub"; + case Opcode::FMul: + return "fmul"; + case Opcode::FDiv: + return "fdiv"; + default: + throw std::runtime_error("不是二元算术 opcode"); + } } -static std::string ValueToString(const Value* v) { - if (auto* ci = dynamic_cast(v)) { - return std::to_string(ci->GetValue()); +const char* ICmpPredName(ICmpPred pred) { + switch (pred) { + case ICmpPred::Eq: + return "eq"; + case ICmpPred::Ne: + return "ne"; + case ICmpPred::Slt: + return "slt"; + case ICmpPred::Sle: + return "sle"; + case ICmpPred::Sgt: + return "sgt"; + case ICmpPred::Sge: + return "sge"; } - return v ? v->GetName() : ""; + throw std::runtime_error("未知 ICmp 谓词"); } +const char* FCmpPredName(FCmpPred pred) { + switch (pred) { + case FCmpPred::Oeq: + return "oeq"; + case FCmpPred::One: + return "one"; + case FCmpPred::Olt: + return "olt"; + case FCmpPred::Ole: + return "ole"; + case FCmpPred::Ogt: + return "ogt"; + case FCmpPred::Oge: + return "oge"; + } + throw std::runtime_error("未知 FCmp 谓词"); +} + +void PrintFunctionHeader(const Function& func, std::ostream& os, bool define) { + os << (define ? "define " : "declare ") + << TypeToString(func.GetReturnType()) << " @" << func.GetName() << "("; + const auto& args = func.GetArguments(); + const auto& params = func.GetFunctionType()->GetParamTypes(); + for (size_t i = 0; i < params.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << TypeToString(params[i]); + if (define) { + os << " " << args[i]->GetName(); + } + } + os << ")"; +} + +} // namespace + void IRPrinter::Print(const Module& module, std::ostream& os) { + for (const auto& global : module.GetGlobals()) { + if (!global) { + continue; + } + os << "@" << global->GetName() << " = " + << (global->IsConstant() ? "constant " : "global ") + << TypeToString(global->GetValueType()) << " "; + auto* init = global->GetInitializer(); + if (!init) { + ConstantZero zero(global->GetValueType()); + os << ConstantToString(&zero); + } else { + os << ConstantToString(init); + } + os << "\n"; + } + + if (!module.GetGlobals().empty() && !module.GetFunctions().empty()) { + os << "\n"; + } + + bool first_function = true; for (const auto& func : module.GetFunctions()) { - os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() - << "() {\n"; + if (!func) { + continue; + } + if (!first_function) { + os << "\n"; + } + first_function = false; + + if (func->IsDeclaration()) { + PrintFunctionHeader(*func, os, false); + os << "\n"; + continue; + } + + PrintFunctionHeader(*func, os, true); + os << " {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { continue; } os << bb->GetName() << ":\n"; - for (const auto& instPtr : bb->GetInstructions()) { - const auto* inst = instPtr.get(); + for (const auto& inst_ptr : bb->GetInstructions()) { + const auto* inst = inst_ptr.get(); switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: - case Opcode::Mul: { - auto* bin = static_cast(inst); - os << " " << bin->GetName() << " = " - << OpcodeToString(bin->GetOpcode()) << " " - << TypeToString(*bin->GetLhs()->GetType()) << " " - << ValueToString(bin->GetLhs()) << ", " - << ValueToString(bin->GetRhs()) << "\n"; + case Opcode::Mul: + case Opcode::SDiv: + case Opcode::SRem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: { + const auto* bin = static_cast(inst); + os << " " << bin->GetName() << " = " << BinaryOpcodeName(inst->GetOpcode()) + << " " << TypeToString(bin->GetType()) << " " + << ValueRef(bin->GetLhs()) << ", " << ValueRef(bin->GetRhs()) + << "\n"; break; } case Opcode::Alloca: { - auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca i32\n"; + const auto* alloca = static_cast(inst); + os << " " << alloca->GetName() << " = alloca " + << TypeToString(alloca->GetAllocatedType()) << "\n"; break; } case Opcode::Load: { - auto* load = static_cast(inst); - os << " " << load->GetName() << " = load i32, i32* " - << ValueToString(load->GetPtr()) << "\n"; + const auto* load = static_cast(inst); + os << " " << load->GetName() << " = load " + << TypeToString(load->GetType()) << ", " + << TypeToString(load->GetPtr()->GetType()) << " " + << ValueRef(load->GetPtr()) << "\n"; break; } case Opcode::Store: { - auto* store = static_cast(inst); - os << " store i32 " << ValueToString(store->GetValue()) - << ", i32* " << ValueToString(store->GetPtr()) << "\n"; + const auto* store = static_cast(inst); + os << " store " << TypeToString(store->GetValue()->GetType()) << " " + << ValueRef(store->GetValue()) << ", " + << TypeToString(store->GetPtr()->GetType()) << " " + << ValueRef(store->GetPtr()) << "\n"; + break; + } + case Opcode::ICmp: + case Opcode::FCmp: { + const auto* cmp = static_cast(inst); + os << " " << cmp->GetName() << " = " + << (cmp->IsFloatCompare() ? "fcmp " : "icmp ") + << (cmp->IsFloatCompare() ? FCmpPredName(cmp->GetFCmpPred()) + : ICmpPredName(cmp->GetICmpPred())) + << " " << TypeToString(cmp->GetLhs()->GetType()) << " " + << ValueRef(cmp->GetLhs()) << ", " << ValueRef(cmp->GetRhs()) + << "\n"; + break; + } + case Opcode::Br: { + const auto* br = static_cast(inst); + os << " br label %" << br->GetTarget()->GetName() << "\n"; + break; + } + case Opcode::CondBr: { + const auto* br = static_cast(inst); + os << " br i1 " << ValueRef(br->GetCond()) << ", label %" + << br->GetTrueBlock()->GetName() << ", label %" + << br->GetFalseBlock()->GetName() << "\n"; + break; + } + case Opcode::Call: { + const auto* call = static_cast(inst); + if (!call->GetType()->IsVoid()) { + os << " " << call->GetName() << " = "; + } else { + os << " "; + } + os << "call " << TypeToString(call->GetCallee()->GetReturnType()) + << " @" << call->GetCallee()->GetName() << "("; + auto args = call->GetArgs(); + for (size_t i = 0; i < args.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << TypeToString(args[i]->GetType()) << " " << ValueRef(args[i]); + } + os << ")\n"; + break; + } + case Opcode::GEP: { + const auto* gep = static_cast(inst); + os << " " << gep->GetName() << " = getelementptr " + << TypeToString(gep->GetSourceElementType()) << ", " + << TypeToString(gep->GetBasePtr()->GetType()) << " " + << ValueRef(gep->GetBasePtr()); + for (auto* index : gep->GetIndices()) { + os << ", " << TypeToString(index->GetType()) << " " << ValueRef(index); + } + os << "\n"; + break; + } + case Opcode::SIToFP: + case Opcode::FPToSI: + case Opcode::ZExt: { + const auto* cast = static_cast(inst); + const char* opname = inst->GetOpcode() == Opcode::SIToFP + ? "sitofp" + : inst->GetOpcode() == Opcode::FPToSI ? "fptosi" + : "zext"; + os << " " << cast->GetName() << " = " << opname << " " + << TypeToString(cast->GetValue()->GetType()) << " " + << ValueRef(cast->GetValue()) << " to " + << TypeToString(cast->GetType()) << "\n"; break; } case Opcode::Ret: { - auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + const auto* ret = static_cast(inst); + if (auto* value = ret->GetValue()) { + os << " ret " << TypeToString(value->GetType()) << " " + << ValueRef(value) << "\n"; + } else { + os << " ret void\n"; + } break; } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7928716..e0b7623 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -1,13 +1,27 @@ -// IR 指令体系: -// - 二元运算/比较、load/store、call、br/condbr、ret、phi、alloca 等 -// - 指令操作数与结果类型管理,支持打印与优化 #include "ir/IR.h" #include -#include "utils/Log.h" - namespace ir { +namespace { + +void Require(bool condition, const std::string& message) { + if (!condition) { + throw std::runtime_error(message); + } +} + +bool SameType(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return lhs && rhs && lhs->Equals(*rhs); +} + +std::shared_ptr GetPointeeType(Value* ptr) { + Require(ptr && ptr->GetType() && ptr->GetType()->IsPointer(), "期望指针类型"); + return ptr->GetType()->GetElementType(); +} + +} // namespace + User::User(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} @@ -24,9 +38,7 @@ void User::SetOperand(size_t index, Value* value) { if (index >= operands_.size()) { throw std::out_of_range("User operand index out of range"); } - if (!value) { - throw std::runtime_error(FormatError("ir", "User operand 不能为空")); - } + Require(value != nullptr, "User operand 不能为空"); auto* old = operands_[index]; if (old == value) { return; @@ -39,10 +51,8 @@ void User::SetOperand(size_t index, Value* value) { } void User::AddOperand(Value* value) { - if (!value) { - throw std::runtime_error(FormatError("ir", "User operand 不能为空")); - } - size_t operand_index = operands_.size(); + Require(value != nullptr, "User operand 不能为空"); + const size_t operand_index = operands_.size(); operands_.push_back(value); value->AddUse(this, operand_index); } @@ -52,30 +62,49 @@ Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) Opcode Instruction::GetOpcode() const { return opcode_; } -bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } +bool Instruction::IsTerminator() const { + return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || + opcode_ == Opcode::CondBr; +} BasicBlock* Instruction::GetParent() const { return parent_; } -void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } - -BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, - Value* rhs, std::string name) - : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); - } - if (!lhs || !rhs) { - throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); - } - if (!type_ || !lhs->GetType() || !rhs->GetType()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); +void Instruction::SetParent(BasicBlock* parent) { + parent_ = parent; + if (!parent_) { + return; } - if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() || - type_->GetKind() != lhs->GetType()->GetKind()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); + if (auto* br = dynamic_cast(this)) { + parent_->AddSuccessor(br->GetTarget()); + } else if (auto* cond = dynamic_cast(this)) { + parent_->AddSuccessor(cond->GetTrueBlock()); + parent_->AddSuccessor(cond->GetFalseBlock()); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); +} + +BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, + std::string name) + : Instruction(op, std::move(ty), std::move(name)) { + Require(lhs && rhs, "BinaryInst 缺少操作数"); + Require(type_ && lhs->GetType() && rhs->GetType(), "BinaryInst 缺少类型信息"); + Require(SameType(lhs->GetType(), rhs->GetType()), "BinaryInst 操作数类型不匹配"); + Require(SameType(type_, lhs->GetType()), "BinaryInst 结果类型不匹配"); + switch (op) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::SDiv: + case Opcode::SRem: + Require(type_->IsInt32(), "整数 BinaryInst 只支持 i32"); + break; + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + Require(type_->IsFloat32(), "浮点 BinaryInst 只支持 float"); + break; + default: + throw std::runtime_error("BinaryInst 不支持该 opcode"); } AddOperand(lhs); AddOperand(rhs); @@ -85,67 +114,189 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); } Value* BinaryInst::GetRhs() const { return GetOperand(1); } -ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) - : Instruction(Opcode::Ret, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); - } - if (!type_ || !type_->IsVoid()) { - throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); - } - AddOperand(val); +CompareInst::CompareInst(ICmpPred pred, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::ICmp, Type::GetInt1Type(), std::move(name)), + icmp_pred_(pred) { + Require(lhs && rhs, "ICmp 缺少操作数"); + Require(lhs->GetType() && rhs->GetType(), "ICmp 缺少类型信息"); + Require(lhs->GetType()->IsInt32() && rhs->GetType()->IsInt32(), + "ICmp 只支持 i32"); + AddOperand(lhs); + AddOperand(rhs); +} + +CompareInst::CompareInst(FCmpPred pred, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)), + is_float_compare_(true), + fcmp_pred_(pred) { + Require(lhs && rhs, "FCmp 缺少操作数"); + Require(lhs->GetType() && rhs->GetType(), "FCmp 缺少类型信息"); + Require(lhs->GetType()->IsFloat32() && rhs->GetType()->IsFloat32(), + "FCmp 只支持 float"); + AddOperand(lhs); + AddOperand(rhs); } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +Value* CompareInst::GetLhs() const { return GetOperand(0); } -AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) - : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { - if (!type_ || !type_->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); - } +Value* CompareInst::GetRhs() const { return GetOperand(1); } + +ReturnInst::ReturnInst(Value* value) + : Instruction(Opcode::Ret, Type::GetVoidType(), "") { + Require(value != nullptr, "ret 缺少返回值"); + AddOperand(value); } -LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) - : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { - if (!ptr) { - throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); - } - if (!type_ || !type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "LoadInst 当前只支持从 i32* 加载")); - } +ReturnInst::ReturnInst() : Instruction(Opcode::Ret, Type::GetVoidType(), "") {} + +Value* ReturnInst::GetValue() const { + return GetNumOperands() == 0 ? nullptr : GetOperand(0); +} + +AllocaInst::AllocaInst(std::shared_ptr allocated_type, std::string name) + : Instruction(Opcode::Alloca, Type::GetPointerType(allocated_type), + std::move(name)), + allocated_type_(std::move(allocated_type)) { + Require(allocated_type_ != nullptr, "alloca 缺少目标类型"); +} + +LoadInst::LoadInst(Value* ptr, std::shared_ptr value_type, std::string name) + : Instruction(Opcode::Load, std::move(value_type), std::move(name)) { + Require(ptr != nullptr, "load 缺少 ptr"); + Require(type_ != nullptr, "load 缺少 value type"); + Require(ptr->GetType() && ptr->GetType()->IsPointer(), "load 需要指针操作数"); + Require(SameType(GetPointeeType(ptr), type_), "load 类型不匹配"); AddOperand(ptr); } Value* LoadInst::GetPtr() const { return GetOperand(0); } -StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) - : Instruction(Opcode::Store, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value")); - } - if (!ptr) { - throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr")); +StoreInst::StoreInst(Value* value, Value* ptr) + : Instruction(Opcode::Store, Type::GetVoidType(), "") { + Require(value != nullptr, "store 缺少 value"); + Require(ptr != nullptr, "store 缺少 ptr"); + Require(ptr->GetType() && ptr->GetType()->IsPointer(), "store 需要指针操作数"); + Require(SameType(value->GetType(), GetPointeeType(ptr)), "store 类型不匹配"); + AddOperand(value); + AddOperand(ptr); +} + +Value* StoreInst::GetValue() const { return GetOperand(0); } + +Value* StoreInst::GetPtr() const { return GetOperand(1); } + +BranchInst::BranchInst(BasicBlock* target) + : Instruction(Opcode::Br, Type::GetVoidType(), "") { + Require(target != nullptr, "br 缺少目标块"); + AddOperand(target); +} + +BasicBlock* BranchInst::GetTarget() const { + return static_cast(GetOperand(0)); +} + +CondBranchInst::CondBranchInst(Value* cond, BasicBlock* true_block, + BasicBlock* false_block) + : Instruction(Opcode::CondBr, Type::GetVoidType(), "") { + Require(cond != nullptr, "condbr 缺少条件"); + Require(cond->GetType() && cond->GetType()->IsInt1(), "condbr 条件必须为 i1"); + Require(true_block != nullptr && false_block != nullptr, + "condbr 缺少目标块"); + AddOperand(cond); + AddOperand(true_block); + AddOperand(false_block); +} + +Value* CondBranchInst::GetCond() const { return GetOperand(0); } + +BasicBlock* CondBranchInst::GetTrueBlock() const { + return static_cast(GetOperand(1)); +} + +BasicBlock* CondBranchInst::GetFalseBlock() const { + return static_cast(GetOperand(2)); +} + +CallInst::CallInst(Function* callee, std::vector args, std::string name) + : Instruction(Opcode::Call, callee ? callee->GetReturnType() : Type::GetVoidType(), + std::move(name)) { + Require(callee != nullptr, "call 缺少 callee"); + AddOperand(callee); + const auto& params = callee->GetFunctionType()->GetParamTypes(); + Require(params.size() == args.size(), "call 参数个数不匹配"); + for (size_t i = 0; i < args.size(); ++i) { + Require(args[i] != nullptr, "call 缺少实参"); + Require(SameType(params[i], args[i]->GetType()), "call 参数类型不匹配"); + AddOperand(args[i]); } - if (!type_ || !type_->IsVoid()) { - throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); +} + +Function* CallInst::GetCallee() const { + return static_cast(GetOperand(0)); +} + +std::vector CallInst::GetArgs() const { + std::vector args; + for (size_t i = 1; i < GetNumOperands(); ++i) { + args.push_back(GetOperand(i)); } - if (!val->GetType() || !val->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); + return args; +} + +GetElementPtrInst::GetElementPtrInst(Value* base_ptr, std::vector indices, + std::shared_ptr result_type, + std::string name) + : Instruction(Opcode::GEP, std::move(result_type), std::move(name)) { + Require(base_ptr != nullptr, "gep 缺少 base_ptr"); + Require(base_ptr->GetType() && base_ptr->GetType()->IsPointer(), + "gep 需要指针基址"); + Require(type_ != nullptr && type_->IsPointer(), "gep 结果必须是指针"); + AddOperand(base_ptr); + for (auto* index : indices) { + Require(index != nullptr, "gep 缺少索引"); + Require(index->GetType() && index->GetType()->IsInt32(), "gep 索引必须为 i32"); + AddOperand(index); } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "StoreInst 当前只支持写入 i32*")); +} + +Value* GetElementPtrInst::GetBasePtr() const { return GetOperand(0); } + +std::vector GetElementPtrInst::GetIndices() const { + std::vector indices; + for (size_t i = 1; i < GetNumOperands(); ++i) { + indices.push_back(GetOperand(i)); } - AddOperand(val); - AddOperand(ptr); + return indices; } -Value* StoreInst::GetValue() const { return GetOperand(0); } +std::shared_ptr GetElementPtrInst::GetSourceElementType() const { + return GetBasePtr()->GetType()->GetElementType(); +} -Value* StoreInst::GetPtr() const { return GetOperand(1); } +CastInst::CastInst(Opcode op, Value* value, std::shared_ptr dst_type, + std::string name) + : Instruction(op, std::move(dst_type), std::move(name)) { + Require(value != nullptr, "cast 缺少 value"); + Require(type_ != nullptr, "cast 缺少目标类型"); + switch (op) { + case Opcode::SIToFP: + Require(value->GetType() && value->GetType()->IsInt32() && type_->IsFloat32(), + "sitofp 需要 i32 -> float"); + break; + case Opcode::FPToSI: + Require(value->GetType() && value->GetType()->IsFloat32() && type_->IsInt32(), + "fptosi 需要 float -> i32"); + break; + case Opcode::ZExt: + Require(value->GetType() && value->GetType()->IsInt1() && type_->IsInt32(), + "zext 需要 i1 -> i32"); + break; + default: + throw std::runtime_error("不支持的 cast opcode"); + } + AddOperand(value); +} + +Value* CastInst::GetValue() const { return GetOperand(0); } } // namespace ir diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 928efdc..ac0f60f 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -1,5 +1,3 @@ -// 保存函数列表并提供模块级上下文访问。 - #include "ir/IR.h" namespace ir { @@ -8,12 +6,45 @@ Context& Module::GetContext() { return context_; } const Context& Module::GetContext() const { return context_; } +GlobalVariable* Module::CreateGlobal(std::string name, + std::shared_ptr value_type, + ConstantValue* initializer, + bool is_constant) { + globals_.push_back(std::make_unique( + std::move(name), std::move(value_type), initializer, is_constant)); + return globals_.back().get(); +} + Function* Module::CreateFunction(const std::string& name, - std::shared_ptr ret_type) { - functions_.push_back(std::make_unique(name, std::move(ret_type))); + std::shared_ptr function_type, + bool is_declaration) { + functions_.push_back( + std::make_unique(name, std::move(function_type), is_declaration)); return functions_.back().get(); } +Function* Module::FindFunction(const std::string& name) const { + for (const auto& func : functions_) { + if (func && func->GetName() == name) { + return func.get(); + } + } + return nullptr; +} + +GlobalVariable* Module::FindGlobal(const std::string& name) const { + for (const auto& global : globals_) { + if (global && global->GetName() == name) { + return global.get(); + } + } + return nullptr; +} + +const std::vector>& Module::GetGlobals() const { + return globals_; +} + const std::vector>& Module::GetFunctions() const { return functions_; } diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 3e1684d..8602d17 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -1,31 +1,141 @@ -// 当前仅支持 void、i32 和 i32*。 #include "ir/IR.h" +#include + namespace ir { -Type::Type(Kind k) : kind_(k) {} +Type::Type(Kind kind) : kind_(kind) {} + +Type::Type(Kind kind, std::shared_ptr element_type) + : kind_(kind), element_type_(std::move(element_type)) {} + +Type::Type(Kind kind, std::shared_ptr element_type, size_t array_size) + : kind_(kind), + element_type_(std::move(element_type)), + array_size_(array_size) {} + +Type::Type(std::shared_ptr return_type, + std::vector> params) + : kind_(Kind::Function), + return_type_(std::move(return_type)), + param_types_(std::move(params)) {} const std::shared_ptr& Type::GetVoidType() { - static const std::shared_ptr type = std::make_shared(Kind::Void); + static const auto type = std::make_shared(Kind::Void); + return type; +} + +const std::shared_ptr& Type::GetInt1Type() { + static const auto type = std::make_shared(Kind::Int1); return type; } const std::shared_ptr& Type::GetInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::Int32); + static const auto type = std::make_shared(Kind::Int32); + return type; +} + +const std::shared_ptr& Type::GetFloatType() { + static const auto type = std::make_shared(Kind::Float32); return type; } +std::shared_ptr Type::GetPointerType(std::shared_ptr element_type) { + if (!element_type) { + throw std::runtime_error("GetPointerType 缺少 element_type"); + } + return std::make_shared(Kind::Pointer, std::move(element_type)); +} + +std::shared_ptr Type::GetArrayType(std::shared_ptr element_type, + size_t array_size) { + if (!element_type) { + throw std::runtime_error("GetArrayType 缺少 element_type"); + } + return std::make_shared(Kind::Array, std::move(element_type), array_size); +} + +std::shared_ptr Type::GetFunctionType( + std::shared_ptr return_type, + std::vector> param_types) { + if (!return_type) { + throw std::runtime_error("GetFunctionType 缺少 return_type"); + } + return std::make_shared(std::move(return_type), std::move(param_types)); +} + const std::shared_ptr& Type::GetPtrInt32Type() { - static const std::shared_ptr type = std::make_shared(Kind::PtrInt32); + static const auto type = GetPointerType(GetInt32Type()); return type; } Type::Kind Type::GetKind() const { return kind_; } +const std::shared_ptr& Type::GetElementType() const { return element_type_; } + +size_t Type::GetArraySize() const { return array_size_; } + +const std::shared_ptr& Type::GetReturnType() const { return return_type_; } + +const std::vector>& Type::GetParamTypes() const { + return param_types_; +} + bool Type::IsVoid() const { return kind_ == Kind::Void; } +bool Type::IsInt1() const { return kind_ == Kind::Int1; } + bool Type::IsInt32() const { return kind_ == Kind::Int32; } -bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } +bool Type::IsFloat32() const { return kind_ == Kind::Float32; } + +bool Type::IsPointer() const { return kind_ == Kind::Pointer; } + +bool Type::IsArray() const { return kind_ == Kind::Array; } + +bool Type::IsFunction() const { return kind_ == Kind::Function; } + +bool Type::IsScalar() const { return IsInt1() || IsInt32() || IsFloat32(); } + +bool Type::IsInteger() const { return IsInt1() || IsInt32(); } + +bool Type::IsNumeric() const { return IsInteger() || IsFloat32(); } + +bool Type::IsPtrInt32() const { + return IsPointer() && element_type_ && element_type_->IsInt32(); +} + +bool Type::Equals(const Type& other) const { + if (kind_ != other.kind_) { + return false; + } + switch (kind_) { + case Kind::Void: + case Kind::Int1: + case Kind::Int32: + case Kind::Float32: + return true; + case Kind::Pointer: + return element_type_ && other.element_type_ && + element_type_->Equals(*other.element_type_); + case Kind::Array: + return array_size_ == other.array_size_ && element_type_ && + other.element_type_ && element_type_->Equals(*other.element_type_); + case Kind::Function: + if (!return_type_ || !other.return_type_ || + !return_type_->Equals(*other.return_type_) || + param_types_.size() != other.param_types_.size()) { + return false; + } + for (size_t i = 0; i < param_types_.size(); ++i) { + if (!param_types_[i] || !other.param_types_[i] || + !param_types_[i]->Equals(*other.param_types_[i])) { + return false; + } + } + return true; + } + return false; +} } // namespace ir diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 2e9f4c1..d5a291c 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -1,9 +1,7 @@ -// SSA 值体系抽象: -// - 常量、参数、指令结果等统一为 Value -// - 提供类型信息与使用/被使用关系(按需要实现) #include "ir/IR.h" #include +#include namespace ir { @@ -14,12 +12,22 @@ const std::shared_ptr& Value::GetType() const { return type_; } const std::string& Value::GetName() const { return name_; } -void Value::SetName(std::string n) { name_ = std::move(n); } +void Value::SetName(std::string name) { name_ = std::move(name); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); } +bool Value::IsInt1() const { return type_ && type_->IsInt1(); } + bool Value::IsInt32() const { return type_ && type_->IsInt32(); } +bool Value::IsFloat32() const { return type_ && type_->IsFloat32(); } + +bool Value::IsPointer() const { return type_ && type_->IsPointer(); } + +bool Value::IsArray() const { return type_ && type_->IsArray(); } + +bool Value::IsFunctionValue() const { return type_ && type_->IsFunction(); } + bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } bool Value::IsConstant() const { @@ -30,27 +38,34 @@ bool Value::IsInstruction() const { return dynamic_cast(this) != nullptr; } -bool Value::IsUser() const { - return dynamic_cast(this) != nullptr; -} +bool Value::IsUser() const { return dynamic_cast(this) != nullptr; } bool Value::IsFunction() const { return dynamic_cast(this) != nullptr; } +bool Value::IsGlobalVariable() const { + return dynamic_cast(this) != nullptr; +} + +bool Value::IsArgument() const { + return dynamic_cast(this) != nullptr; +} + void Value::AddUse(User* user, size_t operand_index) { - if (!user) return; - uses_.push_back(Use(this, user, operand_index)); + if (!user) { + return; + } + uses_.emplace_back(this, user, operand_index); } void Value::RemoveUse(User* user, size_t operand_index) { - uses_.erase( - std::remove_if(uses_.begin(), uses_.end(), - [&](const Use& use) { - return use.GetUser() == user && - use.GetOperandIndex() == operand_index; - }), - uses_.end()); + uses_.erase(std::remove_if(uses_.begin(), uses_.end(), + [&](const Use& use) { + return use.GetUser() == user && + use.GetOperandIndex() == operand_index; + }), + uses_.end()); } const std::vector& Value::GetUses() const { return uses_; } @@ -62,22 +77,39 @@ void Value::ReplaceAllUsesWith(Value* new_value) { if (new_value == this) { return; } - - auto uses = uses_; - for (const auto& use : uses) { + auto snapshot = uses_; + for (const auto& use : snapshot) { auto* user = use.GetUser(); - if (!user) continue; - size_t operand_index = use.GetOperandIndex(); - if (user->GetOperand(operand_index) == this) { - user->SetOperand(operand_index, new_value); + if (!user) { + continue; } + user->SetOperand(use.GetOperandIndex(), new_value); } } ConstantValue::ConstantValue(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} -ConstantInt::ConstantInt(std::shared_ptr ty, int v) - : ConstantValue(std::move(ty), ""), value_(v) {} +ConstantInt::ConstantInt(std::shared_ptr ty, int value) + : ConstantValue(std::move(ty), ""), value_(value) {} + +ConstantFloat::ConstantFloat(std::shared_ptr ty, float value) + : ConstantValue(std::move(ty), ""), value_(value) {} + +ConstantZero::ConstantZero(std::shared_ptr ty) + : ConstantValue(std::move(ty), "") {} + +ConstantArray::ConstantArray(std::shared_ptr ty, + std::vector elements) + : ConstantValue(std::move(ty), ""), elements_(std::move(elements)) {} + +bool ConstantArray::IsZeroValue() const { + for (auto* element : elements_) { + if (!element || !element->IsZeroValue()) { + return false; + } + } + return true; +} } // namespace ir diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 9b7c2d9..0e3aadf 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,69 +1,479 @@ #include "irgen/IRGen.h" +#include +#include #include -#include "SysYParser.h" -#include "ir/IR.h" #include "utils/Log.h" -void IRGenImpl::GenBlock(SysYParser::BlockContext& block) { - for (auto* item : block.blockItem()) { - if (item) { - if (GenBlockItem(*item)) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 - break; +namespace { + +using ir::Type; + +size_t ScalarCount(const std::shared_ptr& type) { + return type->IsArray() ? type->GetArraySize() * ScalarCount(type->GetElementType()) : 1; +} + +std::shared_ptr ScalarLeafType(const std::shared_ptr& type) { + auto current = type; + while (current->IsArray()) { + current = current->GetElementType(); + } + return current; +} + +ConstantData ZeroForType(const std::shared_ptr& type) { + return type->IsFloat32() ? ConstantData::FromFloat(0.0f) + : ConstantData::FromInt(0); +} + +ConstantData ParseNumberValue(const std::string& text) { + if (text.find_first_of(".pPeE") == std::string::npos) { + return ConstantData::FromInt(static_cast(std::strtoll(text.c_str(), nullptr, 0))); + } + return ConstantData::FromFloat(std::strtof(text.c_str(), nullptr)); +} + +bool SameType(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return lhs && rhs && lhs->Equals(*rhs); +} + +ConstantData EvalGlobalConstAddExp( + SysYParser::AddExpContext& add, + const std::unordered_map& const_values); + +ConstantData EvalGlobalConstPrimary( + SysYParser::PrimaryContext& primary, + const std::unordered_map& const_values) { + if (primary.Number()) { + return ParseNumberValue(primary.Number()->getText()); + } + if (primary.exp()) { + return EvalGlobalConstAddExp(*primary.exp()->addExp(), const_values); + } + if (primary.lVal() && primary.lVal()->Ident() && primary.lVal()->exp().empty()) { + auto found = const_values.find(primary.lVal()->Ident()->getText()); + if (found == const_values.end()) { + throw std::runtime_error( + FormatError("irgen", "全局初始化器引用了非常量符号: " + + primary.lVal()->Ident()->getText())); + } + return found->second; + } + throw std::runtime_error( + FormatError("irgen", "全局初始化器暂不支持该常量表达式")); +} + +ConstantData EvalGlobalConstUnaryExp( + SysYParser::UnaryExpContext& unary, + const std::unordered_map& const_values) { + if (unary.primary()) { + return EvalGlobalConstPrimary(*unary.primary(), const_values); + } + if (unary.unaryExp()) { + ConstantData value = EvalGlobalConstUnaryExp(*unary.unaryExp(), const_values); + const std::string op = unary.unaryOp()->getText(); + if (op == "+") { + return value; + } + if (op == "-") { + return value.IsFloat() ? ConstantData::FromFloat(-value.AsFloat()) + : ConstantData::FromInt(-value.AsInt()); + } + if (op == "!") { + return ConstantData::FromInt(value.IsFloat() ? (value.AsFloat() == 0.0f) + : (value.AsInt() == 0)); + } + } + throw std::runtime_error(FormatError("irgen", "全局初始化器不支持函数调用")); +} + +ConstantData EvalGlobalConstMulExp( + SysYParser::MulExpContext& mul, + const std::unordered_map& const_values) { + ConstantData acc = EvalGlobalConstUnaryExp(*mul.unaryExp(0), const_values); + for (size_t i = 1; i < mul.unaryExp().size(); ++i) { + ConstantData rhs = EvalGlobalConstUnaryExp(*mul.unaryExp(i), const_values); + const std::string op = mul.children[2 * i - 1]->getText(); + if (op == "%") { + acc = ConstantData::FromInt(acc.AsInt() % rhs.AsInt()); + continue; + } + auto result_type = + (acc.GetType()->IsFloat32() || rhs.GetType()->IsFloat32()) ? Type::GetFloatType() + : Type::GetInt32Type(); + acc = acc.CastTo(result_type); + rhs = rhs.CastTo(result_type); + if (result_type->IsFloat32()) { + float value = op == "*" ? acc.AsFloat() * rhs.AsFloat() + : acc.AsFloat() / rhs.AsFloat(); + acc = ConstantData::FromFloat(value); + } else { + int value = op == "*" ? acc.AsInt() * rhs.AsInt() + : acc.AsInt() / rhs.AsInt(); + acc = ConstantData::FromInt(value); + } + } + return acc; +} + +ConstantData EvalGlobalConstAddExp( + SysYParser::AddExpContext& add, + const std::unordered_map& const_values) { + ConstantData acc = EvalGlobalConstMulExp(*add.mulExp(0), const_values); + for (size_t i = 1; i < add.mulExp().size(); ++i) { + ConstantData rhs = EvalGlobalConstMulExp(*add.mulExp(i), const_values); + auto result_type = + (acc.GetType()->IsFloat32() || rhs.GetType()->IsFloat32()) ? Type::GetFloatType() + : Type::GetInt32Type(); + acc = acc.CastTo(result_type); + rhs = rhs.CastTo(result_type); + if (result_type->IsFloat32()) { + float value = add.children[2 * i - 1]->getText() == "+" + ? acc.AsFloat() + rhs.AsFloat() + : acc.AsFloat() - rhs.AsFloat(); + acc = ConstantData::FromFloat(value); + } else { + int value = add.children[2 * i - 1]->getText() == "+" + ? acc.AsInt() + rhs.AsInt() + : acc.AsInt() - rhs.AsInt(); + acc = ConstantData::FromInt(value); + } + } + return acc; +} + +void FlattenInitValue(const std::shared_ptr& type, SysYParser::InitValContext& init, + std::vector& leaves, + size_t& cursor, size_t start) { + if (!type->IsArray()) { + if (cursor >= leaves.size()) { + throw std::runtime_error(FormatError("irgen", "初始化器过长")); + } + leaves[cursor++] = &init; + return; + } + if (init.exp()) { + if (cursor >= leaves.size()) { + throw std::runtime_error(FormatError("irgen", "初始化器过长")); + } + leaves[cursor++] = &init; + return; + } + auto elem_type = type->GetElementType(); + const size_t elem_span = ScalarCount(elem_type); + for (auto* child : init.initVal()) { + if (!child) { + continue; + } + if (child->L_BRACE()) { + size_t rel = cursor - start; + if (rel % elem_span != 0) { + cursor += elem_span - (rel % elem_span); } + size_t child_start = cursor; + FlattenInitValue(elem_type, *child, leaves, cursor, child_start); + cursor = child_start + elem_span; + } else { + if (cursor >= leaves.size()) { + throw std::runtime_error(FormatError("irgen", "初始化器过长")); + } + leaves[cursor++] = child; } } } -bool IRGenImpl::GenBlockItem(SysYParser::BlockItemContext& item) { - if (item.decl()) { - GenDecl(*item.decl()); - return false; +void FlattenConstInitValue(const std::shared_ptr& type, + SysYParser::ConstInitValContext& init, + std::vector& leaves, + size_t& cursor, size_t start) { + if (!type->IsArray()) { + if (cursor >= leaves.size()) { + throw std::runtime_error(FormatError("irgen", "初始化器过长")); + } + leaves[cursor++] = &init; + return; } - if (item.stmt()) { - return GenStmt(*item.stmt()); + if (init.constExp()) { + if (cursor >= leaves.size()) { + throw std::runtime_error(FormatError("irgen", "初始化器过长")); + } + leaves[cursor++] = &init; + return; + } + auto elem_type = type->GetElementType(); + const size_t elem_span = ScalarCount(elem_type); + for (auto* child : init.constInitVal()) { + if (!child) { + continue; + } + if (child->L_BRACE()) { + size_t rel = cursor - start; + if (rel % elem_span != 0) { + cursor += elem_span - (rel % elem_span); + } + size_t child_start = cursor; + FlattenConstInitValue(elem_type, *child, leaves, cursor, child_start); + cursor = child_start + elem_span; + } else { + if (cursor >= leaves.size()) { + throw std::runtime_error(FormatError("irgen", "初始化器过长")); + } + leaves[cursor++] = child; + } + } +} + +} // namespace + +void IRGenImpl::GenGlobals(SysYParser::CompUnitContext& cu) { + for (auto* decl : cu.decl()) { + if (!decl) { + continue; + } + if (decl->constDecl()) { + for (auto* def : decl->constDecl()->constDef()) { + auto* symbol = sema_.ResolveConstDef(def); + auto* global = module_.CreateGlobal( + symbol->name, symbol->type, + BuildGlobalConstInitializer(symbol->type, def->constInitVal()), true); + globals_[symbol->name] = {global, symbol->type, false, true, true}; + if (symbol->has_const_value) { + global_const_values_[symbol->name] = symbol->const_value; + } + } + } else if (decl->varDecl()) { + for (auto* def : decl->varDecl()->varDef()) { + auto* symbol = sema_.ResolveVarDef(def); + auto* global = + module_.CreateGlobal(symbol->name, symbol->type, + BuildGlobalInitializer(symbol->type, def->initVal()), false); + globals_[symbol->name] = {global, symbol->type, false, true, false}; + } + } } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); } void IRGenImpl::GenDecl(SysYParser::DeclContext& decl) { + if (decl.constDecl()) { + GenConstDecl(*decl.constDecl()); + return; + } if (decl.varDecl()) { GenVarDecl(*decl.varDecl()); return; } - throw std::runtime_error(FormatError("irgen", "暂不支持的声明类型")); + throw std::runtime_error(FormatError("irgen", "未知声明类型")); } -void IRGenImpl::GenVarDecl(SysYParser::VarDeclContext& decl) { - if (!decl.bType() || !decl.bType()->Int()) { - throw std::runtime_error(FormatError("irgen", "当前 IR 仅支持 int 标量局部变量")); +void IRGenImpl::GenConstDecl(SysYParser::ConstDeclContext& decl) { + for (auto* def : decl.constDef()) { + auto* symbol = sema_.ResolveConstDef(def); + if (!symbol) { + throw std::runtime_error(FormatError("irgen", "const 声明缺少语义绑定")); + } + + auto* slot = + builder_.CreateAlloca(symbol->type, module_.GetContext().NextTemp()); + if (symbol->type->IsArray()) { + EmitLocalConstArrayInit(slot, symbol->type, *def->constInitVal()); + } else { + ir::Value* value = GenAddExpr(*def->constInitVal()->constExp()->addExp()); + value = CastValue(value, symbol->type); + builder_.CreateStore(value, slot); + } + DeclareLocal(symbol->name, {slot, symbol->type, false, false, true}); } +} +void IRGenImpl::GenVarDecl(SysYParser::VarDeclContext& decl) { for (auto* def : decl.varDef()) { - if (!def) { - continue; + auto* symbol = sema_.ResolveVarDef(def); + if (!symbol) { + throw std::runtime_error(FormatError("irgen", "变量声明缺少语义绑定")); } - if (storage_map_.find(def) != storage_map_.end()) { - throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); + + auto* slot = + builder_.CreateAlloca(symbol->type, module_.GetContext().NextTemp()); + if (symbol->type->IsArray()) { + if (def->initVal()) { + EmitLocalArrayInit(slot, symbol->type, *def->initVal()); + } + } else { + ir::Value* init = symbol->type->IsFloat32() + ? static_cast(builder_.CreateConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); + if (auto* init_val = def->initVal()) { + init = GenExpr(*init_val->exp()); + init = CastValue(init, symbol->type); + } + builder_.CreateStore(init, slot); } - if (!def->constExp().empty()) { - throw std::runtime_error( - FormatError("irgen", "当前 IR 仅支持 int 标量局部变量")); + DeclareLocal(symbol->name, {slot, symbol->type, false, false, false}); + } +} + +void IRGenImpl::EmitArrayStore(ir::Value* base_ptr, + const std::shared_ptr& array_type, + size_t flat_index, ir::Value* value) { + auto indices = FlatIndexToIndices(array_type, flat_index); + std::vector gep_indices; + gep_indices.push_back(builder_.CreateConstInt(0)); + for (int index : indices) { + gep_indices.push_back(builder_.CreateConstInt(index)); + } + auto* addr = + builder_.CreateGEP(base_ptr, gep_indices, module_.GetContext().NextTemp()); + builder_.CreateStore(CastValue(value, addr->GetType()->GetElementType()), addr); +} + +void IRGenImpl::ZeroInitializeLocalArray(ir::Value* base_ptr, + const std::shared_ptr& array_type) { + const auto scalar_type = ScalarLeafType(array_type); + for (size_t i = 0; i < CountScalars(array_type); ++i) { + ir::Value* zero = scalar_type->IsFloat32() + ? static_cast(builder_.CreateConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); + EmitArrayStore(base_ptr, array_type, i, zero); + } +} + +void IRGenImpl::EmitLocalArrayInit(ir::Value* base_ptr, + const std::shared_ptr& array_type, + SysYParser::InitValContext& init) { + ZeroInitializeLocalArray(base_ptr, array_type); + std::vector leaves(CountScalars(array_type), nullptr); + size_t cursor = 0; + FlattenInitValue(array_type, init, leaves, cursor, 0); + for (size_t i = 0; i < leaves.size(); ++i) { + if (!leaves[i] || !leaves[i]->exp()) { + continue; + } + EmitArrayStore(base_ptr, array_type, i, GenExpr(*leaves[i]->exp())); + } +} + +void IRGenImpl::EmitLocalConstArrayInit(ir::Value* base_ptr, + const std::shared_ptr& array_type, + SysYParser::ConstInitValContext& init) { + ZeroInitializeLocalArray(base_ptr, array_type); + std::vector leaves(CountScalars(array_type), + nullptr); + size_t cursor = 0; + FlattenConstInitValue(array_type, init, leaves, cursor, 0); + for (size_t i = 0; i < leaves.size(); ++i) { + if (!leaves[i] || !leaves[i]->constExp()) { + continue; } + EmitArrayStore(base_ptr, array_type, i, GenAddExpr(*leaves[i]->constExp()->addExp())); + } +} - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[def] = slot; +ir::ConstantValue* IRGenImpl::BuildGlobalInitializer(const std::shared_ptr& type, + SysYParser::InitValContext* init) { + if (!init) { + return builder_.CreateZero(type); + } + if (!type->IsArray()) { + auto value = EvalGlobalConstAddExp(*init->exp()->addExp(), global_const_values_) + .CastTo(type); + return type->IsFloat32() + ? static_cast( + module_.GetContext().GetConstFloat(value.AsFloat())) + : static_cast( + module_.GetContext().GetConstInt(value.AsInt())); + } - ir::Value* init = builder_.CreateConstInt(0); - if (auto* init_val = def->initVal()) { - if (!init_val->exp()) { - throw std::runtime_error( - FormatError("irgen", "当前 IR 仅支持表达式初始化")); + const auto scalar_type = ScalarLeafType(type); + std::vector flat(CountScalars(type), ZeroForType(scalar_type)); + if (init->L_BRACE()) { + std::vector leaves(flat.size(), nullptr); + size_t cursor = 0; + FlattenInitValue(type, *init, leaves, cursor, 0); + for (size_t i = 0; i < leaves.size(); ++i) { + if (leaves[i] && leaves[i]->exp()) { + flat[i] = EvalGlobalConstAddExp(*leaves[i]->exp()->addExp(), global_const_values_) + .CastTo(scalar_type); } - init = GenExpr(*init_val->exp()); } - builder_.CreateStore(init, slot); } + + size_t offset = 0; + std::function&)> build = + [&](const std::shared_ptr& current) -> ir::ConstantValue* { + if (!current->IsArray()) { + ConstantData value = flat[offset++].CastTo(current); + return current->IsFloat32() + ? static_cast( + module_.GetContext().GetConstFloat(value.AsFloat())) + : static_cast( + module_.GetContext().GetConstInt(value.AsInt())); + } + std::vector elements; + bool all_zero = true; + for (size_t i = 0; i < current->GetArraySize(); ++i) { + auto* child = build(current->GetElementType()); + all_zero = all_zero && child->IsZeroValue(); + elements.push_back(child); + } + if (all_zero) { + return module_.GetContext().CreateOwnedConstant(current); + } + return module_.GetContext().CreateOwnedConstant(current, + elements); + }; + return build(type); +} + +ir::ConstantValue* IRGenImpl::BuildGlobalConstInitializer( + const std::shared_ptr& type, SysYParser::ConstInitValContext* init) { + if (!type->IsArray()) { + auto value = + EvalGlobalConstAddExp(*init->constExp()->addExp(), global_const_values_) + .CastTo(type); + return type->IsFloat32() + ? static_cast( + module_.GetContext().GetConstFloat(value.AsFloat())) + : static_cast( + module_.GetContext().GetConstInt(value.AsInt())); + } + + const auto scalar_type = ScalarLeafType(type); + std::vector flat(CountScalars(type), ZeroForType(scalar_type)); + std::vector leaves(flat.size(), nullptr); + size_t cursor = 0; + FlattenConstInitValue(type, *init, leaves, cursor, 0); + for (size_t i = 0; i < leaves.size(); ++i) { + if (leaves[i] && leaves[i]->constExp()) { + flat[i] = + EvalGlobalConstAddExp(*leaves[i]->constExp()->addExp(), global_const_values_) + .CastTo(scalar_type); + } + } + + size_t offset = 0; + std::function&)> build = + [&](const std::shared_ptr& current) -> ir::ConstantValue* { + if (!current->IsArray()) { + ConstantData value = flat[offset++].CastTo(current); + return current->IsFloat32() + ? static_cast( + module_.GetContext().GetConstFloat(value.AsFloat())) + : static_cast( + module_.GetContext().GetConstInt(value.AsInt())); + } + std::vector elements; + bool all_zero = true; + for (size_t i = 0; i < current->GetArraySize(); ++i) { + auto* child = build(current->GetElementType()); + all_zero = all_zero && child->IsZeroValue(); + elements.push_back(child); + } + if (all_zero) { + return module_.GetContext().CreateOwnedConstant(current); + } + return module_.GetContext().CreateOwnedConstant(current, + elements); + }; + return build(type); } diff --git a/src/irgen/IRGenDriver.cpp b/src/irgen/IRGenDriver.cpp index 6f2a775..b338946 100644 --- a/src/irgen/IRGenDriver.cpp +++ b/src/irgen/IRGenDriver.cpp @@ -2,10 +2,6 @@ #include -#include "SysYParser.h" -#include "ir/IR.h" -#include "utils/Log.h" - std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, const SemanticContext& sema) { auto module = std::make_unique(); diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 2c57209..6822fdb 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -1,65 +1,118 @@ #include "irgen/IRGen.h" +#include #include -#include "SysYParser.h" -#include "ir/IR.h" #include "utils/Log.h" -ir::Value* IRGenImpl::GenExpr(SysYParser::ExpContext& expr) { - if (!expr.addExp()) { - throw std::runtime_error(FormatError("irgen", "非法表达式")); +namespace { + +using ir::FCmpPred; +using ir::ICmpPred; +using ir::Opcode; +using ir::Type; + +bool SameType(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return lhs && rhs && lhs->Equals(*rhs); +} + +std::shared_ptr ArithmeticType(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return (lhs->IsFloat32() || rhs->IsFloat32()) ? Type::GetFloatType() + : Type::GetInt32Type(); +} + +} // namespace + +ir::Value* IRGenImpl::CastValue(ir::Value* value, + const std::shared_ptr& dst_type) { + if (!value || !dst_type) { + throw std::runtime_error(FormatError("irgen", "CastValue 缺少参数")); + } + if (SameType(value->GetType(), dst_type)) { + return value; + } + if (value->GetType()->IsInt1() && dst_type->IsInt32()) { + return builder_.CreateZExt(value, dst_type, module_.GetContext().NextTemp()); + } + if (value->GetType()->IsInt32() && dst_type->IsFloat32()) { + return builder_.CreateSIToFP(value, module_.GetContext().NextTemp()); + } + if (value->GetType()->IsFloat32() && dst_type->IsInt32()) { + return builder_.CreateFPToSI(value, module_.GetContext().NextTemp()); + } + throw std::runtime_error(FormatError("irgen", "不支持的类型转换")); +} + +ir::Value* IRGenImpl::ToBool(ir::Value* value) { + if (!value) { + throw std::runtime_error(FormatError("irgen", "ToBool 缺少 value")); + } + if (value->GetType()->IsInt1()) { + return value; + } + if (value->GetType()->IsInt32()) { + return builder_.CreateICmp(ICmpPred::Ne, value, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); } + if (value->GetType()->IsFloat32()) { + return builder_.CreateFCmp(FCmpPred::One, value, builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } + throw std::runtime_error(FormatError("irgen", "条件表达式只支持 int/float")); +} + +ir::Value* IRGenImpl::DecayArrayPointer(ir::Value* array_ptr) { + return builder_.CreateGEP(array_ptr, + {builder_.CreateConstInt(0), builder_.CreateConstInt(0)}, + module_.GetContext().NextTemp()); +} + +ir::Value* IRGenImpl::GenExpr(SysYParser::ExpContext& expr) { return GenAddExpr(*expr.addExp()); } ir::Value* IRGenImpl::GenAddExpr(SysYParser::AddExpContext& add) { - const auto& terms = add.mulExp(); - if (terms.empty()) { - throw std::runtime_error(FormatError("irgen", "空加法表达式")); - } - - ir::Value* acc = GenMulExpr(*terms[0]); - for (size_t i = 1; i < terms.size(); ++i) { - ir::Value* rhs = GenMulExpr(*terms[i]); - std::string name = module_.GetContext().NextTemp(); - auto* op = add.children[2 * i - 1]; - if (!op) { - throw std::runtime_error(FormatError("irgen", "加法表达式缺少运算符")); - } - const std::string text = op->getText(); - if (text == "+") { - acc = builder_.CreateBinary(ir::Opcode::Add, acc, rhs, name); - } else if (text == "-") { - acc = builder_.CreateBinary(ir::Opcode::Sub, acc, rhs, name); + ir::Value* acc = GenMulExpr(*add.mulExp(0)); + for (size_t i = 1; i < add.mulExp().size(); ++i) { + ir::Value* rhs = GenMulExpr(*add.mulExp(i)); + auto result_type = ArithmeticType(acc->GetType(), rhs->GetType()); + acc = CastValue(acc, result_type); + rhs = CastValue(rhs, result_type); + const std::string op = add.children[2 * i - 1]->getText(); + if (result_type->IsFloat32()) { + acc = builder_.CreateBinary(op == "+" ? Opcode::FAdd : Opcode::FSub, acc, rhs, + module_.GetContext().NextTemp()); } else { - throw std::runtime_error(FormatError("irgen", "暂不支持的加法运算符: " + text)); + acc = builder_.CreateBinary(op == "+" ? Opcode::Add : Opcode::Sub, acc, rhs, + module_.GetContext().NextTemp()); } } return acc; } ir::Value* IRGenImpl::GenMulExpr(SysYParser::MulExpContext& mul) { - const auto& terms = mul.unaryExp(); - if (terms.empty()) { - throw std::runtime_error(FormatError("irgen", "空乘法表达式")); - } - - ir::Value* acc = GenUnaryExpr(*terms[0]); - for (size_t i = 1; i < terms.size(); ++i) { - ir::Value* rhs = GenUnaryExpr(*terms[i]); - std::string name = module_.GetContext().NextTemp(); - auto* op = mul.children[2 * i - 1]; - if (!op) { - throw std::runtime_error(FormatError("irgen", "乘法表达式缺少运算符")); - } - const std::string text = op->getText(); - if (text == "*") { - acc = builder_.CreateBinary(ir::Opcode::Mul, acc, rhs, name); + ir::Value* acc = GenUnaryExpr(*mul.unaryExp(0)); + for (size_t i = 1; i < mul.unaryExp().size(); ++i) { + ir::Value* rhs = GenUnaryExpr(*mul.unaryExp(i)); + const std::string op = mul.children[2 * i - 1]->getText(); + if (op == "%") { + acc = CastValue(acc, Type::GetInt32Type()); + rhs = CastValue(rhs, Type::GetInt32Type()); + acc = builder_.CreateBinary(Opcode::SRem, acc, rhs, + module_.GetContext().NextTemp()); continue; } - throw std::runtime_error( - FormatError("irgen", "当前 IR 暂不支持的乘法类运算符: " + text)); + auto result_type = ArithmeticType(acc->GetType(), rhs->GetType()); + acc = CastValue(acc, result_type); + rhs = CastValue(rhs, result_type); + Opcode opcode = Opcode::Mul; + if (result_type->IsFloat32()) { + opcode = op == "*" ? Opcode::FMul : Opcode::FDiv; + } else { + opcode = op == "*" ? Opcode::Mul : Opcode::SDiv; + } + acc = builder_.CreateBinary(opcode, acc, rhs, module_.GetContext().NextTemp()); } return acc; } @@ -68,50 +121,173 @@ ir::Value* IRGenImpl::GenUnaryExpr(SysYParser::UnaryExpContext& unary) { if (unary.primary()) { return GenPrimary(*unary.primary()); } - - if (unary.unaryExp()) { - if (!unary.unaryOp()) { - throw std::runtime_error(FormatError("irgen", "一元表达式缺少运算符")); + if (unary.Ident()) { + auto* symbol = sema_.ResolveCall(&unary); + auto* callee = symbol ? module_.FindFunction(symbol->name) : nullptr; + if (!callee) { + throw std::runtime_error(FormatError("irgen", "函数声明缺失")); + } + std::vector args; + const auto& params = callee->GetFunctionType()->GetParamTypes(); + if (unary.funcRParams()) { + for (size_t i = 0; i < unary.funcRParams()->exp().size(); ++i) { + auto* value = GenExpr(*unary.funcRParams()->exp(i)); + args.push_back(CastValue(value, params[i])); + } + } + std::string name; + if (!callee->GetReturnType()->IsVoid()) { + name = module_.GetContext().NextTemp(); } + return builder_.CreateCall(callee, args, name); + } + if (unary.unaryExp()) { const std::string op = unary.unaryOp()->getText(); + auto* value = GenUnaryExpr(*unary.unaryExp()); if (op == "+") { - return GenUnaryExpr(*unary.unaryExp()); + return value; } if (op == "-") { - auto* rhs = GenUnaryExpr(*unary.unaryExp()); - return builder_.CreateBinary(ir::Opcode::Sub, builder_.CreateConstInt(0), - rhs, module_.GetContext().NextTemp()); + if (value->GetType()->IsFloat32()) { + return builder_.CreateBinary(Opcode::FSub, builder_.CreateConstFloat(0.0f), + value, module_.GetContext().NextTemp()); + } + value = CastValue(value, Type::GetInt32Type()); + return builder_.CreateBinary(Opcode::Sub, builder_.CreateConstInt(0), value, + module_.GetContext().NextTemp()); + } + if (op == "!") { + auto* bool_value = ToBool(value); + auto* as_i32 = builder_.CreateZExt(bool_value, Type::GetInt32Type(), + module_.GetContext().NextTemp()); + auto* is_zero = builder_.CreateICmp(ICmpPred::Eq, as_i32, + builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + return builder_.CreateZExt(is_zero, Type::GetInt32Type(), + module_.GetContext().NextTemp()); } - throw std::runtime_error( - FormatError("irgen", "当前 IR 暂不支持的一元运算符: " + op)); } - - throw std::runtime_error(FormatError("irgen", "当前 IR 暂不支持函数调用")); + throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } ir::Value* IRGenImpl::GenPrimary(SysYParser::PrimaryContext& primary) { if (primary.Number()) { - return builder_.CreateConstInt(std::stoi(primary.Number()->getText(), nullptr, 0)); + const std::string text = primary.Number()->getText(); + if (text.find_first_of(".pPeE") == std::string::npos) { + return builder_.CreateConstInt(static_cast(std::strtoll(text.c_str(), nullptr, 0))); + } + return builder_.CreateConstFloat(std::strtof(text.c_str(), nullptr)); + } + if (primary.exp()) { + return GenExpr(*primary.exp()); } if (primary.lVal()) { - if (!primary.lVal()->exp().empty()) { - throw std::runtime_error( - FormatError("irgen", "当前 IR 暂不支持数组取值表达式")); + return GenLValueValue(*primary.lVal()); + } + throw std::runtime_error(FormatError("irgen", "非法 primary 表达式")); +} + +ir::Value* IRGenImpl::GenRelExpr(SysYParser::RelExpContext& rel) { + ir::Value* acc = GenAddExpr(*rel.addExp(0)); + for (size_t i = 1; i < rel.addExp().size(); ++i) { + ir::Value* rhs = GenAddExpr(*rel.addExp(i)); + const std::string op = rel.children[2 * i - 1]->getText(); + ir::Value* cmp = nullptr; + if (acc->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { + acc = CastValue(acc, Type::GetFloatType()); + rhs = CastValue(rhs, Type::GetFloatType()); + FCmpPred pred = FCmpPred::Olt; + if (op == "<") pred = FCmpPred::Olt; + if (op == "<=") pred = FCmpPred::Ole; + if (op == ">") pred = FCmpPred::Ogt; + if (op == ">=") pred = FCmpPred::Oge; + cmp = builder_.CreateFCmp(pred, acc, rhs, module_.GetContext().NextTemp()); + } else { + acc = CastValue(acc, Type::GetInt32Type()); + rhs = CastValue(rhs, Type::GetInt32Type()); + ICmpPred pred = ICmpPred::Slt; + if (op == "<") pred = ICmpPred::Slt; + if (op == "<=") pred = ICmpPred::Sle; + if (op == ">") pred = ICmpPred::Sgt; + if (op == ">=") pred = ICmpPred::Sge; + cmp = builder_.CreateICmp(pred, acc, rhs, module_.GetContext().NextTemp()); } - auto* decl = sema_.ResolveVarUse(primary.lVal()); - if (!decl || !primary.lVal()->Ident()) { - throw std::runtime_error(FormatError("irgen", "变量使用缺少语义绑定")); + acc = builder_.CreateZExt(cmp, Type::GetInt32Type(), module_.GetContext().NextTemp()); + } + return acc; +} + +ir::Value* IRGenImpl::GenEqExpr(SysYParser::EqExpContext& eq) { + ir::Value* acc = GenRelExpr(*eq.relExp(0)); + for (size_t i = 1; i < eq.relExp().size(); ++i) { + ir::Value* rhs = GenRelExpr(*eq.relExp(i)); + const std::string op = eq.children[2 * i - 1]->getText(); + ir::Value* cmp = nullptr; + if (acc->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { + acc = CastValue(acc, Type::GetFloatType()); + rhs = CastValue(rhs, Type::GetFloatType()); + cmp = builder_.CreateFCmp(op == "==" ? FCmpPred::Oeq : FCmpPred::One, acc, rhs, + module_.GetContext().NextTemp()); + } else { + acc = CastValue(acc, Type::GetInt32Type()); + rhs = CastValue(rhs, Type::GetInt32Type()); + cmp = builder_.CreateICmp(op == "==" ? ICmpPred::Eq : ICmpPred::Ne, acc, rhs, + module_.GetContext().NextTemp()); + } + acc = builder_.CreateZExt(cmp, Type::GetInt32Type(), module_.GetContext().NextTemp()); + } + return acc; +} + +ir::Value* IRGenImpl::GenLValueAddress(SysYParser::LValContext& lval) { + auto* symbol = sema_.ResolveLVal(&lval); + if (!symbol) { + throw std::runtime_error(FormatError("irgen", "左值缺少语义绑定")); + } + auto* entry = LookupStorage(symbol->name); + if (!entry || !entry->storage) { + throw std::runtime_error(FormatError("irgen", "找不到变量存储: " + symbol->name)); + } + + auto current_type = entry->declared_type; + ir::Value* current_ptr = entry->storage; + if (entry->is_array_param) { + current_ptr = builder_.CreateLoad(entry->storage, module_.GetContext().NextTemp()); + } + + for (auto* index_expr : lval.exp()) { + auto* index = CastValue(GenExpr(*index_expr), Type::GetInt32Type()); + if (current_type->IsArray()) { + current_ptr = builder_.CreateGEP( + current_ptr, {builder_.CreateConstInt(0), index}, + module_.GetContext().NextTemp()); + current_type = current_type->GetElementType(); + continue; } - auto it = storage_map_.find(decl); - if (it == storage_map_.end()) { - throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + primary.lVal()->Ident()->getText())); + if (current_type->IsPointer()) { + current_ptr = + builder_.CreateGEP(current_ptr, {index}, module_.GetContext().NextTemp()); + current_type = current_type->GetElementType(); + continue; } - return builder_.CreateLoad(it->second, module_.GetContext().NextTemp()); + throw std::runtime_error(FormatError("irgen", "非法下标访问")); } - if (primary.exp()) { - return GenExpr(*primary.exp()); + return current_ptr; +} + +ir::Value* IRGenImpl::GenLValueValue(SysYParser::LValContext& lval) { + auto result_type = sema_.ResolveExprType(&lval); + auto* addr = GenLValueAddress(lval); + if (!result_type) { + throw std::runtime_error(FormatError("irgen", "左值缺少结果类型")); + } + if (result_type->IsPointer()) { + if (SameType(addr->GetType(), result_type)) { + return addr; + } + if (addr->GetType()->GetElementType()->IsArray()) { + return DecayArrayPointer(addr); + } } - throw std::runtime_error(FormatError("irgen", "暂不支持的表达式形式")); + return builder_.CreateLoad(addr, module_.GetContext().NextTemp()); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4571c14..4a664bc 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -2,30 +2,19 @@ #include -#include "SysYParser.h" -#include "ir/IR.h" #include "utils/Log.h" namespace { -void VerifyFunctionStructure(const ir::Function& func) { - // 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。 - for (const auto& bb : func.GetBlocks()) { - if (!bb || !bb->HasTerminator()) { - throw std::runtime_error( - FormatError("irgen", "基本块未正确终结: " + - (bb ? bb->GetName() : std::string("")))); - } - } +using ir::Type; + +std::shared_ptr BuiltinFn(std::shared_ptr ret, + std::vector> params) { + return Type::GetFunctionType(std::move(ret), std::move(params)); } -SysYParser::FuncDefContext* FindMainFunc(SysYParser::CompUnitContext& cu) { - for (auto* func : cu.funcDef()) { - if (func && func->Ident() && func->Ident()->getText() == "main") { - return func; - } - } - return nullptr; +bool SameType(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return lhs && rhs && lhs->Equals(*rhs); } } // namespace @@ -33,34 +22,187 @@ SysYParser::FuncDefContext* FindMainFunc(SysYParser::CompUnitContext& cu) { IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) : module_(module), sema_(sema), - func_(nullptr), + current_return_type_(Type::GetVoidType()), builder_(module.GetContext(), nullptr) {} void IRGenImpl::Gen(SysYParser::CompUnitContext& cu) { - auto* main_func = FindMainFunc(cu); - if (!main_func) { - throw std::runtime_error(FormatError("irgen", "缺少 main 定义")); + DeclareBuiltins(); + GenGlobals(cu); + GenFunctionDecls(cu); + GenFunctionBodies(cu); +} + +void IRGenImpl::DeclareBuiltins() { + const auto i32 = Type::GetInt32Type(); + const auto f32 = Type::GetFloatType(); + const auto void_ty = Type::GetVoidType(); + + const struct { + const char* name; + std::shared_ptr type; + } builtins[] = { + {"getint", BuiltinFn(i32, {})}, + {"getch", BuiltinFn(i32, {})}, + {"getfloat", BuiltinFn(f32, {})}, + {"getarray", BuiltinFn(i32, {Type::GetPointerType(i32)})}, + {"getfarray", BuiltinFn(i32, {Type::GetPointerType(f32)})}, + {"putint", BuiltinFn(void_ty, {i32})}, + {"putch", BuiltinFn(void_ty, {i32})}, + {"putfloat", BuiltinFn(void_ty, {f32})}, + {"putarray", BuiltinFn(void_ty, {i32, Type::GetPointerType(i32)})}, + {"putfarray", BuiltinFn(void_ty, {i32, Type::GetPointerType(f32)})}, + {"starttime", BuiltinFn(void_ty, {})}, + {"stoptime", BuiltinFn(void_ty, {})}, + }; + + for (const auto& builtin : builtins) { + if (!module_.FindFunction(builtin.name)) { + module_.CreateFunction(builtin.name, builtin.type, true); + } } - GenFuncDef(*main_func); } -void IRGenImpl::GenFuncDef(SysYParser::FuncDefContext& func) { - if (!func.block()) { - throw std::runtime_error(FormatError("irgen", "函数体为空")); +void IRGenImpl::GenFunctionDecls(SysYParser::CompUnitContext& cu) { + for (auto* func : cu.funcDef()) { + if (!func || !func->Ident()) { + continue; + } + auto* symbol = sema_.ResolveFuncDef(func); + if (!symbol) { + throw std::runtime_error(FormatError("irgen", "缺少函数语义信息")); + } + auto* ir_func = module_.FindFunction(symbol->name); + if (ir_func) { + continue; + } + ir_func = module_.CreateFunction(symbol->name, symbol->type, false); + const auto& params = symbol->type->GetParamTypes(); + for (size_t i = 0; i < params.size(); ++i) { + ir_func->AddArgument(params[i], "%arg" + std::to_string(i)); + } + } +} + +void IRGenImpl::GenFunctionBodies(SysYParser::CompUnitContext& cu) { + for (auto* func : cu.funcDef()) { + if (func) { + GenFuncDef(*func); + } } - if (!func.Ident()) { - throw std::runtime_error(FormatError("irgen", "缺少函数名")); +} + +void IRGenImpl::GenFuncDef(SysYParser::FuncDefContext& func) { + auto* symbol = sema_.ResolveFuncDef(&func); + if (!symbol) { + throw std::runtime_error(FormatError("irgen", "函数缺少语义绑定")); } - if (!func.funcType() || !func.funcType()->Int()) { - throw std::runtime_error( - FormatError("irgen", "当前 IR 仅支持返回 int 的 main 函数")); + + current_function_ = module_.FindFunction(symbol->name); + if (!current_function_) { + throw std::runtime_error(FormatError("irgen", "函数声明缺失: " + symbol->name)); } + current_return_type_ = symbol->type->GetReturnType(); + auto* entry = current_function_->CreateBlock("entry"); + auto* body = current_function_->CreateBlock("entry.body"); + builder_.SetInsertPoint(body); + local_scopes_.clear(); + break_targets_.clear(); + continue_targets_.clear(); - func_ = module_.CreateFunction(func.Ident()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); - storage_map_.clear(); + EnterScope(); + if (auto* params = func.funcFParams()) { + const auto& args = current_function_->GetArguments(); + for (size_t i = 0; i < params->funcFParam().size(); ++i) { + auto* param = params->funcFParam(i); + const auto* arg = args.at(i).get(); + const std::string name = param->Ident()->getText(); + auto* slot = builder_.CreateAlloca(arg->GetType(), module_.GetContext().NextTemp()); + builder_.CreateStore(const_cast(arg), slot); + DeclareLocal(name, {slot, arg->GetType(), !param->L_BRACK().empty(), false, false}); + } + } GenBlock(*func.block()); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 - VerifyFunctionStructure(*func_); + ExitScope(); + + ir::IRBuilder entry_builder(module_.GetContext(), entry); + entry_builder.CreateBr(body); + + if (builder_.GetInsertBlock() && !builder_.GetInsertBlock()->HasTerminator()) { + if (current_return_type_->IsVoid()) { + builder_.CreateRetVoid(); + } else if (current_return_type_->IsFloat32()) { + builder_.CreateRet(builder_.CreateConstFloat(0.0f)); + } else { + builder_.CreateRet(builder_.CreateConstInt(0)); + } + } +} + +void IRGenImpl::EnterScope() { local_scopes_.emplace_back(); } + +void IRGenImpl::ExitScope() { + if (!local_scopes_.empty()) { + local_scopes_.pop_back(); + } +} + +void IRGenImpl::EnsureInsertableBlock() { + if (!builder_.GetInsertBlock()) { + auto* block = current_function_->CreateBlock(module_.GetContext().NextBlock("dead")); + builder_.SetInsertPoint(block); + return; + } + if (builder_.GetInsertBlock()->HasTerminator()) { + auto* block = current_function_->CreateBlock(module_.GetContext().NextBlock("dead")); + builder_.SetInsertPoint(block); + } +} + +void IRGenImpl::DeclareLocal(const std::string& name, StorageEntry entry) { + if (local_scopes_.empty()) { + EnterScope(); + } + local_scopes_.back()[name] = std::move(entry); +} + +IRGenImpl::StorageEntry* IRGenImpl::LookupStorage(const std::string& name) { + for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + return &found->second; + } + } + auto global = globals_.find(name); + return global == globals_.end() ? nullptr : &global->second; +} + +const IRGenImpl::StorageEntry* IRGenImpl::LookupStorage(const std::string& name) const { + for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + return &found->second; + } + } + auto global = globals_.find(name); + return global == globals_.end() ? nullptr : &global->second; +} + +size_t IRGenImpl::CountScalars(const std::shared_ptr& type) const { + if (!type->IsArray()) { + return 1; + } + return type->GetArraySize() * CountScalars(type->GetElementType()); +} + +std::vector IRGenImpl::FlatIndexToIndices(const std::shared_ptr& type, + size_t flat_index) const { + if (!type->IsArray()) { + return {}; + } + size_t inner = CountScalars(type->GetElementType()); + int current = static_cast(flat_index / inner); + auto tail = FlatIndexToIndices(type->GetElementType(), flat_index % inner); + tail.insert(tail.begin(), current); + return tail; } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 67ce213..5a2714d 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -2,22 +2,164 @@ #include -#include "SysYParser.h" -#include "ir/IR.h" #include "utils/Log.h" -bool IRGenImpl::GenStmt(SysYParser::StmtContext& stmt) { +void IRGenImpl::GenBlock(SysYParser::BlockContext& block) { + EnterScope(); + for (auto* item : block.blockItem()) { + if (!item) { + continue; + } + EnsureInsertableBlock(); + GenBlockItem(*item); + } + ExitScope(); +} + +void IRGenImpl::GenBlockItem(SysYParser::BlockItemContext& item) { + if (item.decl()) { + GenDecl(*item.decl()); + return; + } + if (item.stmt()) { + GenStmt(*item.stmt()); + return; + } + throw std::runtime_error(FormatError("irgen", "未知 block item")); +} + +void IRGenImpl::GenStmt(SysYParser::StmtContext& stmt) { + if (stmt.assignStmt()) { + auto* assign = stmt.assignStmt(); + auto* addr = GenLValueAddress(*assign->lVal()); + auto* value = CastValue(GenExpr(*assign->exp()), addr->GetType()->GetElementType()); + builder_.CreateStore(value, addr); + return; + } + if (stmt.expStmt()) { + if (stmt.expStmt()->exp()) { + (void)GenExpr(*stmt.expStmt()->exp()); + } + return; + } + if (stmt.block()) { + GenBlock(*stmt.block()); + return; + } + if (stmt.ifStmt()) { + auto* then_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("if.then")); + ir::BasicBlock* else_block = nullptr; + ir::BasicBlock* merge_block = nullptr; + if (stmt.ifStmt()->Else()) { + else_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("if.else")); + merge_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("if.end")); + GenCond(*stmt.ifStmt()->cond(), then_block, else_block); + + builder_.SetInsertPoint(then_block); + GenStmt(*stmt.ifStmt()->stmt(0)); + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateBr(merge_block); + } + + builder_.SetInsertPoint(else_block); + GenStmt(*stmt.ifStmt()->stmt(1)); + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateBr(merge_block); + } + + builder_.SetInsertPoint(merge_block); + } else { + merge_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("if.end")); + GenCond(*stmt.ifStmt()->cond(), then_block, merge_block); + + builder_.SetInsertPoint(then_block); + GenStmt(*stmt.ifStmt()->stmt(0)); + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateBr(merge_block); + } + + builder_.SetInsertPoint(merge_block); + } + return; + } + if (stmt.whileStmt()) { + auto* cond_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("while.cond")); + auto* body_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("while.body")); + auto* exit_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("while.end")); + + builder_.CreateBr(cond_block); + builder_.SetInsertPoint(cond_block); + GenCond(*stmt.whileStmt()->cond(), body_block, exit_block); + + break_targets_.push_back(exit_block); + continue_targets_.push_back(cond_block); + builder_.SetInsertPoint(body_block); + GenStmt(*stmt.whileStmt()->stmt()); + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateBr(cond_block); + } + break_targets_.pop_back(); + continue_targets_.pop_back(); + + builder_.SetInsertPoint(exit_block); + return; + } + if (stmt.breakStmt()) { + builder_.CreateBr(break_targets_.back()); + return; + } + if (stmt.continueStmt()) { + builder_.CreateBr(continue_targets_.back()); + return; + } if (stmt.returnStmt()) { - GenReturnStmt(*stmt.returnStmt()); - return true; + if (!stmt.returnStmt()->exp()) { + builder_.CreateRetVoid(); + return; + } + auto* value = GenExpr(*stmt.returnStmt()->exp()); + builder_.CreateRet(CastValue(value, current_return_type_)); + return; } throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); } -void IRGenImpl::GenReturnStmt(SysYParser::ReturnStmtContext& ret) { - if (!ret.exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); +void IRGenImpl::GenCond(SysYParser::CondContext& cond, ir::BasicBlock* true_block, + ir::BasicBlock* false_block) { + GenLOrCond(*cond.lOrExp(), true_block, false_block); +} + +void IRGenImpl::GenLOrCond(SysYParser::LOrExpContext& expr, + ir::BasicBlock* true_block, + ir::BasicBlock* false_block) { + const auto& terms = expr.lAndExp(); + for (size_t i = 0; i + 1 < terms.size(); ++i) { + auto* next_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("lor.rhs")); + GenLAndCond(*terms[i], true_block, next_block); + builder_.SetInsertPoint(next_block); + } + GenLAndCond(*terms.back(), true_block, false_block); +} + +void IRGenImpl::GenLAndCond(SysYParser::LAndExpContext& expr, + ir::BasicBlock* true_block, + ir::BasicBlock* false_block) { + const auto& terms = expr.eqExp(); + for (size_t i = 0; i + 1 < terms.size(); ++i) { + auto* next_block = + current_function_->CreateBlock(module_.GetContext().NextBlock("land.rhs")); + auto* value = ToBool(GenEqExpr(*terms[i])); + builder_.CreateCondBr(value, next_block, false_block); + builder_.SetInsertPoint(next_block); } - ir::Value* v = GenExpr(*ret.exp()); - builder_.CreateRet(v); + auto* value = ToBool(GenEqExpr(*terms.back())); + builder_.CreateCondBr(value, true_block, false_block); } diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index e4b4015..856a4f1 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -1,5 +1,6 @@ #include "sem/Sema.h" +#include #include #include @@ -8,183 +9,944 @@ namespace { -SysYParser::FuncDefContext* FindMainFunc(SysYParser::CompUnitContext& comp_unit) { - SysYParser::FuncDefContext* main_func = nullptr; - for (auto* func : comp_unit.funcDef()) { - if (!func || !func->Ident()) { - continue; +using ir::Type; + +bool SameType(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { + return lhs && rhs && lhs->Equals(*rhs); +} + +std::shared_ptr MakePointer(std::shared_ptr element) { + return Type::GetPointerType(std::move(element)); +} + +bool IsScalar(const std::shared_ptr& type) { + return type && (type->IsInt32() || type->IsFloat32()); +} + +bool IsTruthyType(const std::shared_ptr& type) { return IsScalar(type); } + +bool CanImplicitConvert(const std::shared_ptr& src, + const std::shared_ptr& dst) { + if (!src || !dst) { + return false; + } + if (SameType(src, dst)) { + return true; + } + if (src->IsInt32() && dst->IsFloat32()) { + return true; + } + if (src->IsFloat32() && dst->IsInt32()) { + return true; + } + return false; +} + +std::shared_ptr ArithmeticResultType(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + if (!IsScalar(lhs) || !IsScalar(rhs)) { + throw std::runtime_error(FormatError("sema", "算术表达式需要 int/float 操作数")); + } + return (lhs->IsFloat32() || rhs->IsFloat32()) ? Type::GetFloatType() + : Type::GetInt32Type(); +} + +std::shared_ptr ParseNumberType(const std::string& text) { + return text.find_first_of(".pPeE") == std::string::npos ? Type::GetInt32Type() + : Type::GetFloatType(); +} + +ConstantData ParseNumberValue(const std::string& text) { + if (ParseNumberType(text)->IsInt32()) { + long long value = std::strtoll(text.c_str(), nullptr, 0); + return ConstantData::FromInt(static_cast(value)); + } + return ConstantData::FromFloat(std::strtof(text.c_str(), nullptr)); +} + +std::string TypeName(const std::shared_ptr& type) { + if (!type) { + return ""; + } + if (type->IsVoid()) { + return "void"; + } + if (type->IsInt32()) { + return "int"; + } + if (type->IsFloat32()) { + return "float"; + } + if (type->IsPointer()) { + return TypeName(type->GetElementType()) + "*"; + } + if (type->IsArray()) { + return "array"; + } + if (type->IsFunction()) { + return "function"; + } + return ""; +} + +class SemaAnalyzer { + public: + explicit SemaAnalyzer(SysYParser::CompUnitContext& comp_unit) + : comp_unit_(comp_unit) {} + + SemanticContext Run() { + table_.EnterScope(); + DeclareBuiltins(); + ProcessGlobalDecls(); + CollectFunctionSignatures(); + CheckFunctionBodies(); + table_.ExitScope(); + return std::move(sema_); + } + + private: + std::shared_ptr BaseTypeFromBType(SysYParser::BTypeContext* btype) { + if (!btype) { + throw std::runtime_error(FormatError("sema", "缺少基础类型")); } - if (func->Ident()->getText() != "main") { - continue; + if (btype->Int()) { + return Type::GetInt32Type(); } - if (main_func) { - throw std::runtime_error(FormatError("sema", "main 函数定义重复")); + if (btype->Float()) { + return Type::GetFloatType(); } - main_func = func; + throw std::runtime_error(FormatError("sema", "未知基础类型")); } - return main_func; -} -void CheckExpr(SysYParser::ExpContext& exp, const SymbolTable& table, - SemanticContext& sema); + std::shared_ptr BaseTypeFromFuncType(SysYParser::FuncTypeContext* func_type) { + if (!func_type) { + throw std::runtime_error(FormatError("sema", "缺少函数返回类型")); + } + if (func_type->Void()) { + return Type::GetVoidType(); + } + if (func_type->Int()) { + return Type::GetInt32Type(); + } + if (func_type->Float()) { + return Type::GetFloatType(); + } + throw std::runtime_error(FormatError("sema", "未知函数返回类型")); + } -void CheckLVal(SysYParser::LValContext& lval, const SymbolTable& table, - SemanticContext& sema) { - if (!lval.Ident()) { - throw std::runtime_error(FormatError("sema", "左值缺少标识符")); + ConstantData EvalConstExp(SysYParser::ExpContext& exp) { + if (!exp.addExp()) { + throw std::runtime_error(FormatError("sema", "非法常量表达式")); + } + return EvalConstAddExp(*exp.addExp()); } - const std::string name = lval.Ident()->getText(); - auto* decl = table.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); + + ConstantData EvalConstExp(SysYParser::ConstExpContext& exp) { + if (!exp.addExp()) { + throw std::runtime_error(FormatError("sema", "非法常量表达式")); + } + return EvalConstAddExp(*exp.addExp()); + } + + ConstantData EvalConstAddExp(SysYParser::AddExpContext& add) { + const auto& terms = add.mulExp(); + if (terms.empty()) { + throw std::runtime_error(FormatError("sema", "空加法表达式")); + } + ConstantData acc = EvalConstMulExp(*terms[0]); + for (size_t i = 1; i < terms.size(); ++i) { + ConstantData rhs = EvalConstMulExp(*terms[i]); + auto result_type = ArithmeticResultType(acc.GetType(), rhs.GetType()); + ConstantData lhs_cast = acc.CastTo(result_type); + ConstantData rhs_cast = rhs.CastTo(result_type); + const std::string op = add.children[2 * i - 1]->getText(); + if (result_type->IsFloat32()) { + float value = op == "+" ? lhs_cast.AsFloat() + rhs_cast.AsFloat() + : lhs_cast.AsFloat() - rhs_cast.AsFloat(); + acc = ConstantData::FromFloat(value); + } else { + int value = op == "+" ? lhs_cast.AsInt() + rhs_cast.AsInt() + : lhs_cast.AsInt() - rhs_cast.AsInt(); + acc = ConstantData::FromInt(value); + } + } + return acc; } - sema.BindVarUse(&lval, decl); - for (auto* index : lval.exp()) { - if (index) { - CheckExpr(*index, table, sema); + + ConstantData EvalConstMulExp(SysYParser::MulExpContext& mul) { + const auto& terms = mul.unaryExp(); + if (terms.empty()) { + throw std::runtime_error(FormatError("sema", "空乘法表达式")); } + ConstantData acc = EvalConstUnaryExp(*terms[0]); + for (size_t i = 1; i < terms.size(); ++i) { + ConstantData rhs = EvalConstUnaryExp(*terms[i]); + const std::string op = mul.children[2 * i - 1]->getText(); + if (op == "%") { + acc = ConstantData::FromInt(acc.AsInt() % rhs.AsInt()); + continue; + } + auto result_type = ArithmeticResultType(acc.GetType(), rhs.GetType()); + ConstantData lhs_cast = acc.CastTo(result_type); + ConstantData rhs_cast = rhs.CastTo(result_type); + if (result_type->IsFloat32()) { + float value = 0.0f; + if (op == "*") { + value = lhs_cast.AsFloat() * rhs_cast.AsFloat(); + } else { + value = lhs_cast.AsFloat() / rhs_cast.AsFloat(); + } + acc = ConstantData::FromFloat(value); + } else { + int value = 0; + if (op == "*") { + value = lhs_cast.AsInt() * rhs_cast.AsInt(); + } else { + value = lhs_cast.AsInt() / rhs_cast.AsInt(); + } + acc = ConstantData::FromInt(value); + } + } + return acc; + } + + ConstantData EvalConstUnaryExp(SysYParser::UnaryExpContext& unary) { + if (unary.primary()) { + return EvalConstPrimary(*unary.primary()); + } + if (unary.unaryExp()) { + ConstantData value = EvalConstUnaryExp(*unary.unaryExp()); + const std::string op = unary.unaryOp()->getText(); + if (op == "+") { + return value; + } + if (op == "-") { + return value.IsFloat() ? ConstantData::FromFloat(-value.AsFloat()) + : ConstantData::FromInt(-value.AsInt()); + } + if (op == "!") { + return ConstantData::FromInt( + value.IsFloat() ? (value.AsFloat() == 0.0f) : (value.AsInt() == 0)); + } + } + throw std::runtime_error(FormatError("sema", "常量表达式不支持函数调用")); } -} -void CheckPrimary(SysYParser::PrimaryContext& primary, const SymbolTable& table, - SemanticContext& sema) { - if (primary.Number()) { - return; + ConstantData EvalConstPrimary(SysYParser::PrimaryContext& primary) { + if (primary.Number()) { + return ParseNumberValue(primary.Number()->getText()); + } + if (primary.exp()) { + return EvalConstExp(*primary.exp()); + } + if (primary.lVal()) { + auto* ident = primary.lVal()->Ident(); + const std::string name = ident ? ident->getText() : ""; + auto* symbol = table_.Lookup(name); + if (!symbol || symbol->kind != SymbolKind::Object || !symbol->has_const_value || + !primary.lVal()->exp().empty()) { + throw std::runtime_error( + FormatError("sema", "常量求值需要可用的标量常量: " + name)); + } + return symbol->const_value; + } + throw std::runtime_error(FormatError("sema", "非法常量表达式")); } - if (primary.lVal()) { - CheckLVal(*primary.lVal(), table, sema); - return; + std::vector EvalArrayDims( + const std::vector& dims) { + std::vector values; + for (auto* dim : dims) { + ConstantData value = EvalConstExp(*dim); + int int_value = value.AsInt(); + if (int_value <= 0) { + throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); + } + values.push_back(int_value); + } + return values; } - if (primary.exp()) { - CheckExpr(*primary.exp(), table, sema); - return; + std::shared_ptr BuildArrayType(std::shared_ptr element_type, + const std::vector& dims) { + auto type = std::move(element_type); + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + type = Type::GetArrayType(type, static_cast(*it)); + } + return type; } - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); -} + std::shared_ptr BuildFuncParamType(SysYParser::FuncFParamContext& param) { + auto base_type = BaseTypeFromBType(param.bType()); + if (param.L_BRACK().empty()) { + return base_type; + } + std::vector dims; + for (auto* dim_exp : param.exp()) { + ConstantData dim_value = EvalConstExp(*dim_exp); + int int_value = dim_value.AsInt(); + if (int_value <= 0) { + throw std::runtime_error(FormatError("sema", "数组形参维度必须为正整数")); + } + dims.push_back(int_value); + } + auto array_element = BuildArrayType(base_type, dims); + return MakePointer(array_element); + } -void CheckUnaryExpr(SysYParser::UnaryExpContext& unary, const SymbolTable& table, - SemanticContext& sema) { - if (unary.primary()) { - CheckPrimary(*unary.primary(), table, sema); - return; + void DeclareBuiltins() { + DeclareBuiltin("getint", Type::GetFunctionType(Type::GetInt32Type(), {})); + DeclareBuiltin("getch", Type::GetFunctionType(Type::GetInt32Type(), {})); + DeclareBuiltin("getfloat", Type::GetFunctionType(Type::GetFloatType(), {})); + DeclareBuiltin("getarray", + Type::GetFunctionType(Type::GetInt32Type(), {MakePointer(Type::GetInt32Type())})); + DeclareBuiltin("getfarray", + Type::GetFunctionType(Type::GetInt32Type(), {MakePointer(Type::GetFloatType())})); + DeclareBuiltin("putint", Type::GetFunctionType(Type::GetVoidType(), {Type::GetInt32Type()})); + DeclareBuiltin("putch", Type::GetFunctionType(Type::GetVoidType(), {Type::GetInt32Type()})); + DeclareBuiltin("putfloat", + Type::GetFunctionType(Type::GetVoidType(), {Type::GetFloatType()})); + DeclareBuiltin("putarray", + Type::GetFunctionType(Type::GetVoidType(), + {Type::GetInt32Type(), MakePointer(Type::GetInt32Type())})); + DeclareBuiltin("putfarray", + Type::GetFunctionType(Type::GetVoidType(), + {Type::GetInt32Type(), MakePointer(Type::GetFloatType())})); + DeclareBuiltin("starttime", Type::GetFunctionType(Type::GetVoidType(), {})); + DeclareBuiltin("stoptime", Type::GetFunctionType(Type::GetVoidType(), {})); } - if (unary.unaryExp()) { - CheckUnaryExpr(*unary.unaryExp(), table, sema); - return; + void DeclareBuiltin(const std::string& name, std::shared_ptr type) { + SymbolInfo symbol; + symbol.name = name; + symbol.kind = SymbolKind::Function; + symbol.type = std::move(type); + symbol.is_global = true; + symbol.is_builtin = true; + auto* info = sema_.CreateSymbol(std::move(symbol)); + if (!table_.Declare(name, info)) { + throw std::runtime_error(FormatError("sema", "内建函数声明重复: " + name)); + } } - if (unary.funcRParams()) { - for (auto* arg : unary.funcRParams()->exp()) { - if (arg) { - CheckExpr(*arg, table, sema); + void ProcessGlobalDecls() { + for (auto* decl : comp_unit_.decl()) { + if (decl) { + CheckDecl(*decl, true); } } } -} -void CheckMulExpr(SysYParser::MulExpContext& mul, const SymbolTable& table, - SemanticContext& sema) { - for (auto* unary : mul.unaryExp()) { - if (unary) { - CheckUnaryExpr(*unary, table, sema); + void CollectFunctionSignatures() { + for (auto* func : comp_unit_.funcDef()) { + if (!func || !func->Ident()) { + continue; + } + const std::string name = func->Ident()->getText(); + if (table_.LookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义函数: " + name)); + } + std::vector> params; + if (auto* func_params = func->funcFParams()) { + for (auto* param : func_params->funcFParam()) { + params.push_back(BuildFuncParamType(*param)); + } + } + SymbolInfo symbol; + symbol.name = name; + symbol.kind = SymbolKind::Function; + symbol.type = Type::GetFunctionType(BaseTypeFromFuncType(func->funcType()), + std::move(params)); + symbol.is_global = true; + symbol.func_def = func; + auto* info = sema_.CreateSymbol(std::move(symbol)); + table_.Declare(name, info); + sema_.BindFuncDef(func, info); } } -} -void CheckAddExpr(SysYParser::AddExpContext& add, const SymbolTable& table, - SemanticContext& sema) { - for (auto* mul : add.mulExp()) { - if (mul) { - CheckMulExpr(*mul, table, sema); + void CheckFunctionBodies() { + for (auto* func : comp_unit_.funcDef()) { + if (!func) { + continue; + } + auto* symbol = sema_.ResolveFuncDef(func); + if (!symbol) { + throw std::runtime_error(FormatError("sema", "函数签名缺失")); + } + current_return_type_ = symbol->type->GetReturnType(); + loop_depth_ = 0; + + table_.EnterScope(); + if (auto* params = func->funcFParams()) { + for (auto* param : params->funcFParam()) { + if (!param || !param->Ident()) { + continue; + } + const std::string name = param->Ident()->getText(); + if (table_.LookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义形参: " + name)); + } + SymbolInfo param_symbol; + param_symbol.name = name; + param_symbol.kind = SymbolKind::Object; + param_symbol.type = BuildFuncParamType(*param); + param_symbol.is_parameter = true; + param_symbol.is_array_parameter = !param->L_BRACK().empty(); + auto* info = sema_.CreateSymbol(std::move(param_symbol)); + table_.Declare(name, info); + } + } + CheckBlock(*func->block()); + table_.ExitScope(); } } -} -void CheckExpr(SysYParser::ExpContext& exp, const SymbolTable& table, - SemanticContext& sema) { - if (!exp.addExp()) { - throw std::runtime_error(FormatError("sema", "非法表达式")); + void CheckBlock(SysYParser::BlockContext& block) { + table_.EnterScope(); + for (auto* item : block.blockItem()) { + if (!item) { + continue; + } + if (item->decl()) { + CheckDecl(*item->decl(), false); + } else if (item->stmt()) { + CheckStmt(*item->stmt()); + } + } + table_.ExitScope(); } - CheckAddExpr(*exp.addExp(), table, sema); -} -} // namespace + void CheckDecl(SysYParser::DeclContext& decl, bool is_global) { + if (decl.constDecl()) { + CheckConstDecl(*decl.constDecl(), is_global); + return; + } + if (decl.varDecl()) { + CheckVarDecl(*decl.varDecl(), is_global); + return; + } + throw std::runtime_error(FormatError("sema", "未知声明类型")); + } -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { - auto* func = FindMainFunc(comp_unit); - if (!func || !func->block()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + void CheckConstDecl(SysYParser::ConstDeclContext& decl, bool is_global) { + auto base_type = BaseTypeFromBType(decl.bType()); + for (auto* def : decl.constDef()) { + if (!def || !def->Ident()) { + continue; + } + const std::string name = def->Ident()->getText(); + if (table_.LookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + + auto dims = EvalArrayDims(def->constExp()); + auto object_type = BuildArrayType(base_type, dims); + CheckConstInitializer(*def->constInitVal(), object_type); + + SymbolInfo symbol; + symbol.name = name; + symbol.kind = SymbolKind::Object; + symbol.type = object_type; + symbol.is_const = true; + symbol.is_global = is_global; + symbol.const_def = def; + if (object_type->IsInt32() || object_type->IsFloat32()) { + symbol.has_const_value = true; + symbol.const_value = EvalConstInitScalar(*def->constInitVal()).CastTo(object_type); + } + + auto* info = sema_.CreateSymbol(std::move(symbol)); + table_.Declare(name, info); + sema_.BindConstDef(def, info); + } } - SymbolTable table; - SemanticContext sema; - bool seen_return = false; + void CheckVarDecl(SysYParser::VarDeclContext& decl, bool is_global) { + auto base_type = BaseTypeFromBType(decl.bType()); + for (auto* def : decl.varDef()) { + if (!def || !def->Ident()) { + continue; + } + const std::string name = def->Ident()->getText(); + if (table_.LookupCurrent(name)) { + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + + auto dims = EvalArrayDims(def->constExp()); + auto object_type = BuildArrayType(base_type, dims); + if (auto* init = def->initVal()) { + CheckInitializer(*init, object_type, is_global); + } + + SymbolInfo symbol; + symbol.name = name; + symbol.kind = SymbolKind::Object; + symbol.type = object_type; + symbol.is_global = is_global; + symbol.var_def = def; - const auto& items = func->block()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); + auto* info = sema_.CreateSymbol(std::move(symbol)); + table_.Declare(name, info); + sema_.BindVarDef(def, info); + } + } + + ConstantData EvalConstInitScalar(SysYParser::ConstInitValContext& init) { + if (!init.constExp()) { + throw std::runtime_error(FormatError("sema", "标量常量初始化器必须是表达式")); + } + return EvalConstExp(*init.constExp()); + } + + void CheckConstInitializer(SysYParser::ConstInitValContext& init, + const std::shared_ptr& object_type) { + if (object_type->IsArray()) { + if (init.constExp()) { + EvalConstExp(*init.constExp()); + return; + } + for (auto* child : init.constInitVal()) { + if (child) { + CheckConstInitializer(*child, object_type->GetElementType()); + } + } + return; + } + if (!IsScalar(object_type) || !init.constExp()) { + throw std::runtime_error(FormatError("sema", "非法常量初始化器")); + } + EvalConstExp(*init.constExp()); } - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; + void CheckInitializer(SysYParser::InitValContext& init, + const std::shared_ptr& object_type, + bool require_const) { + if (object_type->IsArray()) { + if (init.exp()) { + if (require_const) { + EvalConstExp(*init.exp()); + } else { + CheckExp(*init.exp()); + } + return; + } + for (auto* child : init.initVal()) { + if (child) { + CheckInitializer(*child, object_type->GetElementType(), require_const); + } + } + return; } - if (seen_return) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + if (!IsScalar(object_type) || !init.exp()) { + throw std::runtime_error(FormatError("sema", "非法初始化器")); } + auto expr_type = + require_const ? EvalConstExp(*init.exp()).GetType() : CheckExp(*init.exp()); + if (!CanImplicitConvert(expr_type, object_type)) { + throw std::runtime_error(FormatError( + "sema", "初始化器类型不匹配: " + TypeName(expr_type) + " -> " + + TypeName(object_type))); + } + } - if (auto* decl = item->decl() ? item->decl()->varDecl() : nullptr) { - for (auto* def : decl->varDef()) { - if (!def || !def->Ident()) { - continue; + std::shared_ptr CheckStmt(SysYParser::StmtContext& stmt) { + if (stmt.assignStmt()) { + auto* assign = stmt.assignStmt(); + auto lhs_type = CheckLVal(*assign->lVal(), true); + auto rhs_type = CheckExp(*assign->exp()); + if (!CanImplicitConvert(rhs_type, lhs_type)) { + throw std::runtime_error(FormatError( + "sema", "赋值类型不匹配: " + TypeName(rhs_type) + " -> " + + TypeName(lhs_type))); + } + return Type::GetVoidType(); + } + if (stmt.expStmt()) { + if (stmt.expStmt()->exp()) { + return CheckExp(*stmt.expStmt()->exp()); + } + return Type::GetVoidType(); + } + if (stmt.block()) { + CheckBlock(*stmt.block()); + return Type::GetVoidType(); + } + if (stmt.ifStmt()) { + CheckCond(*stmt.ifStmt()->cond()); + CheckStmt(*stmt.ifStmt()->stmt(0)); + if (stmt.ifStmt()->stmt().size() > 1) { + CheckStmt(*stmt.ifStmt()->stmt(1)); + } + return Type::GetVoidType(); + } + if (stmt.whileStmt()) { + CheckCond(*stmt.whileStmt()->cond()); + ++loop_depth_; + CheckStmt(*stmt.whileStmt()->stmt()); + --loop_depth_; + return Type::GetVoidType(); + } + if (stmt.breakStmt()) { + if (loop_depth_ <= 0) { + throw std::runtime_error(FormatError("sema", "break 必须出现在循环中")); + } + return Type::GetVoidType(); + } + if (stmt.continueStmt()) { + if (loop_depth_ <= 0) { + throw std::runtime_error(FormatError("sema", "continue 必须出现在循环中")); + } + return Type::GetVoidType(); + } + if (stmt.returnStmt()) { + auto* ret = stmt.returnStmt(); + if (current_return_type_->IsVoid()) { + if (ret->exp()) { + throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); } - const std::string name = def->Ident()->getText(); - if (table.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } else { + if (!ret->exp()) { + throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); } - if (!def->constExp().empty()) { + auto value_type = CheckExp(*ret->exp()); + if (!CanImplicitConvert(value_type, current_return_type_)) { throw std::runtime_error( - FormatError("sema", "当前 IR 仅支持标量局部变量")); + FormatError("sema", "return 类型不匹配: " + TypeName(value_type) + + " -> " + TypeName(current_return_type_))); } - if (auto* init = def->initVal()) { - if (!init->exp()) { - throw std::runtime_error( - FormatError("sema", "当前 IR 仅支持标量表达式初始化")); - } - CheckExpr(*init->exp(), table, sema); + } + return Type::GetVoidType(); + } + throw std::runtime_error(FormatError("sema", "暂不支持的语句类型")); + } + + void CheckCond(SysYParser::CondContext& cond) { + auto type = CheckLOrExp(*cond.lOrExp()); + if (!IsTruthyType(type)) { + throw std::runtime_error(FormatError("sema", "条件表达式必须是 int/float")); + } + } + + std::shared_ptr CheckExp(SysYParser::ExpContext& exp) { + auto type = CheckAddExp(*exp.addExp()); + sema_.SetExprType(&exp, type); + return type; + } + + std::shared_ptr CheckLOrExp(SysYParser::LOrExpContext& expr) { + auto type = CheckLAndExp(*expr.lAndExp(0)); + for (size_t i = 1; i < expr.lAndExp().size(); ++i) { + auto rhs = CheckLAndExp(*expr.lAndExp(i)); + if (!IsTruthyType(type) || !IsTruthyType(rhs)) { + throw std::runtime_error(FormatError("sema", "|| 两侧必须是 int/float")); + } + type = Type::GetInt32Type(); + } + sema_.SetExprType(&expr, type); + return type; + } + + std::shared_ptr CheckLAndExp(SysYParser::LAndExpContext& expr) { + auto type = CheckEqExp(*expr.eqExp(0)); + for (size_t i = 1; i < expr.eqExp().size(); ++i) { + auto rhs = CheckEqExp(*expr.eqExp(i)); + if (!IsTruthyType(type) || !IsTruthyType(rhs)) { + throw std::runtime_error(FormatError("sema", "&& 两侧必须是 int/float")); + } + type = Type::GetInt32Type(); + } + sema_.SetExprType(&expr, type); + return type; + } + + std::shared_ptr CheckEqExp(SysYParser::EqExpContext& expr) { + auto type = CheckRelExp(*expr.relExp(0)); + for (size_t i = 1; i < expr.relExp().size(); ++i) { + auto rhs = CheckRelExp(*expr.relExp(i)); + if (!IsScalar(type) || !IsScalar(rhs)) { + throw std::runtime_error(FormatError("sema", "==/!= 两侧必须是 int/float")); + } + type = Type::GetInt32Type(); + } + sema_.SetExprType(&expr, type); + return type; + } + + std::shared_ptr CheckRelExp(SysYParser::RelExpContext& expr) { + auto type = CheckAddExp(*expr.addExp(0)); + for (size_t i = 1; i < expr.addExp().size(); ++i) { + auto rhs = CheckAddExp(*expr.addExp(i)); + if (!IsScalar(type) || !IsScalar(rhs)) { + throw std::runtime_error(FormatError("sema", "关系运算两侧必须是 int/float")); + } + type = Type::GetInt32Type(); + } + sema_.SetExprType(&expr, type); + return type; + } + + std::shared_ptr CheckAddExp(SysYParser::AddExpContext& expr) { + auto type = CheckMulExp(*expr.mulExp(0)); + for (size_t i = 1; i < expr.mulExp().size(); ++i) { + auto rhs = CheckMulExp(*expr.mulExp(i)); + type = ArithmeticResultType(type, rhs); + } + sema_.SetExprType(&expr, type); + return type; + } + + std::shared_ptr CheckMulExp(SysYParser::MulExpContext& expr) { + auto type = CheckUnaryExp(*expr.unaryExp(0)); + for (size_t i = 1; i < expr.unaryExp().size(); ++i) { + auto rhs = CheckUnaryExp(*expr.unaryExp(i)); + const std::string op = expr.children[2 * i - 1]->getText(); + if (op == "%") { + if (!type->IsInt32() || !rhs->IsInt32()) { + throw std::runtime_error(FormatError("sema", "% 只支持 int")); } - table.Add(name, def); + type = Type::GetInt32Type(); + } else { + type = ArithmeticResultType(type, rhs); } - continue; } + sema_.SetExprType(&expr, type); + return type; + } - if (auto* stmt = item->stmt(); stmt && stmt->returnStmt()) { - auto* ret = stmt->returnStmt(); - if (!ret->exp()) { - throw std::runtime_error(FormatError("sema", "main 函数必须返回一个值")); + std::shared_ptr CheckUnaryExp(SysYParser::UnaryExpContext& expr) { + if (expr.primary()) { + auto type = CheckPrimary(*expr.primary()); + sema_.SetExprType(&expr, type); + return type; + } + if (expr.Ident()) { + const std::string name = expr.Ident()->getText(); + auto* symbol = table_.Lookup(name); + if (!symbol || symbol->kind != SymbolKind::Function) { + throw std::runtime_error(FormatError("sema", "未定义的函数: " + name)); } - CheckExpr(*ret->exp(), table, sema); - seen_return = true; - if (i + 1 != items.size()) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + sema_.BindCall(&expr, symbol); + const auto& params = symbol->type->GetParamTypes(); + std::vector args; + if (expr.funcRParams()) { + args = expr.funcRParams()->exp(); + } + if (params.size() != args.size()) { + throw std::runtime_error(FormatError( + "sema", "函数参数个数不匹配: " + name)); + } + for (size_t i = 0; i < args.size(); ++i) { + auto arg_type = CheckExp(*args[i]); + if (!CanImplicitConvert(arg_type, params[i]) && !SameType(arg_type, params[i])) { + throw std::runtime_error(FormatError( + "sema", "函数参数类型不匹配: " + name)); + } + } + auto ret_type = symbol->type->GetReturnType(); + sema_.SetExprType(&expr, ret_type); + return ret_type; + } + if (expr.unaryExp()) { + auto inner = CheckUnaryExp(*expr.unaryExp()); + const std::string op = expr.unaryOp()->getText(); + if (op == "+" || op == "-") { + if (!IsScalar(inner)) { + throw std::runtime_error(FormatError("sema", "一元 +/- 只支持 int/float")); + } + sema_.SetExprType(&expr, inner); + return inner; + } + if (op == "!") { + if (!IsTruthyType(inner)) { + throw std::runtime_error(FormatError("sema", "! 只支持 int/float")); + } + sema_.SetExprType(&expr, Type::GetInt32Type()); + return Type::GetInt32Type(); } - continue; + } + throw std::runtime_error(FormatError("sema", "非法一元表达式")); + } + + std::shared_ptr CheckPrimary(SysYParser::PrimaryContext& primary) { + if (primary.Number()) { + auto type = ParseNumberType(primary.Number()->getText()); + sema_.SetExprType(&primary, type); + return type; + } + if (primary.exp()) { + auto type = CheckExp(*primary.exp()); + sema_.SetExprType(&primary, type); + return type; + } + if (primary.lVal()) { + auto type = CheckLVal(*primary.lVal(), false); + sema_.SetExprType(&primary, type); + return type; + } + throw std::runtime_error(FormatError("sema", "非法 primary 表达式")); + } + + std::shared_ptr CheckLVal(SysYParser::LValContext& lval, bool is_assign_target) { + if (!lval.Ident()) { + throw std::runtime_error(FormatError("sema", "左值缺少标识符")); + } + const std::string name = lval.Ident()->getText(); + auto* symbol = table_.Lookup(name); + if (!symbol || symbol->kind != SymbolKind::Object) { + throw std::runtime_error(FormatError("sema", "未定义的变量: " + name)); + } + sema_.BindLVal(&lval, symbol); + + auto current_type = symbol->type; + for (auto* index : lval.exp()) { + auto index_type = CheckExp(*index); + if (!index_type->IsInt32()) { + throw std::runtime_error(FormatError("sema", "数组下标必须是 int")); + } + if (current_type->IsArray() || current_type->IsPointer()) { + current_type = current_type->GetElementType(); + } else { + throw std::runtime_error(FormatError("sema", "对非数组对象进行了下标访问")); + } + } + + if (is_assign_target) { + if (symbol->is_const) { + throw std::runtime_error(FormatError("sema", "不能给 const 对象赋值: " + name)); + } + if (!IsScalar(current_type)) { + throw std::runtime_error(FormatError("sema", "赋值目标必须是标量左值")); + } + sema_.SetExprType(&lval, current_type); + return current_type; } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + auto result_type = + current_type->IsArray() ? MakePointer(current_type->GetElementType()) : current_type; + sema_.SetExprType(&lval, result_type); + return result_type; } - if (!seen_return) { - throw std::runtime_error(FormatError("sema", "main 函数必须包含 return 语句")); + SysYParser::CompUnitContext& comp_unit_; + SemanticContext sema_; + SymbolTable table_; + std::shared_ptr current_return_type_ = Type::GetVoidType(); + int loop_depth_ = 0; +}; + +} // namespace + +ConstantData ConstantData::FromInt(int value) { + ConstantData data; + data.kind = Kind::Int; + data.int_value = value; + return data; +} + +ConstantData ConstantData::FromFloat(float value) { + ConstantData data; + data.kind = Kind::Float; + data.float_value = value; + return data; +} + +int ConstantData::AsInt() const { + return IsFloat() ? static_cast(float_value) : int_value; +} + +float ConstantData::AsFloat() const { + return IsFloat() ? float_value : static_cast(int_value); +} + +ConstantData ConstantData::CastTo(const std::shared_ptr& dst_type) const { + if (!dst_type) { + throw std::runtime_error("ConstantData::CastTo 缺少目标类型"); + } + if (dst_type->IsInt32()) { + return FromInt(AsInt()); + } + if (dst_type->IsFloat32()) { + return FromFloat(AsFloat()); } + throw std::runtime_error("ConstantData 只支持转换到 int/float"); +} + +std::shared_ptr ConstantData::GetType() const { + return IsFloat() ? ir::Type::GetFloatType() : ir::Type::GetInt32Type(); +} + +SymbolInfo* SemanticContext::CreateSymbol(SymbolInfo symbol) { + auto value = std::make_unique(std::move(symbol)); + auto* ptr = value.get(); + owned_symbols_.push_back(std::move(value)); + return ptr; +} + +void SemanticContext::BindConstDef(SysYParser::ConstDefContext* node, + const SymbolInfo* symbol) { + const_defs_[node] = symbol; +} + +void SemanticContext::BindVarDef(SysYParser::VarDefContext* node, + const SymbolInfo* symbol) { + var_defs_[node] = symbol; +} + +void SemanticContext::BindFuncDef(SysYParser::FuncDefContext* node, + const SymbolInfo* symbol) { + func_defs_[node] = symbol; +} + +void SemanticContext::BindLVal(SysYParser::LValContext* node, + const SymbolInfo* symbol) { + lvals_[node] = symbol; +} + +void SemanticContext::BindCall(SysYParser::UnaryExpContext* node, + const SymbolInfo* symbol) { + calls_[node] = symbol; +} - return sema; +void SemanticContext::SetExprType(const void* node, std::shared_ptr type) { + expr_types_[node] = std::move(type); +} + +const SymbolInfo* SemanticContext::ResolveConstDef( + const SysYParser::ConstDefContext* node) const { + auto it = const_defs_.find(node); + return it == const_defs_.end() ? nullptr : it->second; +} + +const SymbolInfo* SemanticContext::ResolveVarDef( + const SysYParser::VarDefContext* node) const { + auto it = var_defs_.find(node); + return it == var_defs_.end() ? nullptr : it->second; +} + +const SymbolInfo* SemanticContext::ResolveFuncDef( + const SysYParser::FuncDefContext* node) const { + auto it = func_defs_.find(node); + return it == func_defs_.end() ? nullptr : it->second; +} + +const SymbolInfo* SemanticContext::ResolveLVal( + const SysYParser::LValContext* node) const { + auto it = lvals_.find(node); + return it == lvals_.end() ? nullptr : it->second; +} + +const SymbolInfo* SemanticContext::ResolveCall( + const SysYParser::UnaryExpContext* node) const { + auto it = calls_.find(node); + return it == calls_.end() ? nullptr : it->second; +} + +std::shared_ptr SemanticContext::ResolveExprType(const void* node) const { + auto it = expr_types_.find(node); + return it == expr_types_.end() ? nullptr : it->second; +} + +SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { + return SemaAnalyzer(comp_unit).Run(); } diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..560f4f0 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,17 +1,38 @@ -// 维护局部变量声明的注册与查找。 - #include "sem/SymbolTable.h" -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +#include + +void SymbolTable::EnterScope() { scopes_.emplace_back(); } + +void SymbolTable::ExitScope() { + if (scopes_.empty()) { + throw std::runtime_error("作用域栈为空,无法退出"); + } + scopes_.pop_back(); +} + +bool SymbolTable::Declare(const std::string& name, const SymbolInfo* symbol) { + if (scopes_.empty()) { + EnterScope(); + } + auto& scope = scopes_.back(); + return scope.emplace(name, symbol).second; } -bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); +const SymbolInfo* SymbolTable::Lookup(const std::string& name) const { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + return found->second; + } + } + return nullptr; } -SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; +const SymbolInfo* SymbolTable::LookupCurrent(const std::string& name) const { + if (scopes_.empty()) { + return nullptr; + } + auto found = scopes_.back().find(name); + return found == scopes_.back().end() ? nullptr : found->second; } diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..354cc19 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -1,4 +1,66 @@ -// SysY 运行库实现: -// - 按实验/评测规范提供 I/O 等函数实现 -// - 与编译器生成的目标代码链接,支撑运行时行为 +#include "sylib.h" +#include +#include + +static float ReadFloatToken(void) { + char buffer[128] = {0}; + if (scanf("%127s", buffer) != 1) { + return 0.0f; + } + return strtof(buffer, NULL); +} + +int getint(void) { + int value = 0; + scanf("%d", &value); + return value; +} + +int getch(void) { + return getchar(); +} + +float getfloat(void) { return ReadFloatToken(); } + +int getarray(int a[]) { + int n = getint(); + for (int i = 0; i < n; ++i) { + a[i] = getint(); + } + return n; +} + +int getfarray(float a[]) { + int n = getint(); + for (int i = 0; i < n; ++i) { + a[i] = getfloat(); + } + return n; +} + +void putint(int x) { printf("%d", x); } + +void putch(int x) { putchar(x); } + +void putfloat(float x) { printf("%a", x); } + +void putarray(int n, int a[]) { + printf("%d:", n); + for (int i = 0; i < n; ++i) { + printf(" %d", a[i]); + } + putchar('\n'); +} + +void putfarray(int n, float a[]) { + printf("%d:", n); + for (int i = 0; i < n; ++i) { + printf(" %a", a[i]); + } + putchar('\n'); +} + +void starttime(void) {} + +void stoptime(void) {} diff --git a/sylib/sylib.h b/sylib/sylib.h index 502d488..488d134 100644 --- a/sylib/sylib.h +++ b/sylib/sylib.h @@ -1,4 +1,16 @@ -// SysY 运行库头文件: -// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用) -// - 与 sylib.c 配套,按规范逐步补齐声明 +#pragma once +int getint(void); +int getch(void); +float getfloat(void); +int getarray(int a[]); +int getfarray(float a[]); + +void putint(int x); +void putch(int x); +void putfloat(float x); +void putarray(int n, int a[]); +void putfarray(int n, float a[]); + +void starttime(void); +void stoptime(void);