Compare commits

..

9 Commits
hyz ... master

5
.gitignore vendored

@ -68,3 +68,8 @@ Thumbs.db
# Project outputs # Project outputs
# ========================= # =========================
test/test_result/ test/test_result/
# Added by cargo
/target

@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.20)
project(compiler LANGUAGES C CXX) project(compiler LANGUAGES C CXX)
find_package(Java REQUIRED COMPONENTS Runtime)
# C++ # C++
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
@ -39,6 +41,11 @@ endif()
option(COMPILER_PARSE_ONLY "Build only the frontend parser pipeline" OFF) option(COMPILER_PARSE_ONLY "Build only the frontend parser pipeline" OFF)
set(ANTLR4_JAR "${PROJECT_SOURCE_DIR}/third_party/antlr-4.13.2-complete.jar")
if(NOT EXISTS "${ANTLR4_JAR}")
message(FATAL_ERROR "ANTLR jar not found: ${ANTLR4_JAR}")
endif()
# 使 third_party ANTLR4 C++ runtime # 使 third_party ANTLR4 C++ runtime
# third_party runtime third_party/antlr4-runtime-4.13.2/runtime/src # third_party runtime third_party/antlr4-runtime-4.13.2/runtime/src
set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.13.2/runtime/src") set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.13.2/runtime/src")

7
Cargo.lock generated

@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "nudt-compiler-cpp"
version = "0.1.0"

@ -0,0 +1,6 @@
[package]
name = "nudt-compiler-cpp"
version = "0.1.0"
edition = "2024"
[dependencies]

@ -1,37 +1,15 @@
// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 // 扩展后的 IR 库:
// // - 完整基础类型void/i1/i32/float/ptr/array/function/label
// 当前已经实现: // - 指令算术、比较、分支、调用、phi、gep、类型转换等
// 1. 基础类型系统void / i32 / i32* // - 常量int/float/array
// 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction // - 基本块/函数/模块/IRBuilder 的完整接口
// 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表
// - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现
//
// 当前尚未实现或只做了最小占位:
// 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
//
// 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
//
// 建议的扩展顺序:
// 1. 先补更多指令和类型
// 2. 再补控制流相关 IR
// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架
#pragma once #pragma once
#include <cstdint>
#include <iosfwd> #include <iosfwd>
#include <memory> #include <memory>
#include <optional>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -45,10 +23,14 @@ class Value;
class User; class User;
class ConstantValue; class ConstantValue;
class ConstantInt; class ConstantInt;
class ConstantFloat;
class ConstantArray;
class GlobalValue; class GlobalValue;
class GlobalVariable;
class Instruction; class Instruction;
class BasicBlock; class BasicBlock;
class Function; class Function;
class Argument;
// Use 表示一个 Value 的一次使用记录。 // Use 表示一个 Value 的一次使用记录。
// 当前实现设计: // 当前实现设计:
@ -83,31 +65,65 @@ class Context {
~Context(); ~Context();
// 去重创建 i32 常量。 // 去重创建 i32 常量。
ConstantInt* GetConstInt(int v); ConstantInt* GetConstInt(int v);
ConstantInt* GetConstBool(bool v);
ConstantFloat* GetConstFloat(float v);
ConstantArray* CreateConstArray(std::shared_ptr<Type> array_ty,
std::vector<ConstantValue*> elements);
std::string NextTemp(); std::string NextTemp();
private: private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_; std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_bools_;
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
std::vector<std::unique_ptr<ConstantArray>> const_arrays_;
int temp_index_ = -1; int temp_index_ = -1;
}; };
class Type { class Type {
public: public:
enum class Kind { Void, Int32, PtrInt32 }; enum class Kind { Void, Int1, Int32, Float, Pointer, Array, Function, Label };
explicit Type(Kind k); explicit Type(Kind k);
Type(Kind k, std::shared_ptr<Type> elem, size_t count);
Type(Kind k, std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg);
// 使用静态共享对象获取类型。 // 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如: // 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type() // Type::GetInt32Type() == Type::GetInt32Type()
static const std::shared_ptr<Type>& GetVoidType(); static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type(); static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type(); static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> elem);
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> elem,
size_t count);
static std::shared_ptr<Type> GetFunctionType(
std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg = false);
Kind GetKind() const; Kind GetKind() const;
bool IsVoid() const; bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const; bool IsInt32() const;
bool IsPtrInt32() const; bool IsFloat() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunction() const;
bool IsLabel() const;
const std::shared_ptr<Type>& GetElementType() const;
size_t GetArraySize() const;
const std::shared_ptr<Type>& GetReturnType() const;
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const;
bool IsVarArg() const;
bool Equals(const Type& other) const;
private: private:
Kind kind_; Kind kind_;
std::shared_ptr<Type> elem_type_;
size_t array_size_ = 0;
std::shared_ptr<Type> ret_type_;
std::vector<std::shared_ptr<Type>> param_types_;
bool is_vararg_ = false;
}; };
class Value { class Value {
@ -118,7 +134,12 @@ class Value {
const std::string& GetName() const; const std::string& GetName() const;
void SetName(std::string n); void SetName(std::string n);
bool IsVoid() const; bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const; bool IsInt32() const;
bool IsFloat() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunctionType() const;
bool IsPtrInt32() const; bool IsPtrInt32() const;
bool IsConstant() const; bool IsConstant() const;
bool IsInstruction() const; bool IsInstruction() const;
@ -151,8 +172,53 @@ class ConstantInt : public ConstantValue {
int value_{}; int value_{};
}; };
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
private:
std::vector<ConstantValue*> elements_;
};
// 后续还需要扩展更多指令类型。 // 后续还需要扩展更多指令类型。
enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; enum class Opcode {
Add,
Sub,
Mul,
SDiv,
SRem,
FAdd,
FSub,
FMul,
FDiv,
Alloca,
Load,
Store,
Ret,
Br,
CondBr,
ICmp,
FCmp,
Call,
Phi,
Gep,
SIToFP,
FPToSI,
ZExt
};
enum class ICmpPredicate { Eq, Ne, Slt, Sle, Sgt, Sge };
enum class FCmpPredicate { Oeq, One, Olt, Ole, Ogt, Oge };
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。 // 当前实现中只有 Instruction 继承自 User。
@ -178,6 +244,20 @@ class GlobalValue : public User {
GlobalValue(std::shared_ptr<Type> ty, std::string name); GlobalValue(std::shared_ptr<Type> ty, std::string name);
}; };
class GlobalVariable : public GlobalValue {
public:
GlobalVariable(std::shared_ptr<Type> value_ty, std::string name,
ConstantValue* init, bool is_const);
const std::shared_ptr<Type>& GetValueType() const;
ConstantValue* GetInitializer() const;
bool IsConst() const;
private:
std::shared_ptr<Type> value_type_;
ConstantValue* initializer_ = nullptr;
bool is_const_ = false;
};
class Instruction : public User { class Instruction : public User {
public: public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = ""); Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
@ -199,15 +279,64 @@ class BinaryInst : public Instruction {
Value* GetRhs() const; Value* GetRhs() const;
}; };
class ICmpInst : public Instruction {
public:
ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name);
ICmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
ICmpPredicate pred_;
};
class FCmpInst : public Instruction {
public:
FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name);
FCmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
FCmpPredicate pred_;
};
class CastInst : public Instruction {
public:
CastInst(Opcode op, std::shared_ptr<Type> dst_ty, Value* src,
std::string name);
Value* GetValue() const;
};
class BranchInst : public Instruction {
public:
explicit BranchInst(BasicBlock* dest);
BasicBlock* GetDest() const;
};
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* true_dest, BasicBlock* false_dest);
Value* GetCond() const;
BasicBlock* GetTrueDest() const;
BasicBlock* GetFalseDest() const;
};
class ReturnInst : public Instruction { class ReturnInst : public Instruction {
public: public:
explicit ReturnInst(std::shared_ptr<Type> void_ty);
ReturnInst(std::shared_ptr<Type> void_ty, Value* val); ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
bool HasReturnValue() const;
Value* GetValue() const; Value* GetValue() const;
}; };
class AllocaInst : public Instruction { class AllocaInst : public Instruction {
public: public:
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name); AllocaInst(std::shared_ptr<Type> allocated_ty, std::string name);
const std::shared_ptr<Type>& GetAllocatedType() const;
private:
std::shared_ptr<Type> allocated_type_;
}; };
class LoadInst : public Instruction { class LoadInst : public Instruction {
@ -223,8 +352,41 @@ class StoreInst : public Instruction {
Value* GetPtr() const; Value* GetPtr() const;
}; };
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 class CallInst : public Instruction {
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 public:
CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
std::vector<Value*> args, std::string name);
Value* GetCallee() const;
const std::vector<Value*>& GetArgs() const { return args_; }
private:
std::vector<Value*> args_;
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* value, BasicBlock* block);
const std::vector<Value*>& GetIncomingValues() const;
const std::vector<BasicBlock*>& GetIncomingBlocks() const;
private:
std::vector<Value*> incoming_values_;
std::vector<BasicBlock*> incoming_blocks_;
};
class GepInst : public Instruction {
public:
GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name);
Value* GetBasePtr() const;
const std::vector<Value*>& GetIndices() const { return indices_; }
private:
std::vector<Value*> indices_;
};
// BasicBlock 已纳入 Value 体系,使用 label type。
class BasicBlock : public Value { class BasicBlock : public Value {
public: public:
explicit BasicBlock(std::string name); explicit BasicBlock(std::string name);
@ -234,6 +396,8 @@ class BasicBlock : public Value {
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const; const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const; const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const; const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
template <typename T, typename... Args> template <typename T, typename... Args>
T* Append(Args&&... args) { T* Append(Args&&... args) {
if (HasTerminator()) { if (HasTerminator()) {
@ -244,6 +408,16 @@ class BasicBlock : public Value {
auto* ptr = inst.get(); auto* ptr = inst.get();
ptr->SetParent(this); ptr->SetParent(this);
instructions_.push_back(std::move(inst)); instructions_.push_back(std::move(inst));
LinkSuccessorsIfNeeded(ptr);
return ptr;
}
template <typename T, typename... Args>
T* Prepend(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr; return ptr;
} }
@ -252,6 +426,7 @@ class BasicBlock : public Value {
std::vector<std::unique_ptr<Instruction>> instructions_; std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_; std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_; std::vector<BasicBlock*> successors_;
void LinkSuccessorsIfNeeded(Instruction* inst);
}; };
// Function 当前也采用了最小实现。 // Function 当前也采用了最小实现。
@ -262,16 +437,34 @@ class BasicBlock : public Value {
// 形参和调用,通常需要引入专门的函数类型表示。 // 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value { class Function : public Value {
public: public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。 Function(std::string name, std::shared_ptr<Type> func_type,
Function(std::string name, std::shared_ptr<Type> ret_type); bool is_declaration = false);
BasicBlock* CreateBlock(const std::string& name); BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry(); BasicBlock* GetEntry();
const BasicBlock* GetEntry() const; const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const; const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
size_t GetNumArgs() const;
Argument* GetArg(size_t index);
std::shared_ptr<Type> GetFunctionType() const;
std::shared_ptr<Type> GetReturnType() const;
bool IsDeclaration() const;
private: private:
BasicBlock* entry_ = nullptr; BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_; std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<std::unique_ptr<Argument>> args_;
std::unordered_map<std::string, size_t> block_name_counts_;
bool is_declaration_ = false;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> ty, std::string name, size_t index);
size_t GetIndex() const { return index_; }
private:
size_t index_ = 0;
}; };
class Module { class Module {
@ -282,11 +475,20 @@ class Module {
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name, Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type); std::shared_ptr<Type> ret_type);
Function* CreateFunctionWithType(const std::string& name,
std::shared_ptr<Type> func_type);
Function* CreateFunctionDecl(const std::string& name,
std::shared_ptr<Type> func_type);
GlobalVariable* CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> value_type,
ConstantValue* init, bool is_const);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const; const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobals() const;
private: private:
Context context_; Context context_;
std::vector<std::unique_ptr<Function>> functions_; std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalVariable>> globals_;
}; };
class IRBuilder { class IRBuilder {
@ -297,13 +499,44 @@ class IRBuilder {
// 构造常量、二元运算、返回指令的最小集合。 // 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v); ConstantInt* CreateConstInt(int v);
ConstantInt* CreateConstBool(bool v);
ConstantFloat* CreateConstFloat(float v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name); const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSRem(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name);
ICmpInst* CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
FCmpInst* CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
CastInst* CreateSIToFP(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateFPToSI(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateZExt(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr); StoreInst* CreateStore(Value* val, Value* ptr);
GepInst* CreateGep(Value* base_ptr, std::vector<Value*> indices,
const std::string& name);
CallInst* CreateCall(Value* callee, std::vector<Value*> args,
const std::string& name);
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
BranchInst* CreateBr(BasicBlock* dest);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest);
ReturnInst* CreateRet(Value* v); ReturnInst* CreateRet(Value* v);
ReturnInst* CreateRetVoid();
private: private:
Context& ctx_; Context& ctx_;

@ -1,57 +1,113 @@
// 将语法树翻译为 IR。
// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。
#pragma once #pragma once
#include <any> #include <any>
#include <memory> #include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "sem/Sema.h" #include "sem/Sema.h"
namespace ir {
class Module;
class Function;
class IRBuilder;
class Value;
}
class IRGenImpl final : public SysYBaseVisitor { class IRGenImpl final : public SysYBaseVisitor {
public: public:
IRGenImpl(ir::Module& module, const SemanticContext& sema); IRGenImpl(ir::Module& module, const SemanticContext& sema);
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlock(SysYParser::BlockContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
private: std::any visitExp(SysYParser::ExpContext* ctx) override;
enum class BlockFlow { std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
Continue, std::any visitMulExp(SysYParser::MulExpContext* ctx) override; // 新增
Terminated, std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; // 新增
}; std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override;
std::any visitConstInitVal(SysYParser::ConstInitValContext* ctx) override;
std::any visitInitVal(SysYParser::InitValContext* ctx) override;
private:
ir::Value* EvalExp(SysYParser::ExpContext* ctx);
ir::Value* EvalCondValue(SysYParser::CondContext* ctx);
void EmitCondBr(SysYParser::CondContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
void EmitLOrCond(SysYParser::LOrExpContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
void EmitLAndCond(SysYParser::LAndExpContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
ir::Value* EmitRelEq(SysYParser::RelExpContext* ctx);
ir::Value* EmitEq(SysYParser::EqExpContext* ctx);
ir::Value* CastToFloat(ir::Value* v);
ir::Value* CastToInt(ir::Value* v);
ir::Value* MakeBool(ir::Value* v);
ir::Value* GetLValAddress(SysYParser::LValContext* ctx);
ir::Value* LoadIfNeeded(ir::Value* addr_or_val, const TypeDesc& ty,
bool as_rvalue);
std::shared_ptr<ir::Type> ToIRType(const TypeDesc& ty);
std::shared_ptr<ir::Type> ToIRParamType(const TypeDesc& ty);
ir::Value* DefaultValue(const TypeDesc& ty);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name);
void InitArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::InitValContext* init);
void InitConstArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::ConstInitValContext* init);
size_t FillArrayValues(const TypeDesc& ty, SysYParser::InitValContext* init,
std::vector<ir::Value*>& values, size_t base,
size_t idx, size_t dim);
size_t FillConstArrayValues(const TypeDesc& ty,
SysYParser::ConstInitValContext* init,
std::vector<ir::Value*>& values, size_t base,
size_t idx, size_t dim);
size_t ArrayStride(const TypeDesc& ty, size_t dim) const;
size_t ArrayTotalSize(const TypeDesc& ty) const;
void PushLoop(ir::BasicBlock* break_bb, ir::BasicBlock* cont_bb);
void PopLoop();
ir::BasicBlock* CurrentBreak() const;
ir::BasicBlock* CurrentContinue() const;
ir::ConstantValue* EvalConstScalar(SysYParser::ExpContext* ctx);
ir::ConstantValue* EvalConstScalar(SysYParser::ConstExpContext* ctx);
ir::ConstantValue* EvalConstAdd(SysYParser::AddExpContext* ctx);
ir::ConstantValue* EvalConstMul(SysYParser::MulExpContext* ctx);
ir::ConstantValue* EvalConstUnary(SysYParser::UnaryExpContext* ctx);
ir::ConstantValue* EvalConstPrimary(SysYParser::PrimaryExpContext* ctx);
ir::ConstantValue* EvalConstNumber(SysYParser::NumberContext* ctx);
ir::ConstantValue* EvalConstLVal(SysYParser::LValContext* ctx);
size_t InitGlobalArray(const TypeDesc& ty, SysYParser::InitValContext* init,
std::vector<ir::ConstantValue*>& values, size_t base,
size_t idx, size_t dim);
size_t InitGlobalConstArray(const TypeDesc& ty,
SysYParser::ConstInitValContext* init,
std::vector<ir::ConstantValue*>& values,
size_t base, size_t idx, size_t dim);
enum class BlockFlow { Continue, Terminated };
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::Module& module_; ir::Module& module_;
const SemanticContext& sema_; const SemanticContext& sema_;
ir::Function* func_; ir::Function* func_;
ir::IRBuilder builder_; ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 std::unordered_map<const SysYParser::VarDefContext*, ir::Value*> var_storage_;
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_; std::unordered_map<const SysYParser::ConstDefContext*, ir::Value*> const_storage_;
std::unordered_map<const SysYParser::FuncFParamContext*, ir::Value*> param_storage_;
std::unordered_map<const SysYParser::FuncDefContext*, ir::Function*> func_map_;
std::unordered_map<const SysYParser::VarDefContext*, ir::GlobalVariable*>
global_var_storage_;
std::unordered_map<const SysYParser::ConstDefContext*, ir::GlobalVariable*>
global_const_storage_;
std::vector<std::pair<ir::BasicBlock*, ir::BasicBlock*>> loop_stack_;
}; };
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

@ -1,30 +1,91 @@
// 基于语法树的语义检查与名称绑定 // 基于语法树的语义检查与名称绑定Lab2 扩展)
#pragma once #pragma once
#include <optional>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
#include "sem/SymbolTable.h"
struct FuncTypeDesc {
TypeDesc ret;
std::vector<TypeDesc> params;
};
struct BoundDecl {
enum class Kind { Var, Const, Param } kind = Kind::Var;
SysYParser::VarDefContext* var_decl = nullptr;
SysYParser::ConstDefContext* const_decl = nullptr;
SysYParser::FuncFParamContext* param_decl = nullptr;
};
class SemanticContext { class SemanticContext {
public: public:
void BindVarUse(SysYParser::VarContext* use, void BindVarUse(SysYParser::LValContext* use, BoundDecl decl) {
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl; var_uses_[use] = decl;
} }
SysYParser::VarDefContext* ResolveVarUse( BoundDecl ResolveVarUse(const SysYParser::LValContext* use) const {
const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use); auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second; return it == var_uses_.end() ? BoundDecl{} : it->second;
}
void RegisterVarDecl(SysYParser::VarDefContext* decl, TypeDesc ty) {
var_types_[decl] = std::move(ty);
}
void RegisterConstDecl(SysYParser::ConstDefContext* decl, TypeDesc ty) {
const_types_[decl] = std::move(ty);
}
void RegisterParam(SysYParser::FuncFParamContext* decl, TypeDesc ty) {
param_types_[decl] = std::move(ty);
}
void RegisterFunc(SysYParser::FuncDefContext* decl, FuncTypeDesc ty) {
func_types_[decl] = std::move(ty);
}
const TypeDesc* GetVarType(const SysYParser::VarDefContext* decl) const {
auto it = var_types_.find(decl);
return it == var_types_.end() ? nullptr : &it->second;
}
const TypeDesc* GetConstType(const SysYParser::ConstDefContext* decl) const {
auto it = const_types_.find(decl);
return it == const_types_.end() ? nullptr : &it->second;
}
const TypeDesc* GetParamType(const SysYParser::FuncFParamContext* decl) const {
auto it = param_types_.find(decl);
return it == param_types_.end() ? nullptr : &it->second;
}
const FuncTypeDesc* GetFuncType(const SysYParser::FuncDefContext* decl) const {
auto it = func_types_.find(decl);
return it == func_types_.end() ? nullptr : &it->second;
}
void BindFuncCall(SysYParser::UnaryExpContext* call,
SysYParser::FuncDefContext* decl) {
func_calls_[call] = decl;
}
SysYParser::FuncDefContext* ResolveFuncCall(
const SysYParser::UnaryExpContext* call) const {
auto it = func_calls_.find(call);
return it == func_calls_.end() ? nullptr : it->second;
} }
private: private:
std::unordered_map<const SysYParser::VarContext*, std::unordered_map<const SysYParser::LValContext*, BoundDecl> var_uses_;
SysYParser::VarDefContext*> std::unordered_map<const SysYParser::VarDefContext*, TypeDesc> var_types_;
var_uses_; std::unordered_map<const SysYParser::ConstDefContext*, TypeDesc> const_types_;
std::unordered_map<const SysYParser::FuncFParamContext*, TypeDesc> param_types_;
std::unordered_map<const SysYParser::FuncDefContext*, FuncTypeDesc> func_types_;
std::unordered_map<const SysYParser::UnaryExpContext*, SysYParser::FuncDefContext*>
func_calls_;
}; };
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,17 +1,42 @@
// 极简符号表:记录局部变量定义 // 符号表:记录局部变量/常量/参数定义。
#pragma once #pragma once
#include <optional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
enum class BaseTypeKind { Int, Float, Void };
struct TypeDesc {
BaseTypeKind base = BaseTypeKind::Int;
std::vector<int> dims; // 为空表示标量;数组维度允许首维为 -1 表示形参不定长
bool is_const = false;
};
enum class SymbolKind { Var, Const, Param };
struct SymbolEntry {
SymbolKind kind = SymbolKind::Var;
SysYParser::VarDefContext* var_decl = nullptr;
SysYParser::ConstDefContext* const_decl = nullptr;
SysYParser::FuncFParamContext* param_decl = nullptr;
TypeDesc type; // 记录类型信息
bool is_const = false;
std::optional<int> const_value;
};
class SymbolTable { class SymbolTable {
public: public:
void Add(const std::string& name, SysYParser::VarDefContext* decl); void EnterScope();
bool Contains(const std::string& name) const; void ExitScope();
SysYParser::VarDefContext* Lookup(const std::string& name) const;
bool ContainsInCurrentScope(const std::string& name) const;
void Add(const std::string& name, const SymbolEntry& entry);
const SymbolEntry* Lookup(const std::string& name) const;
private: private:
std::unordered_map<std::string, SysYParser::VarDefContext*> table_; std::vector<std::unordered_map<std::string, SymbolEntry>> scopes_;
}; };

@ -0,0 +1,19 @@
#!/usr/bin/env bash
set -euo pipefail
# Reconfigure with IR pipeline enabled, build, then run Lab2 test script.
RESULT_FILE="test/test_result/run_lab2_result.log"
mkdir -p "$(dirname \"$RESULT_FILE\")"
: > "$RESULT_FILE"
{
echo "[run_lab2] start: $(date '+%Y-%m-%d %H:%M:%S')"
echo "[run_lab2] logging to: $RESULT_FILE"
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF
cmake --build build -j "$(nproc)"
CASE_DIR=test/test_case bash scripts/test_lab2.sh
echo "[run_lab2] end: $(date '+%Y-%m-%d %H:%M:%S')"
} 2>&1 | tee "$RESULT_FILE"

@ -0,0 +1,121 @@
#!/usr/bin/env bash
set -euo pipefail
# Lab2 quick/full verification helper.
# Usage:
# bash scripts/test_lab2.sh
# Optional env vars:
# COMPILER=./build/bin/compiler
# CASE_DIR=test/test_case/functional
# OUT_DIR=test/test_result/lab2_ir
# LOG_FILE=test/test_result/lab2_test.log
COMPILER="${COMPILER:-./build/bin/compiler}"
CASE_DIR="${CASE_DIR:-test/test_case/functional}"
OUT_DIR="${OUT_DIR:-test/test_result/lab2_ir}"
LOG_FILE="${LOG_FILE:-test/test_result/lab2_test.log}"
VERIFY_SCRIPT="./scripts/verify_ir.sh"
if [[ ! -x "$COMPILER" ]]; then
echo "compiler not found or not executable: $COMPILER" >&2
echo "build first:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
exit 1
fi
if [[ ! -x "$VERIFY_SCRIPT" ]]; then
echo "verify script not found or not executable: $VERIFY_SCRIPT" >&2
exit 1
fi
if [[ ! -d "$CASE_DIR" ]]; then
echo "case dir not found: $CASE_DIR" >&2
exit 1
fi
mkdir -p "$OUT_DIR"
# Preflight: ensure compiler supports IR emission (not parse-only build).
probe_input="$CASE_DIR/simple_add.sy"
probe_err="$OUT_DIR/.lab2_probe.err"
if [[ -f "$probe_input" ]]; then
set +e
"$COMPILER" --emit-ir "$probe_input" > /dev/null 2> "$probe_err"
probe_rc=$?
set -e
if [[ $probe_rc -ne 0 ]] && grep -Eiq "parse-only|IR/汇编输出已禁用" "$probe_err"; then
echo "detected parse-only compiler build, cannot run Lab2 IR tests." >&2
echo "rebuild with IR enabled:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
rm -f "$probe_err"
exit 2
fi
rm -f "$probe_err"
fi
mkdir -p "$(dirname "$LOG_FILE")"
: > "$LOG_FILE"
echo "[Lab2] start test" | tee -a "$LOG_FILE"
echo "compiler : $COMPILER" | tee -a "$LOG_FILE"
echo "cases : $CASE_DIR" | tee -a "$LOG_FILE"
echo "out dir : $OUT_DIR" | tee -a "$LOG_FILE"
echo "[Step 1] single sample check: simple_add.sy" | tee -a "$LOG_FILE"
sample_input="$(find "$CASE_DIR" -type f -name "simple_add.sy" -print -quit)"
if [[ -z "$sample_input" ]]; then
echo "single sample: FAIL (simple_add.sy not found under $CASE_DIR)" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
if "$VERIFY_SCRIPT" "$sample_input" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
echo "single sample: PASS" | tee -a "$LOG_FILE"
else
echo "single sample: FAIL" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
echo "[Step 2] full functional regression" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
failed_list=()
while IFS= read -r -d '' sy; do
total=$((total + 1))
name="$(basename "$sy")"
echo "[$total] $name" | tee -a "$LOG_FILE"
if "$VERIFY_SCRIPT" "$sy" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
pass=$((pass + 1))
echo " PASS" | tee -a "$LOG_FILE"
else
fail=$((fail + 1))
failed_list+=("$sy")
echo " FAIL" | tee -a "$LOG_FILE"
fi
done < <(find "$CASE_DIR" -type f -name "*.sy" -print0 | sort -z)
echo "" | tee -a "$LOG_FILE"
echo "[Summary]" | tee -a "$LOG_FILE"
echo "total: $total" | tee -a "$LOG_FILE"
echo "pass : $pass" | tee -a "$LOG_FILE"
echo "fail : $fail" | tee -a "$LOG_FILE"
if [[ $fail -gt 0 ]]; then
echo "failed cases:" | tee -a "$LOG_FILE"
for f in "${failed_list[@]}"; do
echo " - $f" | tee -a "$LOG_FILE"
done
echo "Lab2 target is not fully met yet." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"
exit 1
fi
echo "All functional cases passed. Lab2 target (functional regression) is met." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"

@ -1,78 +0,0 @@
#!/bin/bash
# ================================================
# SysY 编译器 Lab1 批量解析测试脚本
# 文件名scripts/test_parse.sh
# 适用环境Arch Linuxbash 原生支持,无需额外安装)
# 功能:
# - 遍历 test/test_case 下所有 .sy 文件functional + performance
# - 执行 --emit-parse-tree 检查是否能成功解析
# - 输出简洁的 PASS/FAIL 结果 + 统计
# - 错误文件会自动打印最后 10 行报错信息(方便调试)
# - 所有结果保存到 test/test_result/parse_test.log
# ================================================
set -u # 遇到未定义变量直接报错
# ================== 配置 ==================
COMPILER="./build/bin/compiler"
TEST_DIR="test/test_case"
LOG_FILE="test/test_result/parse_test.log"
MAX_ERROR_LINES=10
# 检查编译器是否存在
if [[ ! -x "$COMPILER" ]]; then
echo "❌ 错误:找不到编译器 $COMPILER"
echo " 请先执行 Lab1 构建命令:"
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON"
echo " cmake --build build -j \"\$(nproc)\""
exit 1
fi
# 创建日志目录(如果不存在)
mkdir -p "$(dirname "$LOG_FILE")"
> "$LOG_FILE" # 清空日志
echo "开始 Lab1 批量语法树测试..." | tee -a "$LOG_FILE"
echo "测试目录:$TEST_DIR" | tee -a "$LOG_FILE"
echo "编译器:$COMPILER" | tee -a "$LOG_FILE"
echo "========================================" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
# 遍历所有 .sy 文件(支持子目录)
while IFS= read -r -d '' sy_file; do
((total++))
echo -n "[$total] 测试: $sy_file ... " | tee -a "$LOG_FILE"
# 执行解析(把输出丢到 /dev/null防止刷屏
if "$COMPILER" --emit-parse-tree "$sy_file" > /dev/null 2>&1; then
echo "✅PASS" | tee -a "$LOG_FILE"
((pass++))
else
echo "FAIL" | tee -a "$LOG_FILE"
((fail++))
# 打印错误信息到日志(最后几行)
echo " └── 错误详情(最后 $MAX_ERROR_LINES 行):" >> "$LOG_FILE"
"$COMPILER" --emit-parse-tree "$sy_file" 2>&1 | tail -n "$MAX_ERROR_LINES" >> "$LOG_FILE"
echo "" >> "$LOG_FILE"
fi
done < <(find "$TEST_DIR" -name "*.sy" -print0 | sort -z)
# ================== 总结 ==================
echo "========================================" | tee -a "$LOG_FILE"
echo "测试完成!" | tee -a "$LOG_FILE"
echo "总文件数 : $total" | tee -a "$LOG_FILE"
echo "通过 : $pass" | tee -a "$LOG_FILE"
echo "失败 : $fail" | tee -a "$LOG_FILE"
if [[ $fail -eq 0 ]]; then
echo "恭喜Lab1 语法树构建全部通过!可以进入 Lab2 啦~" | tee -a "$LOG_FILE"
else
echo "$fail 个文件解析失败,请检查 SysY.g4 或报错日志" | tee -a "$LOG_FILE"
echo " 日志文件:$LOG_FILE" | tee -a "$LOG_FILE"
fi
echo "========================================" | tee -a "$LOG_FILE"

@ -60,7 +60,7 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj" llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe" clang "$obj" sylib/sylib.c -o "$exe"
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then

@ -1,98 +0,0 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY;
/*===-------------------------------------------===*/
/* Lexer rules */
/*===-------------------------------------------===*/
INT: 'int';
RETURN: 'return';
ASSIGN: '=';
ADD: '+';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
SEMICOLON: ';';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*===-------------------------------------------===*/
/* Syntax rules */
/*===-------------------------------------------===*/
compUnit
: funcDef EOF
;
decl
: btype varDef SEMICOLON
;
btype
: INT
;
varDef
: lValue (ASSIGN initValue)?
;
initValue
: exp
;
funcDef
: funcType ID LPAREN RPAREN blockStmt
;
funcType
: INT
;
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
stmt
: returnStmt
;
returnStmt
: RETURN exp SEMICOLON
;
exp
: LPAREN exp RPAREN # parenExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
;
var
: ID
;
lValue
: ID
;
number
: ILITERAL
;

@ -1,98 +0,0 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY;
/*===-------------------------------------------===*/
/* Lexer rules */
/*===-------------------------------------------===*/
INT: 'int';
RETURN: 'return';
ASSIGN: '=';
ADD: '+';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
SEMICOLON: ';';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*===-------------------------------------------===*/
/* Syntax rules */
/*===-------------------------------------------===*/
compUnit
: funcDef EOF
;
decl
: btype varDef SEMICOLON
;
btype
: INT
;
varDef
: lValue (ASSIGN initValue)?
;
initValue
: exp
;
funcDef
: funcType ID LPAREN RPAREN blockStmt
;
funcType
: INT
;
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
stmt
: returnStmt
;
returnStmt
: RETURN exp SEMICOLON
;
exp
: LPAREN exp RPAREN # parenExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
;
var
: ID
;
lValue
: ID
;
number
: ILITERAL
;

@ -1,17 +0,0 @@
int main() {
int a = 1, b = 2;
int c;
c = a + b * 3;
if (c > 5) {
c = c - 1;
} else {
c = c + 1;
}
while (c < 10) {
c = c + 1;
}
return c;
}

@ -1,98 +0,0 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY;
/*===-------------------------------------------===*/
/* Lexer rules */
/*===-------------------------------------------===*/
INT: 'int';
RETURN: 'return';
ASSIGN: '=';
ADD: '+';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
SEMICOLON: ';';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*===-------------------------------------------===*/
/* Syntax rules */
/*===-------------------------------------------===*/
compUnit
: funcDef EOF
;
decl
: btype varDef SEMICOLON
;
btype
: INT
;
varDef
: lValue (ASSIGN initValue)?
;
initValue
: exp
;
funcDef
: funcType ID LPAREN RPAREN blockStmt
;
funcType
: INT
;
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
stmt
: returnStmt
;
returnStmt
: RETURN exp SEMICOLON
;
exp
: LPAREN exp RPAREN # parenExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
;
var
: ID
;
lValue
: ID
;
number
: ILITERAL
;

@ -1,187 +0,0 @@
grammar SysY;
/*====================*/
/* Lexer rules */
/*====================*/
INT: 'int';
RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
ASSIGN: '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
EQ: '==';
NEQ: '!=';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
SEMICOLON: ';';
COMMA: ',';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*====================*/
/* Syntax rules */
/*====================*/
compUnit
: (decl | funcDef)* EOF
;
/* ===== 声明 ===== */
decl
: btype varDefList SEMICOLON
;
btype
: INT
;
varDefList
: varDef (COMMA varDef)*
;
varDef
: ID (LBRACK number RBRACK)? (ASSIGN initValue)?
;
initValue
: exp
;
/* ===== 函数 ===== */
funcDef
: funcType ID LPAREN funcParams? RPAREN blockStmt
;
funcType
: INT
;
funcParams
: funcParam (COMMA funcParam)*
;
funcParam
: btype ID
;
/* ===== 语句块 ===== */
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
/* ===== 语句 ===== */
stmt
: returnStmt
| assignStmt
| expStmt
| blockStmt
| ifStmt
| whileStmt
;
returnStmt
: RETURN exp SEMICOLON
;
assignStmt
: lValue ASSIGN exp SEMICOLON
;
expStmt
: exp? SEMICOLON
;
ifStmt
: IF LPAREN exp RPAREN stmt (ELSE stmt)?
;
whileStmt
: WHILE LPAREN exp RPAREN stmt
;
/* ===== 表达式 ===== */
exp
: logicalExp
;
/* 逻辑(先简化成关系表达式) */
logicalExp
: relationalExp
;
/* 比较 */
relationalExp
: additiveExp ( (LT | GT | LE | GE | EQ | NEQ) additiveExp )*
;
/* 加减 */
additiveExp
: multiplicativeExp ( (ADD | SUB) multiplicativeExp )*
;
/* 乘除 */
multiplicativeExp
: unaryExp ( (MUL | DIV) unaryExp )*
;
/* 一元 */
unaryExp
: primaryExp
| ADD unaryExp
| SUB unaryExp
;
/* 基本表达式 */
primaryExp
: LPAREN exp RPAREN
| number
| lValue
| funcCall
;
/* 函数调用 */
funcCall
: ID LPAREN (exp (COMMA exp)*)? RPAREN
;
/* ===== 基础 ===== */
lValue
: ID (LBRACK exp RBRACK)?
;
number
: ILITERAL
;

@ -1,229 +0,0 @@
grammar SysY;
/*====================*/
/* Lexer rules */
/*====================*/
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
ASSIGN: '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
EQ: '==';
NEQ: '!=';
AND: '&&';
OR: '||';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
SEMICOLON: ';';
COMMA: ',';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
/* 浮点字面量要放在整数字面量前面 */
FLOATLITERAL
: [0-9]+ '.' [0-9]* ([eE] [+\-]? [0-9]+)?
| '.' [0-9]+ ([eE] [+\-]? [0-9]+)?
| [0-9]+ [eE] [+\-]? [0-9]+
;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*====================*/
/* Parser rules */
/*====================*/
compUnit
: (decl | funcDef)* EOF
;
/* ===== 声明 ===== */
decl
: constDecl
| varDecl
;
constDecl
: CONST btype constDef (COMMA constDef)* SEMICOLON
;
varDecl
: btype varDef (COMMA varDef)* SEMICOLON
;
btype
: INT
| FLOAT
;
constDef
: ID constExpArrayDims ASSIGN constInitVal
;
varDef
: ID arrayDims? (ASSIGN initVal)?
;
/* 变量定义时数组维度一般要求是 exp */
arrayDims
: (LBRACK exp RBRACK)+
;
/* const 定义也支持多维 */
constExpArrayDims
: (LBRACK constExp RBRACK)+
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
;
initVal
: exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
;
/* ===== 函数 ===== */
funcDef
: funcType ID LPAREN funcFParams? RPAREN blockStmt
;
funcType
: INT
| FLOAT
| VOID
;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: btype ID
| btype ID LBRACK RBRACK (LBRACK exp RBRACK)*
;
/* ===== 语句块 ===== */
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
/* ===== 语句 ===== */
stmt
: lVal ASSIGN exp SEMICOLON # assignStatement
| exp? SEMICOLON # expStatement
| blockStmt # blockStatement
| IF LPAREN cond RPAREN stmt (ELSE stmt)? # ifStatement
| WHILE LPAREN cond RPAREN stmt # whileStatement
| BREAK SEMICOLON # breakStatement
| CONTINUE SEMICOLON # continueStatement
| RETURN exp? SEMICOLON # returnStatement
;
/* ===== 表达式 ===== */
exp
: addExp
;
cond
: lOrExp
;
lVal
: ID (LBRACK exp RBRACK)*
;
primaryExp
: LPAREN exp RPAREN
| lVal
| number
;
number
: ILITERAL
| FLOATLITERAL
;
funcRParams
: exp (COMMA exp)*
;
unaryExp
: primaryExp
| ID LPAREN funcRParams? RPAREN
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NEQ) relExp)*
;
lAndExp
: eqExp (AND eqExp)*
;
lOrExp
: lAndExp (OR lAndExp)*
;
constExp
: addExp
;

@ -1,229 +0,0 @@
grammar SysY;
/*====================*/
/* Lexer rules */
/*====================*/
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
ASSIGN: '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
EQ: '==';
NEQ: '!=';
AND: '&&';
OR: '||';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
SEMICOLON: ';';
COMMA: ',';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
/* 浮点字面量要放在整数字面量前面 */
FLOATLITERAL
: [0-9]+ '.' [0-9]* ([eE] [+\-]? [0-9]+)?
| '.' [0-9]+ ([eE] [+\-]? [0-9]+)?
| [0-9]+ [eE] [+\-]? [0-9]+
;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*====================*/
/* Parser rules */
/*====================*/
compUnit
: (decl | funcDef)* EOF
;
/* ===== 声明 ===== */
decl
: constDecl
| varDecl
;
constDecl
: CONST btype constDef (COMMA constDef)* SEMICOLON
;
varDecl
: btype varDef (COMMA varDef)* SEMICOLON
;
btype
: INT
| FLOAT
;
constDef
: ID constExpArrayDims ASSIGN constInitVal
;
varDef
: ID arrayDims? (ASSIGN initVal)?
;
/* 变量定义时数组维度一般要求是 exp */
arrayDims
: (LBRACK exp RBRACK)+
;
/* const 定义也支持多维 */
constExpArrayDims
: (LBRACK constExp RBRACK)+
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
;
initVal
: exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
;
/* ===== 函数 ===== */
funcDef
: funcType ID LPAREN funcFParams? RPAREN blockStmt
;
funcType
: INT
| FLOAT
| VOID
;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: btype ID
| btype ID LBRACK RBRACK (LBRACK exp RBRACK)*
;
/* ===== 语句块 ===== */
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
/* ===== 语句 ===== */
stmt
: lVal ASSIGN exp SEMICOLON # assignStatement
| exp? SEMICOLON # expStatement
| blockStmt # blockStatement
| IF LPAREN cond RPAREN stmt (ELSE stmt)? # ifStatement
| WHILE LPAREN cond RPAREN stmt # whileStatement
| BREAK SEMICOLON # breakStatement
| CONTINUE SEMICOLON # continueStatement
| RETURN exp? SEMICOLON # returnStatement
;
/* ===== 表达式 ===== */
exp
: addExp
;
cond
: lOrExp
;
lVal
: ID (LBRACK exp RBRACK)*
;
primaryExp
: LPAREN exp RPAREN
| lVal
| number
;
number
: ILITERAL
| FLOATLITERAL
;
funcRParams
: exp (COMMA exp)*
;
unaryExp
: primaryExp
| ID LPAREN funcRParams? RPAREN
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NEQ) relExp)*
;
lAndExp
: eqExp (AND eqExp)*
;
lOrExp
: lAndExp (OR lAndExp)*
;
constExp
: addExp
;

@ -1,229 +0,0 @@
grammar SysY;
/*====================*/
/* Lexer rules */
/*====================*/
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
ASSIGN: '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
EQ: '==';
NEQ: '!=';
AND: '&&';
OR: '||';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
SEMICOLON: ';';
COMMA: ',';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
/* 浮点字面量要放在整数字面量前面 */
FLOATLITERAL
: [0-9]+ '.' [0-9]* ([eE] [+\-]? [0-9]+)?
| '.' [0-9]+ ([eE] [+\-]? [0-9]+)?
| [0-9]+ [eE] [+\-]? [0-9]+
;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*====================*/
/* Parser rules */
/*====================*/
compUnit
: (decl | funcDef)* EOF
;
/* ===== 声明 ===== */
decl
: constDecl
| varDecl
;
constDecl
: CONST btype constDef (COMMA constDef)* SEMICOLON
;
varDecl
: btype varDef (COMMA varDef)* SEMICOLON
;
btype
: INT
| FLOAT
;
constDef
: ID constExpArrayDims ASSIGN constInitVal
;
varDef
: ID arrayDims? (ASSIGN initVal)?
;
/* 变量定义时数组维度一般要求是 exp */
arrayDims
: (LBRACK exp RBRACK)+
;
/* const 定义也支持多维 */
constExpArrayDims
: (LBRACK constExp RBRACK)+
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
;
initVal
: exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
;
/* ===== 函数 ===== */
funcDef
: funcType ID LPAREN funcFParams? RPAREN blockStmt
;
funcType
: INT
| FLOAT
| VOID
;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: btype ID
| btype ID LBRACK RBRACK (LBRACK exp RBRACK)*
;
/* ===== 语句块 ===== */
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
/* ===== 语句 ===== */
stmt
: lVal ASSIGN exp SEMICOLON # assignStatement
| exp? SEMICOLON # expStatement
| blockStmt # blockStatement
| IF LPAREN cond RPAREN stmt (ELSE stmt)? # ifStatement
| WHILE LPAREN cond RPAREN stmt # whileStatement
| BREAK SEMICOLON # breakStatement
| CONTINUE SEMICOLON # continueStatement
| RETURN exp? SEMICOLON # returnStatement
;
/* ===== 表达式 ===== */
exp
: addExp
;
cond
: lOrExp
;
lVal
: ID (LBRACK exp RBRACK)*
;
primaryExp
: LPAREN exp RPAREN
| lVal
| number
;
number
: ILITERAL
| FLOATLITERAL
;
funcRParams
: exp (COMMA exp)*
;
unaryExp
: primaryExp
| ID LPAREN funcRParams? RPAREN
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NEQ) relExp)*
;
lAndExp
: eqExp (AND eqExp)*
;
lOrExp
: lAndExp (OR lAndExp)*
;
constExp
: addExp
;

@ -1,229 +0,0 @@
grammar SysY;
/*====================*/
/* Lexer rules */
/*====================*/
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
ASSIGN: '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
EQ: '==';
NEQ: '!=';
AND: '&&';
OR: '||';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
SEMICOLON: ';';
COMMA: ',';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
/* 浮点字面量要放在整数字面量前面 */
FLOATLITERAL
: [0-9]+ '.' [0-9]* ([eE] [+\-]? [0-9]+)?
| '.' [0-9]+ ([eE] [+\-]? [0-9]+)?
| [0-9]+ [eE] [+\-]? [0-9]+
;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*====================*/
/* Parser rules */
/*====================*/
compUnit
: (decl | funcDef)* EOF
;
/* ===== 声明 ===== */
decl
: constDecl
| varDecl
;
constDecl
: CONST btype constDef (COMMA constDef)* SEMICOLON
;
varDecl
: btype varDef (COMMA varDef)* SEMICOLON
;
btype
: INT
| FLOAT
;
constDef
: ID constExpArrayDims ASSIGN constInitVal
;
varDef
: ID arrayDims? (ASSIGN initVal)?
;
/* 变量定义时数组维度一般要求是 exp */
arrayDims
: (LBRACK exp RBRACK)+
;
/* const 定义也支持多维 */
constExpArrayDims
: (LBRACK constExp RBRACK)+
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
;
initVal
: exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
;
/* ===== 函数 ===== */
funcDef
: funcType ID LPAREN funcFParams? RPAREN blockStmt
;
funcType
: INT
| FLOAT
| VOID
;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: btype ID
| btype ID LBRACK RBRACK (LBRACK exp RBRACK)*
;
/* ===== 语句块 ===== */
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
/* ===== 语句 ===== */
stmt
: lVal ASSIGN exp SEMICOLON # assignStatement
| exp? SEMICOLON # expStatement
| blockStmt # blockStatement
| IF LPAREN cond RPAREN stmt (ELSE stmt)? # ifStatement
| WHILE LPAREN cond RPAREN stmt # whileStatement
| BREAK SEMICOLON # breakStatement
| CONTINUE SEMICOLON # continueStatement
| RETURN exp? SEMICOLON # returnStatement
;
/* ===== 表达式 ===== */
exp
: addExp
;
cond
: lOrExp
;
lVal
: ID (LBRACK exp RBRACK)*
;
primaryExp
: LPAREN exp RPAREN
| lVal
| number
;
number
: ILITERAL
| FLOATLITERAL
;
funcRParams
: exp (COMMA exp)*
;
unaryExp
: primaryExp
| ID LPAREN funcRParams? RPAREN
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NEQ) relExp)*
;
lAndExp
: eqExp (AND eqExp)*
;
lOrExp
: lAndExp (OR lAndExp)*
;
constExp
: addExp
;

@ -1,260 +0,0 @@
grammar SysY;
/*====================*/
/* Lexer rules */
/*====================*/
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
ASSIGN: '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
EQ: '==';
NEQ: '!=';
AND: '&&';
OR: '||';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
SEMICOLON: ';';
COMMA: ',';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
/* 浮点字面量放在整数字面量前 */
FLOATLITERAL
: [0-9]+ '.' [0-9]* ([eE] [+\-]? [0-9]+)?
| '.' [0-9]+ ([eE] [+\-]? [0-9]+)?
| [0-9]+ [eE] [+\-]? [0-9]+
;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*====================*/
/* Parser rules */
/*====================*/
compUnit
: (decl | funcDef)* EOF
;
/* ===== 声明 ===== */
decl
: constDecl
| varDecl
;
constDecl
: CONST btype constDefList SEMICOLON
;
varDecl
: btype varDefList SEMICOLON
;
btype
: INT
| FLOAT
;
constDefList
: constDef (COMMA constDef)*
;
varDefList
: varDef (COMMA varDef)*
;
constDef
: ID (LBRACK constExp RBRACK)+ ASSIGN constInitVal
| ID ASSIGN constInitVal
;
varDef
: ID (LBRACK exp RBRACK)* (ASSIGN initVal)?
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
;
initVal
: exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
;
/* ===== 函数 ===== */
funcDef
: funcType ID LPAREN funcParams? RPAREN blockStmt
;
funcType
: INT
| FLOAT
| VOID
;
funcParams
: funcParam (COMMA funcParam)*
;
funcParam
: btype ID
| btype ID LBRACK RBRACK (LBRACK exp RBRACK)*
;
/* ===== 语句块 ===== */
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
/* ===== 语句 ===== */
stmt
: returnStmt
| assignStmt
| expStmt
| blockStmt
| ifStmt
| whileStmt
| breakStmt
| continueStmt
;
returnStmt
: RETURN exp? SEMICOLON
;
assignStmt
: lValue ASSIGN exp SEMICOLON
;
expStmt
: exp? SEMICOLON
;
ifStmt
: IF LPAREN cond RPAREN stmt (ELSE stmt)?
;
whileStmt
: WHILE LPAREN cond RPAREN stmt
;
breakStmt
: BREAK SEMICOLON
;
continueStmt
: CONTINUE SEMICOLON
;
/* ===== 表达式 ===== */
exp
: lOrExp
;
cond
: lOrExp
;
lValue
: ID (LBRACK exp RBRACK)*
;
primaryExp
: LPAREN exp RPAREN
| number
| lValue
;
funcCall
: ID LPAREN funcRParams? RPAREN
;
funcRParams
: exp (COMMA exp)*
;
unaryExp
: primaryExp
| funcCall
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NEQ) relExp)*
;
lAndExp
: eqExp (AND eqExp)*
;
lOrExp
: lAndExp (OR lAndExp)*
;
constExp
: addExp
;
number
: ILITERAL
| FLOATLITERAL
;

@ -1,260 +0,0 @@
grammar SysY;
/*====================*/
/* Lexer rules */
/*====================*/
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
ASSIGN: '=';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
EQ: '==';
NEQ: '!=';
AND: '&&';
OR: '||';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
SEMICOLON: ';';
COMMA: ',';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
/* 浮点字面量放在整数字面量前 */
FLOATLITERAL
: [0-9]+ '.' [0-9]* ([eE] [+\-]? [0-9]+)?
| '.' [0-9]+ ([eE] [+\-]? [0-9]+)?
| [0-9]+ [eE] [+\-]? [0-9]+
;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*====================*/
/* Parser rules */
/*====================*/
compUnit
: (decl | funcDef)* EOF
;
/* ===== 声明 ===== */
decl
: constDecl
| varDecl
;
constDecl
: CONST btype constDefList SEMICOLON
;
varDecl
: btype varDefList SEMICOLON
;
btype
: INT
| FLOAT
;
constDefList
: constDef (COMMA constDef)*
;
varDefList
: varDef (COMMA varDef)*
;
constDef
: ID (LBRACK constExp RBRACK)+ ASSIGN constInitVal
| ID ASSIGN constInitVal
;
varDef
: ID (LBRACK exp RBRACK)* (ASSIGN initVal)?
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
;
initVal
: exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
;
/* ===== 函数 ===== */
funcDef
: funcType ID LPAREN funcParams? RPAREN blockStmt
;
funcType
: INT
| FLOAT
| VOID
;
funcParams
: funcParam (COMMA funcParam)*
;
funcParam
: btype ID
| btype ID LBRACK RBRACK (LBRACK exp RBRACK)*
;
/* ===== 语句块 ===== */
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
/* ===== 语句 ===== */
stmt
: returnStmt
| assignStmt
| expStmt
| blockStmt
| ifStmt
| whileStmt
| breakStmt
| continueStmt
;
returnStmt
: RETURN exp? SEMICOLON
;
assignStmt
: lValue ASSIGN exp SEMICOLON
;
expStmt
: exp? SEMICOLON
;
ifStmt
: IF LPAREN cond RPAREN stmt (ELSE stmt)?
;
whileStmt
: WHILE LPAREN cond RPAREN stmt
;
breakStmt
: BREAK SEMICOLON
;
continueStmt
: CONTINUE SEMICOLON
;
/* ===== 表达式 ===== */
exp
: lOrExp
;
cond
: lOrExp
;
lValue
: ID (LBRACK exp RBRACK)*
;
primaryExp
: LPAREN exp RPAREN
| number
| lValue
;
funcCall
: ID LPAREN funcRParams? RPAREN
;
funcRParams
: exp (COMMA exp)*
;
unaryExp
: primaryExp
| funcCall
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NEQ) relExp)*
;
lAndExp
: eqExp (AND eqExp)*
;
lOrExp
: lAndExp (OR lAndExp)*
;
constExp
: addExp
;
number
: ILITERAL
| FLOATLITERAL
;

@ -225,3 +225,4 @@ lOrExp
constExp constExp
: addExp : addExp
; ;

@ -3,15 +3,44 @@ add_library(frontend STATIC
SyntaxTreePrinter.cpp SyntaxTreePrinter.cpp
) )
set(ANTLR4_GRAMMAR "${PROJECT_SOURCE_DIR}/src/antlr4/SysY.g4")
set(ANTLR4_GENERATED_FILES
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.h"
"${ANTLR4_GENERATED_DIR}/SysYLexer.interp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.tokens"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.h"
"${ANTLR4_GENERATED_DIR}/SysY.interp"
"${ANTLR4_GENERATED_DIR}/SysY.tokens"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.h"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.h"
)
add_custom_command(
OUTPUT ${ANTLR4_GENERATED_FILES}
COMMAND ${CMAKE_COMMAND} -E make_directory "${ANTLR4_GENERATED_DIR}"
COMMAND ${Java_JAVA_EXECUTABLE} -jar "${ANTLR4_JAR}"
-Dlanguage=Cpp
-visitor
-no-listener
-Xexact-output-dir
-o "${ANTLR4_GENERATED_DIR}"
"${ANTLR4_GRAMMAR}"
DEPENDS "${ANTLR4_GRAMMAR}" "${ANTLR4_JAR}"
COMMENT "Generating ANTLR4 parser sources from SysY.g4"
VERBATIM
)
add_custom_target(antlr4_generated DEPENDS ${ANTLR4_GENERATED_FILES})
add_dependencies(frontend antlr4_generated)
target_sources(frontend PRIVATE
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
)
target_link_libraries(frontend PUBLIC target_link_libraries(frontend PUBLIC
build_options build_options
${ANTLR4_RUNTIME_TARGET} ${ANTLR4_RUNTIME_TARGET}
) )
# Lexer/Parser
file(GLOB_RECURSE ANTLR4_GENERATED_SOURCES CONFIGURE_DEPENDS
"${ANTLR4_GENERATED_DIR}/*.cpp"
)
if(ANTLR4_GENERATED_SOURCES)
target_sources(frontend PRIVATE ${ANTLR4_GENERATED_SOURCES})
endif()

@ -13,9 +13,9 @@
namespace ir { namespace ir {
// 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型 // BasicBlock 使用 label type
BasicBlock::BasicBlock(std::string name) BasicBlock::BasicBlock(std::string name)
: Value(Type::GetVoidType(), std::move(name)) {} : Value(Type::GetLabelType(), std::move(name)) {}
Function* BasicBlock::GetParent() const { return parent_; } Function* BasicBlock::GetParent() const { return parent_; }
@ -42,4 +42,38 @@ const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_; return successors_;
} }
void BasicBlock::AddPredecessor(BasicBlock* pred) {
if (!pred) return;
for (auto* p : predecessors_) {
if (p == pred) return;
}
predecessors_.push_back(pred);
}
void BasicBlock::AddSuccessor(BasicBlock* succ) {
if (!succ) return;
for (auto* s : successors_) {
if (s == succ) return;
}
successors_.push_back(succ);
}
void BasicBlock::LinkSuccessorsIfNeeded(Instruction* inst) {
if (!inst) return;
if (auto* br = dynamic_cast<BranchInst*>(inst)) {
auto* dest = br->GetDest();
AddSuccessor(dest);
dest->AddPredecessor(this);
return;
}
if (auto* cbr = dynamic_cast<CondBrInst*>(inst)) {
auto* t = cbr->GetTrueDest();
auto* f = cbr->GetFalseDest();
AddSuccessor(t);
AddSuccessor(f);
t->AddPredecessor(this);
f->AddPredecessor(this);
}
}
} // namespace ir } // namespace ir

@ -1,6 +1,7 @@
// 管理基础类型、整型常量池和临时名生成。 // 管理基础类型、整型常量池和临时名生成。
#include "ir/IR.h" #include "ir/IR.h"
#include <cstring>
#include <sstream> #include <sstream>
namespace ir { namespace ir {
@ -15,9 +16,43 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get(); return inserted->second.get();
} }
ConstantInt* Context::GetConstBool(bool v) {
int iv = v ? 1 : 0;
auto it = const_bools_.find(iv);
if (it != const_bools_.end()) return it->second.get();
auto inserted = const_bools_.emplace(
iv, std::make_unique<ConstantInt>(Type::GetInt1Type(), iv)).first;
return inserted->second.get();
}
static uint32_t FloatToBits(float v) {
uint32_t bits = 0;
std::memcpy(&bits, &v, sizeof(float));
return bits;
}
ConstantFloat* Context::GetConstFloat(float v) {
uint32_t bits = FloatToBits(v);
auto it = const_floats_.find(bits);
if (it != const_floats_.end()) return it->second.get();
auto inserted = const_floats_.emplace(
bits, std::make_unique<ConstantFloat>(Type::GetFloatType(), v)).first;
return inserted->second.get();
}
ConstantArray* Context::CreateConstArray(std::shared_ptr<Type> array_ty,
std::vector<ConstantValue*> elements) {
if (!array_ty || !array_ty->IsArray()) {
throw std::runtime_error("CreateConstArray 需要 array type");
}
const_arrays_.push_back(
std::make_unique<ConstantArray>(std::move(array_ty), std::move(elements)));
return const_arrays_.back().get();
}
std::string Context::NextTemp() { std::string Context::NextTemp() {
std::ostringstream oss; std::ostringstream oss;
oss << "%" << ++temp_index_; oss << "%t" << ++temp_index_;
return oss.str(); return oss.str();
} }

@ -5,13 +5,32 @@
namespace ir { namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type) Function::Function(std::string name, std::shared_ptr<Type> func_type,
: Value(std::move(ret_type), std::move(name)) { bool is_declaration)
: Value(std::move(func_type), std::move(name)),
is_declaration_(is_declaration) {
if (!type_ || !type_->IsFunction()) {
throw std::runtime_error("Function 需要 function type");
}
const auto& params = type_->GetParamTypes();
args_.reserve(params.size());
for (size_t i = 0; i < params.size(); ++i) {
args_.push_back(std::make_unique<Argument>(params[i], "%arg" + std::to_string(i), i));
}
if (!is_declaration_) {
entry_ = CreateBlock("entry"); entry_ = CreateBlock("entry");
} }
}
BasicBlock* Function::CreateBlock(const std::string& name) { BasicBlock* Function::CreateBlock(const std::string& name) {
auto block = std::make_unique<BasicBlock>(name); std::string base = name.empty() ? "bb" : name;
auto& count = block_name_counts_[base];
std::string final_name = base;
if (count > 0) {
final_name = base + "." + std::to_string(count);
}
++count;
auto block = std::make_unique<BasicBlock>(final_name);
auto* ptr = block.get(); auto* ptr = block.get();
ptr->SetParent(this); ptr->SetParent(this);
blocks_.push_back(std::move(block)); blocks_.push_back(std::move(block));
@ -29,4 +48,31 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_; return blocks_;
} }
const std::vector<std::unique_ptr<Argument>>& Function::GetArguments() const {
return args_;
}
size_t Function::GetNumArgs() const { return args_.size(); }
Argument* Function::GetArg(size_t index) {
if (index >= args_.size()) {
throw std::out_of_range("Function arg index out of range");
}
return args_[index].get();
}
std::shared_ptr<Type> Function::GetFunctionType() const { return type_; }
std::shared_ptr<Type> Function::GetReturnType() const {
if (!type_ || !type_->IsFunction()) {
throw std::runtime_error("Function type 缺失");
}
return type_->GetReturnType();
}
bool Function::IsDeclaration() const { return is_declaration_; }
Argument::Argument(std::shared_ptr<Type> ty, std::string name, size_t index)
: Value(std::move(ty), std::move(name)), index_(index) {}
} // namespace ir } // namespace ir

@ -8,4 +8,23 @@ namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name) GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {} : User(std::move(ty), std::move(name)) {}
GlobalVariable::GlobalVariable(std::shared_ptr<Type> value_ty, std::string name,
ConstantValue* init, bool is_const)
: GlobalValue(Type::GetPointerType(value_ty), std::move(name)),
value_type_(std::move(value_ty)),
initializer_(init),
is_const_(is_const) {
if (!value_type_) {
throw std::runtime_error("GlobalVariable 缺少 value type");
}
}
const std::shared_ptr<Type>& GlobalVariable::GetValueType() const {
return value_type_;
}
ConstantValue* GlobalVariable::GetInitializer() const { return initializer_; }
bool GlobalVariable::IsConst() const { return is_const_; }
} // namespace ir } // namespace ir

@ -9,6 +9,42 @@
#include "utils/Log.h" #include "utils/Log.h"
namespace ir { namespace ir {
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
case Type::Kind::Int1:
return "i1";
case Type::Kind::Int32:
return "i32";
case Type::Kind::Float:
return "float";
case Type::Kind::Label:
return "label";
case Type::Kind::Pointer:
return TypeToString(*ty.GetElementType()) + "*";
case Type::Kind::Array: {
return "[" + std::to_string(ty.GetArraySize()) + " x " +
TypeToString(*ty.GetElementType()) + "]";
}
case Type::Kind::Function: {
std::string out = TypeToString(*ty.GetReturnType()) + " (";
const auto& params = ty.GetParamTypes();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) out += ", ";
out += TypeToString(*params[i]);
}
if (ty.IsVarArg()) {
if (!params.empty()) out += ", ";
out += "...";
}
out += ")";
return out;
}
}
return "?";
}
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb)
: ctx_(ctx), insert_block_(bb) {} : ctx_(ctx), insert_block_(bb) {}
@ -42,11 +78,107 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name); return CreateBinary(Opcode::Add, lhs, rhs, name);
} }
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::SDiv, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSRem(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::SRem, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FAdd, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FSub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FMul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FDiv, lhs, rhs, name);
}
ICmpInst* IRBuilder::CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
} }
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name); return insert_block_->Append<ICmpInst>(pred, lhs, rhs, name);
}
FCmpInst* IRBuilder::CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<FCmpInst>(pred, lhs, rhs, name);
}
CastInst* IRBuilder::CreateSIToFP(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::SIToFP, std::move(dst_ty), src,
name);
}
CastInst* IRBuilder::CreateFPToSI(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::FPToSI, std::move(dst_ty), src,
name);
}
CastInst* IRBuilder::CreateZExt(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::ZExt, std::move(dst_ty), src,
name);
}
ConstantInt* IRBuilder::CreateConstBool(bool v) {
return ctx_.GetConstBool(v);
}
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return ctx_.GetConstFloat(v);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(std::move(ty), name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
return CreateAlloca(Type::GetInt32Type(), name);
} }
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
@ -57,7 +189,11 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
throw std::runtime_error( throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
} }
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name); if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateLoad ptr 不是指针"));
}
auto val_ty = ptr->GetType()->GetElementType();
return insert_block_->Append<LoadInst>(val_ty, ptr, name);
} }
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@ -75,6 +211,95 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr); return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr);
} }
static std::shared_ptr<Type> ResolveGepResultType(const std::shared_ptr<Type>& base_ptr_ty,
size_t index_count) {
if (!base_ptr_ty || !base_ptr_ty->IsPointer()) {
throw std::runtime_error("GEP base type 必须是指针");
}
auto cur = base_ptr_ty->GetElementType();
for (size_t i = 0; i < index_count; ++i) {
if (cur->IsArray()) {
cur = cur->GetElementType();
continue;
}
if (cur->IsPointer()) {
cur = cur->GetElementType();
continue;
}
}
return Type::GetPointerType(cur);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, std::vector<Value*> indices,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!base_ptr || !base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGep base_ptr 非指针"));
}
auto result_ty = ResolveGepResultType(base_ptr->GetType(), indices.size());
return insert_block_->Append<GepInst>(result_ty, base_ptr, std::move(indices),
name);
}
CallInst* IRBuilder::CreateCall(Value* callee, std::vector<Value*> args,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!callee || !callee->GetType()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 缺少 callee"));
}
std::shared_ptr<Type> func_ty;
if (callee->GetType()->IsFunction()) {
func_ty = callee->GetType();
} else if (callee->GetType()->IsPointer() &&
callee->GetType()->GetElementType()->IsFunction()) {
func_ty = callee->GetType()->GetElementType();
} else {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall callee 非函数"));
}
const auto& params = func_ty->GetParamTypes();
if (!func_ty->IsVarArg() && params.size() != args.size()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 参数数量不匹配"));
}
for (size_t i = 0; i < params.size() && i < args.size(); ++i) {
if (!args[i] || !args[i]->GetType() ||
!args[i]->GetType()->Equals(*params[i])) {
std::string msg = "IRBuilder::CreateCall 参数类型不匹配: arg" +
std::to_string(i) + " got " +
TypeToString(*args[i]->GetType()) + ", expect " +
TypeToString(*params[i]);
throw std::runtime_error(FormatError("ir", msg));
}
}
auto ret_ty = func_ty->GetReturnType();
return insert_block_->Append<CallInst>(ret_ty, callee, std::move(args), name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> ty, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<PhiInst>(std::move(ty), name);
}
BranchInst* IRBuilder::CreateBr(BasicBlock* dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<BranchInst>(dest);
}
CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CondBrInst>(cond, true_dest, false_dest);
}
ReturnInst* IRBuilder::CreateRet(Value* v) { ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -86,4 +311,11 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v); return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
} }
ReturnInst* IRBuilder::CreateRetVoid() {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType());
}
} // namespace ir } // namespace ir

@ -5,6 +5,11 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <ostream> #include <ostream>
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <limits>
#include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@ -12,14 +17,41 @@
namespace ir { namespace ir {
static const char* TypeToString(const Type& ty) { static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) { switch (ty.GetKind()) {
case Type::Kind::Void: case Type::Kind::Void:
return "void"; return "void";
case Type::Kind::Int1:
return "i1";
case Type::Kind::Int32: case Type::Kind::Int32:
return "i32"; return "i32";
case Type::Kind::PtrInt32: case Type::Kind::Float:
return "i32*"; return "float";
case Type::Kind::Label:
return "label";
case Type::Kind::Pointer:
return TypeToString(*ty.GetElementType()) + "*";
case Type::Kind::Array: {
std::ostringstream oss;
oss << "[" << ty.GetArraySize() << " x "
<< TypeToString(*ty.GetElementType()) << "]";
return oss.str();
}
case Type::Kind::Function: {
std::ostringstream oss;
oss << TypeToString(*ty.GetReturnType()) << " (";
const auto& params = ty.GetParamTypes();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) oss << ", ";
oss << TypeToString(*params[i]);
}
if (ty.IsVarArg()) {
if (!params.empty()) oss << ", ";
oss << "...";
}
oss << ")";
return oss.str();
}
} }
throw std::runtime_error(FormatError("ir", "未知类型")); throw std::runtime_error(FormatError("ir", "未知类型"));
} }
@ -32,6 +64,18 @@ static const char* OpcodeToString(Opcode op) {
return "sub"; return "sub";
case Opcode::Mul: case Opcode::Mul:
return "mul"; return "mul";
case Opcode::SDiv:
return "sdiv";
case Opcode::SRem:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::Alloca: case Opcode::Alloca:
return "alloca"; return "alloca";
case Opcode::Load: case Opcode::Load:
@ -40,21 +84,161 @@ static const char* OpcodeToString(Opcode op) {
return "store"; return "store";
case Opcode::Ret: case Opcode::Ret:
return "ret"; return "ret";
case Opcode::Br:
return "br";
case Opcode::CondBr:
return "br";
case Opcode::ICmp:
return "icmp";
case Opcode::FCmp:
return "fcmp";
case Opcode::Call:
return "call";
case Opcode::Phi:
return "phi";
case Opcode::Gep:
return "getelementptr";
case Opcode::SIToFP:
return "sitofp";
case Opcode::FPToSI:
return "fptosi";
case Opcode::ZExt:
return "zext";
} }
return "?"; return "?";
} }
static std::string ValueToString(const Value* v) { static std::string FloatToString(float v) {
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) { std::uint32_t bits = 0;
static_assert(sizeof(bits) == sizeof(v), "float size mismatch");
std::memcpy(&bits, &v, sizeof(bits));
std::ostringstream oss;
oss << "bitcast (i32 " << std::dec << static_cast<std::uint64_t>(bits)
<< " to float)";
return oss.str();
}
static std::string ConstantToString(const ConstantValue* c) {
if (auto* ci = dynamic_cast<const ConstantInt*>(c)) {
return std::to_string(ci->GetValue()); return std::to_string(ci->GetValue());
} }
if (auto* cf = dynamic_cast<const ConstantFloat*>(c)) {
return FloatToString(cf->GetValue());
}
if (auto* ca = dynamic_cast<const ConstantArray*>(c)) {
std::ostringstream oss;
oss << "[";
const auto& elems = ca->GetElements();
for (size_t i = 0; i < elems.size(); ++i) {
if (i > 0) oss << ", ";
oss << TypeToString(*elems[i]->GetType()) << " "
<< ConstantToString(elems[i]);
}
oss << "]";
return oss.str();
}
return "<const>";
}
static std::string ValueToString(const Value* v) {
if (auto* c = dynamic_cast<const ConstantValue*>(v)) {
return ConstantToString(c);
}
if (auto* func = dynamic_cast<const Function*>(v)) {
const auto& name = func->GetName();
if (!name.empty() && name[0] == '@') return name;
return "@" + name;
}
if (auto* gv = dynamic_cast<const GlobalValue*>(v)) {
const auto& name = gv->GetName();
if (!name.empty() && name[0] == '@') return name;
return "@" + name;
}
return v ? v->GetName() : "<null>"; return v ? v->GetName() : "<null>";
} }
static std::string LabelToString(const BasicBlock* bb) {
if (!bb) return "%<null>";
const auto& name = bb->GetName();
if (!name.empty() && name[0] == '%') return name;
return "%" + name;
}
static const char* ICmpPredToString(ICmpPredicate pred) {
switch (pred) {
case ICmpPredicate::Eq:
return "eq";
case ICmpPredicate::Ne:
return "ne";
case ICmpPredicate::Slt:
return "slt";
case ICmpPredicate::Sle:
return "sle";
case ICmpPredicate::Sgt:
return "sgt";
case ICmpPredicate::Sge:
return "sge";
}
return "?";
}
static const char* FCmpPredToString(FCmpPredicate pred) {
switch (pred) {
case FCmpPredicate::Oeq:
return "oeq";
case FCmpPredicate::One:
return "one";
case FCmpPredicate::Olt:
return "olt";
case FCmpPredicate::Ole:
return "ole";
case FCmpPredicate::Ogt:
return "ogt";
case FCmpPredicate::Oge:
return "oge";
}
return "?";
}
void IRPrinter::Print(const Module& module, std::ostream& os) { void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& g : module.GetGlobals()) {
if (!g) continue;
os << "@" << g->GetName() << " = "
<< (g->IsConst() ? "constant " : "global ")
<< TypeToString(*g->GetValueType()) << " ";
if (auto* init = g->GetInitializer()) {
os << ConstantToString(init);
} else {
if (g->GetValueType()->IsArray()) {
os << "zeroinitializer";
} else if (g->GetValueType()->IsFloat()) {
os << "0.0";
} else {
os << "0";
}
}
os << "\n";
}
for (const auto& func : module.GetFunctions()) { for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() if (func->IsDeclaration()) {
<< "() {\n"; os << "declare " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType());
}
os << ")\n";
continue;
}
os << "define " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName();
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) { for (const auto& bb : func->GetBlocks()) {
if (!bb) { if (!bb) {
continue; continue;
@ -65,7 +249,13 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
switch (inst->GetOpcode()) { switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Add:
case Opcode::Sub: case Opcode::Sub:
case Opcode::Mul: { case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst); auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = " os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " " << OpcodeToString(bin->GetOpcode()) << " "
@ -74,27 +264,122 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
<< ValueToString(bin->GetRhs()) << "\n"; << ValueToString(bin->GetRhs()) << "\n";
break; break;
} }
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " " << cmp->GetName() << " = icmp "
<< ICmpPredToString(cmp->GetPredicate()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " " << cmp->GetName() << " = fcmp "
<< FCmpPredToString(cmp->GetPredicate()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt: {
auto* cast = static_cast<const CastInst*>(inst);
os << " " << cast->GetName() << " = "
<< OpcodeToString(cast->GetOpcode()) << " "
<< TypeToString(*cast->GetValue()->GetType()) << " "
<< ValueToString(cast->GetValue()) << " to "
<< TypeToString(*cast->GetType()) << "\n";
break;
}
case Opcode::Alloca: { case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst); auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca i32\n"; os << " " << alloca->GetName() << " = alloca "
<< TypeToString(*alloca->GetAllocatedType()) << "\n";
break; break;
} }
case Opcode::Load: { case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst); auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* " os << " " << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n"; << ValueToString(load->GetPtr()) << "\n";
break; break;
} }
case Opcode::Store: { case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst); auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue()) os << " store " << TypeToString(*store->GetValue()->GetType())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n"; << " " << ValueToString(store->GetValue()) << ", "
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
os << " br label " << LabelToString(br->GetDest()) << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << ValueToString(cbr->GetCond())
<< ", label " << LabelToString(cbr->GetTrueDest())
<< ", label " << LabelToString(cbr->GetFalseDest()) << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
const auto& args = call->GetArgs();
if (!call->GetType()->IsVoid()) {
os << " " << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call->GetType()) << " "
<< ValueToString(call->GetCallee()) << "(";
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType()) << " "
<< ValueToString(args[i]);
}
os << ")\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " " << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType()) << " ";
const auto& values = phi->GetIncomingValues();
const auto& blocks = phi->GetIncomingBlocks();
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) os << ", ";
os << "[ " << ValueToString(values[i]) << ", "
<< LabelToString(blocks[i]) << " ]";
}
os << "\n";
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
os << " " << gep->GetName() << " = getelementptr "
<< TypeToString(*gep->GetBasePtr()->GetType()->GetElementType())
<< ", " << TypeToString(*gep->GetBasePtr()->GetType()) << " "
<< ValueToString(gep->GetBasePtr());
const auto& idx = gep->GetIndices();
for (auto* v : idx) {
os << ", i32 " << ValueToString(v);
}
os << "\n";
break; break;
} }
case Opcode::Ret: { case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst); auto* ret = static_cast<const ReturnInst*>(inst);
if (ret->HasReturnValue()) {
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n"; << ValueToString(ret->GetValue()) << "\n";
} else {
os << " ret void\n";
}
break; break;
} }
} }

@ -52,17 +52,30 @@ Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
Opcode Instruction::GetOpcode() const { return opcode_; } Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br ||
opcode_ == Opcode::CondBr;
}
BasicBlock* Instruction::GetParent() const { return parent_; } BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
static bool IsIntBinaryOp(Opcode op) {
return op == Opcode::Add || op == Opcode::Sub || op == Opcode::Mul ||
op == Opcode::SDiv || op == Opcode::SRem;
}
static bool IsFloatBinaryOp(Opcode op) {
return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul ||
op == Opcode::FDiv;
}
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name) Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) { : Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) { if (!IsIntBinaryOp(op) && !IsFloatBinaryOp(op)) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); throw std::runtime_error(FormatError("ir", "BinaryInst 非算术 op"));
} }
if (!lhs || !rhs) { if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
@ -70,12 +83,15 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
if (!type_ || !lhs->GetType() || !rhs->GetType()) { if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
} }
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() || if (!lhs->GetType()->Equals(*rhs->GetType()) ||
type_->GetKind() != lhs->GetType()->GetKind()) { !type_->Equals(*lhs->GetType())) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
} }
if (!type_->IsInt32()) { if (IsIntBinaryOp(op) && !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); throw std::runtime_error(FormatError("ir", "整数二元只支持 i32"));
}
if (IsFloatBinaryOp(op) && !type_->IsFloat()) {
throw std::runtime_error(FormatError("ir", "浮点二元只支持 float"));
} }
AddOperand(lhs); AddOperand(lhs);
AddOperand(rhs); AddOperand(rhs);
@ -85,6 +101,127 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Value* BinaryInst::GetRhs() const { return GetOperand(1); } Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ICmpInst::ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::ICmp, Type::GetInt1Type(), std::move(name)),
pred_(pred) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "ICmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType() ||
!lhs->GetType()->Equals(*rhs->GetType())) {
throw std::runtime_error(FormatError("ir", "ICmpInst 类型不匹配"));
}
if (!lhs->GetType()->IsInt1() && !lhs->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "ICmpInst 仅支持整型"));
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* ICmpInst::GetLhs() const { return GetOperand(0); }
Value* ICmpInst::GetRhs() const { return GetOperand(1); }
FCmpInst::FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)),
pred_(pred) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType() ||
!lhs->GetType()->Equals(*rhs->GetType())) {
throw std::runtime_error(FormatError("ir", "FCmpInst 类型不匹配"));
}
if (!lhs->GetType()->IsFloat()) {
throw std::runtime_error(FormatError("ir", "FCmpInst 仅支持 float"));
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* FCmpInst::GetLhs() const { return GetOperand(0); }
Value* FCmpInst::GetRhs() const { return GetOperand(1); }
CastInst::CastInst(Opcode op, std::shared_ptr<Type> dst_ty, Value* src,
std::string name)
: Instruction(op, std::move(dst_ty), std::move(name)) {
if (op != Opcode::SIToFP && op != Opcode::FPToSI && op != Opcode::ZExt) {
throw std::runtime_error(FormatError("ir", "CastInst 不支持的 op"));
}
if (!src) {
throw std::runtime_error(FormatError("ir", "CastInst 缺少 src"));
}
if (op == Opcode::SIToFP) {
if (!src->GetType()->IsInt32() && !src->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "SIToFP 仅支持整型"));
}
if (!type_ || !type_->IsFloat()) {
throw std::runtime_error(FormatError("ir", "SIToFP 目标必须是 float"));
}
} else if (op == Opcode::FPToSI) {
if (!src->GetType()->IsFloat()) {
throw std::runtime_error(FormatError("ir", "FPToSI 仅支持 float"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "FPToSI 目标必须是 i32"));
}
} else {
if (!src->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "ZExt 仅支持 i1"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "ZExt 目标必须是 i32"));
}
}
AddOperand(src);
}
Value* CastInst::GetValue() const { return GetOperand(0); }
BranchInst::BranchInst(BasicBlock* dest)
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
if (!dest) {
throw std::runtime_error(FormatError("ir", "BranchInst 缺少目标块"));
}
AddOperand(dest);
}
BasicBlock* BranchInst::GetDest() const {
return static_cast<BasicBlock*>(GetOperand(0));
}
CondBrInst::CondBrInst(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest)
: Instruction(Opcode::CondBr, Type::GetVoidType(), "") {
if (!cond || !true_dest || !false_dest) {
throw std::runtime_error(FormatError("ir", "CondBrInst 缺少参数"));
}
if (!cond->GetType() || !cond->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "CondBrInst cond 必须是 i1"));
}
AddOperand(cond);
AddOperand(true_dest);
AddOperand(false_dest);
}
Value* CondBrInst::GetCond() const { return GetOperand(0); }
BasicBlock* CondBrInst::GetTrueDest() const {
return static_cast<BasicBlock*>(GetOperand(1));
}
BasicBlock* CondBrInst::GetFalseDest() const {
return static_cast<BasicBlock*>(GetOperand(2));
}
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
}
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val) ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") { : Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) { if (!val) {
@ -96,13 +233,24 @@ ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
AddOperand(val); AddOperand(val);
} }
Value* ReturnInst::GetValue() const { return GetOperand(0); } bool ReturnInst::HasReturnValue() const { return GetNumOperands() > 0; }
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name) Value* ReturnInst::GetValue() const {
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { if (!HasReturnValue()) return nullptr;
if (!type_ || !type_->IsPtrInt32()) { return GetOperand(0);
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
} }
AllocaInst::AllocaInst(std::shared_ptr<Type> allocated_ty, std::string name)
: Instruction(Opcode::Alloca, Type::GetPointerType(allocated_ty),
std::move(name)),
allocated_type_(std::move(allocated_ty)) {
if (!allocated_type_) {
throw std::runtime_error(FormatError("ir", "AllocaInst 缺少类型"));
}
}
const std::shared_ptr<Type>& AllocaInst::GetAllocatedType() const {
return allocated_type_;
} }
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name) LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
@ -110,12 +258,11 @@ LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
if (!ptr) { if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
} }
if (!type_ || !type_->IsInt32()) { if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); throw std::runtime_error(FormatError("ir", "LoadInst ptr 不是指针"));
} }
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { if (!type_ || !ptr->GetType()->GetElementType()->Equals(*type_)) {
throw std::runtime_error( throw std::runtime_error(FormatError("ir", "LoadInst 类型不匹配"));
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
} }
AddOperand(ptr); AddOperand(ptr);
} }
@ -133,12 +280,11 @@ StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
if (!type_ || !type_->IsVoid()) { if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
} }
if (!val->GetType() || !val->GetType()->IsInt32()) { if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); throw std::runtime_error(FormatError("ir", "StoreInst ptr 不是指针"));
} }
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { if (!ptr->GetType()->GetElementType()->Equals(*val->GetType())) {
throw std::runtime_error( throw std::runtime_error(FormatError("ir", "StoreInst 类型不匹配"));
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
} }
AddOperand(val); AddOperand(val);
AddOperand(ptr); AddOperand(ptr);
@ -148,4 +294,70 @@ Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); } Value* StoreInst::GetPtr() const { return GetOperand(1); }
CallInst::CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
std::vector<Value*> args, std::string name)
: Instruction(Opcode::Call, std::move(ret_ty), std::move(name)),
args_(std::move(args)) {
if (!callee) {
throw std::runtime_error(FormatError("ir", "CallInst 缺少 callee"));
}
AddOperand(callee);
for (auto* arg : args_) {
if (!arg) {
throw std::runtime_error(FormatError("ir", "CallInst arg 为空"));
}
AddOperand(arg);
}
}
Value* CallInst::GetCallee() const { return GetOperand(0); }
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* value, BasicBlock* block) {
if (!value || !block) {
throw std::runtime_error(FormatError("ir", "PhiInst incoming 为空"));
}
if (!value->GetType() || !type_ || !value->GetType()->Equals(*type_)) {
throw std::runtime_error(FormatError("ir", "PhiInst 类型不匹配"));
}
incoming_values_.push_back(value);
incoming_blocks_.push_back(block);
AddOperand(value);
AddOperand(block);
}
const std::vector<Value*>& PhiInst::GetIncomingValues() const {
return incoming_values_;
}
const std::vector<BasicBlock*>& PhiInst::GetIncomingBlocks() const {
return incoming_blocks_;
}
GepInst::GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name)
: Instruction(Opcode::Gep, std::move(result_ptr_ty), std::move(name)),
indices_(std::move(indices)) {
if (!base_ptr) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少 base_ptr"));
}
if (!base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "GepInst base_ptr 不是指针"));
}
if (!type_ || !type_->IsPointer()) {
throw std::runtime_error(FormatError("ir", "GepInst 结果必须是指针"));
}
AddOperand(base_ptr);
for (auto* idx : indices_) {
if (!idx || !idx->GetType() || !idx->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "GepInst index 必须是 i32"));
}
AddOperand(idx);
}
}
Value* GepInst::GetBasePtr() const { return GetOperand(0); }
} // namespace ir } // namespace ir

@ -10,12 +10,39 @@ const Context& Module::GetContext() const { return context_; }
Function* Module::CreateFunction(const std::string& name, Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) { std::shared_ptr<Type> ret_type) {
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type))); auto func_ty = Type::GetFunctionType(std::move(ret_type), {});
functions_.push_back(std::make_unique<Function>(name, std::move(func_ty)));
return functions_.back().get(); return functions_.back().get();
} }
Function* Module::CreateFunctionWithType(const std::string& name,
std::shared_ptr<Type> func_type) {
functions_.push_back(
std::make_unique<Function>(name, std::move(func_type), false));
return functions_.back().get();
}
Function* Module::CreateFunctionDecl(const std::string& name,
std::shared_ptr<Type> func_type) {
functions_.push_back(
std::make_unique<Function>(name, std::move(func_type), true));
return functions_.back().get();
}
GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> value_type,
ConstantValue* init, bool is_const) {
globals_.push_back(std::make_unique<GlobalVariable>(
std::move(value_type), name, init, is_const));
return globals_.back().get();
}
const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const { const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_; return functions_;
} }
const std::vector<std::unique_ptr<GlobalVariable>>& Module::GetGlobals() const {
return globals_;
}
} // namespace ir } // namespace ir

@ -1,31 +1,148 @@
// 当前仅支持 void、i32 和 i32* // 支持 void/i1/i32/float/ptr/array/function/label
#include "ir/IR.h" #include "ir/IR.h"
namespace ir { namespace ir {
Type::Type(Kind k) : kind_(k) {} Type::Type(Kind k) : kind_(k) {}
Type::Type(Kind k, std::shared_ptr<Type> elem, size_t count)
: kind_(k), elem_type_(std::move(elem)), array_size_(count) {}
Type::Type(Kind k, std::shared_ptr<Type> ret,
std::vector<std::shared_ptr<Type>> params, bool is_vararg)
: kind_(k), ret_type_(std::move(ret)), param_types_(std::move(params)),
is_vararg_(is_vararg) {}
const std::shared_ptr<Type>& Type::GetVoidType() { const std::shared_ptr<Type>& Type::GetVoidType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetInt1Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int1);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() { const std::shared_ptr<Type>& Type::GetInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetPtrInt32Type() { const std::shared_ptr<Type>& Type::GetFloatType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Float);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetLabelType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Label);
return type;
}
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> elem) {
if (!elem) {
throw std::runtime_error("PointerType 缺少 element type");
}
return std::make_shared<Type>(Kind::Pointer, std::move(elem), 0);
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> elem,
size_t count) {
if (!elem) {
throw std::runtime_error("ArrayType 缺少 element type");
}
return std::make_shared<Type>(Kind::Array, std::move(elem), count);
}
std::shared_ptr<Type> Type::GetFunctionType(
std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg) {
if (!ret) {
throw std::runtime_error("FunctionType 缺少 return type");
}
return std::make_shared<Type>(Kind::Function, std::move(ret),
std::move(params), is_vararg);
}
Type::Kind Type::GetKind() const { return kind_; } Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; } bool Type::IsVoid() const { return kind_ == Kind::Void; }
bool Type::IsInt1() const { return kind_ == Kind::Int1; }
bool Type::IsInt32() const { return kind_ == Kind::Int32; } bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } bool Type::IsFloat() const { return kind_ == Kind::Float; }
bool Type::IsPointer() const { return kind_ == Kind::Pointer; }
bool Type::IsArray() const { return kind_ == Kind::Array; }
bool Type::IsFunction() const { return kind_ == Kind::Function; }
bool Type::IsLabel() const { return kind_ == Kind::Label; }
const std::shared_ptr<Type>& Type::GetElementType() const {
if (!elem_type_) {
throw std::runtime_error("Type 没有 element type");
}
return elem_type_;
}
size_t Type::GetArraySize() const {
if (!IsArray()) {
throw std::runtime_error("Type 不是 array");
}
return array_size_;
}
const std::shared_ptr<Type>& Type::GetReturnType() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return ret_type_;
}
const std::vector<std::shared_ptr<Type>>& Type::GetParamTypes() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return param_types_;
}
bool Type::IsVarArg() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return is_vararg_;
}
bool Type::Equals(const Type& other) const {
if (kind_ != other.kind_) return false;
switch (kind_) {
case Kind::Pointer:
return elem_type_ && other.elem_type_ &&
elem_type_->Equals(*other.elem_type_);
case Kind::Array:
return array_size_ == other.array_size_ && elem_type_ &&
other.elem_type_ && elem_type_->Equals(*other.elem_type_);
case Kind::Function: {
if (!ret_type_ || !other.ret_type_ ||
!ret_type_->Equals(*other.ret_type_) ||
is_vararg_ != other.is_vararg_ ||
param_types_.size() != other.param_types_.size()) {
return false;
}
for (size_t i = 0; i < param_types_.size(); ++i) {
if (!param_types_[i] || !other.param_types_[i] ||
!param_types_[i]->Equals(*other.param_types_[i])) {
return false;
}
}
return true;
}
default:
return true;
}
}
} // namespace ir } // namespace ir

@ -18,9 +18,21 @@ void Value::SetName(std::string n) { name_ = std::move(n); }
bool Value::IsVoid() const { return type_ && type_->IsVoid(); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
bool Value::IsInt1() const { return type_ && type_->IsInt1(); }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); } bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } bool Value::IsFloat() const { return type_ && type_->IsFloat(); }
bool Value::IsPointer() const { return type_ && type_->IsPointer(); }
bool Value::IsArray() const { return type_ && type_->IsArray(); }
bool Value::IsFunctionType() const { return type_ && type_->IsFunction(); }
bool Value::IsPtrInt32() const {
return type_ && type_->IsPointer() && type_->GetElementType()->IsInt32();
}
bool Value::IsConstant() const { bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr; return dynamic_cast<const ConstantValue*>(this) != nullptr;
@ -78,6 +90,25 @@ ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {} : Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v) ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) {} : ConstantValue(std::move(ty), ""), value_(v) {
if (!type_ || (!type_->IsInt32() && !type_->IsInt1())) {
throw std::runtime_error("ConstantInt 需要 i1/i32 类型");
}
}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float v)
: ConstantValue(std::move(ty), ""), value_(v) {
if (!type_ || !type_->IsFloat()) {
throw std::runtime_error("ConstantFloat 需要 float 类型");
}
}
ConstantArray::ConstantArray(std::shared_ptr<Type> ty,
std::vector<ConstantValue*> elements)
: ConstantValue(std::move(ty), ""), elements_(std::move(elements)) {
if (!type_ || !type_->IsArray()) {
throw std::runtime_error("ConstantArray 需要 array 类型");
}
}
} // namespace ir } // namespace ir

@ -1,46 +1,32 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <any>
#include <stdexcept> #include <stdexcept>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
namespace { std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
if (!ctx) return BlockFlow::Continue;
std::string GetLValueName(SysYParser::LValueContext& lvalue) { BlockFlow flow = BlockFlow::Continue;
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
return lvalue.ID()->getText();
}
} // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
for (auto* item : ctx->blockItem()) { for (auto* item : ctx->blockItem()) {
if (item) { if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
// 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 flow = BlockFlow::Terminated;
break; break;
} }
} }
} }
return {}; return flow;
} }
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(SysYParser::BlockItemContext& item) {
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this)); return std::any_cast<BlockFlow>(item.accept(this));
} }
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (!ctx) { if (!ctx) return BlockFlow::Continue;
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}
if (ctx->decl()) { if (ctx->decl()) {
ctx->decl()->accept(this); ctx->decl()->accept(this);
return BlockFlow::Continue; return BlockFlow::Continue;
@ -48,60 +34,169 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (ctx->stmt()) { if (ctx->stmt()) {
return ctx->stmt()->accept(this); return ctx->stmt()->accept(this);
} }
throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); return BlockFlow::Continue;
} }
// 变量声明的 IR 生成目前也是最小实现:
// - 先检查声明的基础类型,当前仅支持局部 int
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
//
// 和更完整的版本相比,这里还没有:
// - 一个 Decl 中多个变量定义的顺序处理;
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少变量声明")); if (auto* constDecl = ctx->constDecl()) {
for (auto* def : constDecl->constDef()) {
def->accept(this);
} }
if (!ctx->btype() || !ctx->btype()->INT()) { return {};
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
} }
auto* var_def = ctx->varDef(); if (auto* varDecl = ctx->varDecl()) {
if (!var_def) { for (auto* varDef : varDecl->varDef()) {
throw std::runtime_error(FormatError("irgen", "非法变量声明")); varDef->accept(this);
}
return {};
} }
var_def->accept(this);
return {}; return {};
} }
// 当前仍是教学用的最小版本,因此这里只支持:
// - 局部 int 变量;
// - 标量初始化;
// - 一个 VarDef 对应一个槽位。
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少变量定义")); if (!ctx->ID()) {
}
if (!ctx->lValue()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
} }
GetLValueName(*ctx->lValue()); if (!func_) {
if (storage_map_.find(ctx) != storage_map_.end()) { const TypeDesc* ty = sema_.GetVarType(ctx);
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局变量类型缺失"));
}
if (global_var_storage_.find(ctx) != global_var_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局变量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->initVal()) {
if (!initVal->exp()) {
throw std::runtime_error(FormatError("irgen", "全局变量初始化非法"));
}
init = EvalConstScalar(initVal->exp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->initVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, false);
global_var_storage_[ctx] = gv;
return {};
}
if (var_storage_.find(ctx) != var_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成存储槽位"));
}
const TypeDesc* ty = sema_.GetVarType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "变量类型缺失"));
} }
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp());
storage_map_[ctx] = slot; var_storage_[ctx] = slot;
if (ty->dims.empty()) {
ir::Value* init = nullptr; ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) { if (auto* initVal = ctx->initVal()) {
if (!init_value->exp()) { if (!initVal->exp()) {
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); throw std::runtime_error(FormatError("irgen", "标量初始化非法"));
} }
init = EvalExpr(*init_value->exp()); init = EvalExp(initVal->exp());
} else { } else {
init = builder_.CreateConstInt(0); init = DefaultValue(*ty);
} }
builder_.CreateStore(init, slot); builder_.CreateStore(init, slot);
} else {
InitArray(slot, *ty, ctx->initVal());
}
return {};
}
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "常量声明缺少名称"));
}
if (!func_) {
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局常量类型缺失"));
}
if (global_const_storage_.find(ctx) != global_const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局常量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "全局常量初始化非法"));
}
init = EvalConstScalar(initVal->constExp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->constInitVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalConstArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, true);
global_const_storage_[ctx] = gv;
return {};
}
if (const_storage_.find(ctx) != const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成常量存储"));
}
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "常量类型缺失"));
}
auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp());
const_storage_[ctx] = slot;
if (ty->dims.empty()) {
ir::Value* init = nullptr;
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "常量初始化非法"));
}
init = std::any_cast<ir::Value*>(initVal->constExp()->accept(this));
} else {
init = DefaultValue(*ty);
}
builder_.CreateStore(init, slot);
} else {
InitConstArray(slot, *ty, ctx->constInitVal());
}
return {}; return {};
} }

@ -4,12 +4,11 @@
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h"
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) { const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>(); auto module = std::make_unique<ir::Module>(); // 无参构造
IRGenImpl gen(*module, sema); IRGenImpl visitor(*module, sema);
tree.accept(&gen); tree.accept(&visitor);
return module; return module;
} }

File diff suppressed because it is too large Load Diff

@ -6,82 +6,116 @@
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
namespace {
void VerifyFunctionStructure(const ir::Function& func) {
// 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。
for (const auto& bb : func.GetBlocks()) {
if (!bb || !bb->HasTerminator()) {
throw std::runtime_error(
FormatError("irgen", "基本块未正确终结: " +
(bb ? bb->GetName() : std::string("<null>"))));
}
}
}
} // namespace
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), : module_(module),
sema_(sema), sema_(sema),
func_(nullptr), func_(nullptr),
builder_(module.GetContext(), nullptr) {} builder_(module.GetContext(), nullptr) {}
// 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
//
// 当前还没有实现:
// - 多个函数定义的遍历与生成;
// - 全局变量、全局常量的 IR 生成。
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
func_map_.clear();
global_var_storage_.clear();
global_const_storage_.clear();
func_ = nullptr;
for (auto* decl : ctx->decl()) {
if (decl) decl->accept(this);
} }
auto* func = ctx->funcDef(); for (auto* funcDef : ctx->funcDef()) {
if (!func) { if (!funcDef || !funcDef->ID()) continue;
throw std::runtime_error(FormatError("irgen", "缺少函数定义")); const auto* fty = sema_.GetFuncType(funcDef);
if (!fty) {
throw std::runtime_error(FormatError("irgen", "缺少函数类型"));
} }
func->accept(this); std::vector<std::shared_ptr<ir::Type>> params;
return {}; for (const auto& p : fty->params) {
params.push_back(ToIRParamType(p));
}
auto ret = ToIRType(fty->ret);
auto func_ty = ir::Type::GetFunctionType(ret, params);
auto* fn = module_.CreateFunctionWithType(funcDef->ID()->getText(), func_ty);
func_map_[funcDef] = fn;
}
auto declare_builtin = [&](const std::string& name,
std::shared_ptr<ir::Type> ret,
std::vector<std::shared_ptr<ir::Type>> params) {
for (const auto& fn : module_.GetFunctions()) {
if (fn && fn->GetName() == name) return;
} }
auto fty = ir::Type::GetFunctionType(ret, params);
module_.CreateFunctionDecl(name, fty);
};
auto i32 = ir::Type::GetInt32Type();
auto f32 = ir::Type::GetFloatType();
declare_builtin("getint", i32, {});
declare_builtin("getch", i32, {});
declare_builtin("getarray", i32, {ir::Type::GetPointerType(i32)});
declare_builtin("putint", ir::Type::GetVoidType(), {i32});
declare_builtin("putch", ir::Type::GetVoidType(), {i32});
declare_builtin("putarray", ir::Type::GetVoidType(),
{i32, ir::Type::GetPointerType(i32)});
declare_builtin("getfloat", f32, {});
declare_builtin("getfarray", i32, {ir::Type::GetPointerType(f32)});
declare_builtin("putfloat", ir::Type::GetVoidType(), {f32});
declare_builtin("putfarray", ir::Type::GetVoidType(),
{i32, ir::Type::GetPointerType(f32)});
declare_builtin("starttime", ir::Type::GetVoidType(), {});
declare_builtin("stoptime", ir::Type::GetVoidType(), {});
// 函数 IR 生成当前实现了: for (auto* funcDef : ctx->funcDef()) {
// 1. 获取函数名; if (funcDef) funcDef->accept(this);
// 2. 检查函数返回类型; }
// 3. 在 Module 中创建 Function return {};
// 4. 将 builder 插入点设置到入口基本块; }
// 5. 继续生成函数体。
//
// 当前还没有实现:
// - 通用函数返回类型处理;
// - 形参列表遍历与参数类型收集;
// - FunctionType 这样的函数类型对象;
// - Argument/形式参数 IR 对象;
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) { if (!ctx || !ctx->block()) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
if (!ctx->blockStmt()) {
throw std::runtime_error(FormatError("irgen", "函数体为空")); throw std::runtime_error(FormatError("irgen", "函数体为空"));
} }
if (!ctx->ID()) { auto it = func_map_.find(ctx);
throw std::runtime_error(FormatError("irgen", "缺少函数名")); if (it == func_map_.end()) {
throw std::runtime_error(FormatError("irgen", "函数未注册"));
}
func_ = it->second;
auto* entry = func_->GetEntry();
builder_.SetInsertPoint(entry);
var_storage_.clear();
const_storage_.clear();
param_storage_.clear();
loop_stack_.clear();
const auto* fty = sema_.GetFuncType(ctx);
if (!fty) {
throw std::runtime_error(FormatError("irgen", "缺少函数类型"));
}
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
auto* param_ctx = params[i];
auto* arg = func_->GetArg(i);
const TypeDesc* pty = sema_.GetParamType(param_ctx);
if (!pty) {
throw std::runtime_error(FormatError("irgen", "缺少参数类型"));
}
auto slot = CreateEntryAlloca(ToIRParamType(*pty),
module_.GetContext().NextTemp());
builder_.CreateStore(arg, slot);
param_storage_[param_ctx] = slot;
} }
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
} }
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); ctx->block()->accept(this);
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
ctx->blockStmt()->accept(this); if (!builder_.GetInsertBlock()->HasTerminator()) {
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 if (func_->GetReturnType()->IsVoid()) {
VerifyFunctionStructure(*func_); builder_.CreateRetVoid();
} else {
TypeDesc ret = fty->ret;
builder_.CreateRet(DefaultValue(ret));
}
}
return {}; return {};
} }

@ -1,39 +1,170 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <any>
#include <stdexcept> #include <stdexcept>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
// 语句生成当前只实现了最小子集。
// 目前支持:
// - return <exp>;
//
// 还未支持:
// - 赋值语句
// - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少语句")); if (ctx->lVal() && ctx->ASSIGN()) {
ir::Value* addr = GetLValAddress(ctx->lVal());
ir::Value* val = EvalExp(ctx->exp());
BoundDecl bound = sema_.ResolveVarUse(ctx->lVal());
const TypeDesc* ty = nullptr;
if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) {
ty = sema_.GetVarType(bound.var_decl);
} else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) {
ty = sema_.GetParamType(bound.param_decl);
} else if (bound.kind == BoundDecl::Kind::Const) {
throw std::runtime_error(FormatError("irgen", "不能给常量赋值"));
}
if (!ty && ctx->lVal()->ID()) {
const auto name = ctx->lVal()->ID()->getText();
for (const auto& kv : var_storage_) {
auto* def = const_cast<SysYParser::VarDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetVarType(kv.first);
break;
}
}
if (!ty) {
for (const auto& kv : const_storage_) {
auto* def = const_cast<SysYParser::ConstDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
throw std::runtime_error(FormatError("irgen", "不能给常量赋值"));
}
}
}
if (!ty) {
for (const auto& kv : param_storage_) {
auto* def = const_cast<SysYParser::FuncFParamContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetParamType(kv.first);
break;
}
}
}
if (!ty) {
for (const auto& kv : global_var_storage_) {
auto* def = const_cast<SysYParser::VarDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetVarType(kv.first);
break;
}
}
}
if (!ty) {
for (const auto& kv : global_const_storage_) {
auto* def = const_cast<SysYParser::ConstDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
throw std::runtime_error(FormatError("irgen", "不能给常量赋值"));
}
}
}
}
if (!ty) {
throw std::runtime_error(FormatError("irgen", "无法解析赋值类型"));
}
if (ty->base == BaseTypeKind::Float) {
if (val->IsInt1()) {
val = CastToFloat(CastToInt(val));
} else if (val->IsInt32()) {
val = CastToFloat(val);
}
} else if (ty->base == BaseTypeKind::Int) {
if (val->IsFloat()) {
val = CastToInt(val);
} else if (val->IsInt1()) {
val = CastToInt(val);
}
} }
if (ctx->returnStmt()) { builder_.CreateStore(val, addr);
return ctx->returnStmt()->accept(this); return BlockFlow::Continue;
} }
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); if (ctx->block()) {
return ctx->block()->accept(this);
} }
if (ctx->IF()) {
auto* then_bb = func_->CreateBlock("if.then");
auto* else_bb = func_->CreateBlock("if.else");
auto* merge_bb = func_->CreateBlock("if.end");
EmitCondBr(ctx->cond(), then_bb, else_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(else_bb);
if (ctx->stmt(1)) {
auto else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
} else {
builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (ctx->WHILE()) {
auto* cond_bb = func_->CreateBlock("while.cond");
auto* body_bb = func_->CreateBlock("while.body");
auto* end_bb = func_->CreateBlock("while.end");
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
EmitCondBr(ctx->cond(), body_bb, end_bb);
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { builder_.SetInsertPoint(body_bb);
if (!ctx) { PushLoop(end_bb, cond_bb);
throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
PopLoop();
if (body_flow != BlockFlow::Terminated) {
builder_.CreateBr(cond_bb);
}
builder_.SetInsertPoint(end_bb);
return BlockFlow::Continue;
}
if (ctx->BREAK()) {
auto* target = CurrentBreak();
if (!target) {
throw std::runtime_error(FormatError("irgen", "break 不在循环内"));
} }
builder_.CreateBr(target);
return BlockFlow::Terminated;
}
if (ctx->CONTINUE()) {
auto* target = CurrentContinue();
if (!target) {
throw std::runtime_error(FormatError("irgen", "continue 不在循环内"));
}
builder_.CreateBr(target);
return BlockFlow::Terminated;
}
if (ctx->RETURN()) {
if (!ctx->exp()) { if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); builder_.CreateRetVoid();
return BlockFlow::Terminated;
}
ir::Value* v = EvalExp(ctx->exp());
auto ret_ty = func_->GetReturnType();
if (ret_ty->IsFloat() && v->IsInt32()) {
v = CastToFloat(v);
} else if (ret_ty->IsInt32() && v->IsFloat()) {
v = CastToInt(v);
} }
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v); builder_.CreateRet(v);
return BlockFlow::Terminated; return BlockFlow::Terminated;
} }
if (ctx->exp()) {
EvalExp(ctx->exp());
}
return BlockFlow::Continue;
}

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

@ -3,6 +3,7 @@
#include <any> #include <any>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_set>
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
@ -10,12 +11,122 @@
namespace { namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) { static BaseTypeKind BaseTypeFromBType(SysYParser::BTypeContext* ctx) {
if (!lvalue.ID()) { if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法左值")); throw std::runtime_error(FormatError("sema", "缺少 bType"));
}
if (ctx->INT()) return BaseTypeKind::Int;
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知基础类型"));
}
static BaseTypeKind BaseTypeFromFuncType(SysYParser::FuncTypeContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少 funcType"));
}
if (ctx->VOID()) return BaseTypeKind::Void;
if (ctx->INT()) return BaseTypeKind::Int;
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知函数返回类型"));
} }
return lvalue.ID()->getText();
class ConstEvalVisitor final : public SysYBaseVisitor {
public:
explicit ConstEvalVisitor(const SymbolTable& table) : table_(table) {}
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override {
return visitAddExp(ctx->addExp());
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
auto muls = ctx->mulExp();
if (muls.empty()) return 0;
int value = std::any_cast<int>(muls[0]->accept(this));
for (size_t i = 1; i < muls.size(); ++i) {
int rhs = std::any_cast<int>(muls[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "+";
if (text == "+") {
value += rhs;
} else if (text == "-") {
value -= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法加法运算符"));
}
}
return value;
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
auto unaries = ctx->unaryExp();
if (unaries.empty()) return 0;
int value = std::any_cast<int>(unaries[0]->accept(this));
for (size_t i = 1; i < unaries.size(); ++i) {
int rhs = std::any_cast<int>(unaries[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "*";
if (text == "*") {
value *= rhs;
} else if (text == "/") {
value /= rhs;
} else if (text == "%") {
value %= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法乘法运算符"));
}
}
return value;
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
if (ctx->unaryOp() && ctx->unaryExp()) {
int val = std::any_cast<int>(ctx->unaryExp()->accept(this));
auto op = ctx->unaryOp()->getText();
if (op == "+") return val;
if (op == "-") return -val;
throw std::runtime_error(FormatError("sema", "constExp 不支持 !"));
} }
throw std::runtime_error(FormatError("sema", "constExp 不支持函数调用"));
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this);
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
return 0;
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (ctx->INT_CONST()) {
const std::string text = ctx->getText();
size_t idx = 0;
long long val = std::stoll(text, &idx, 0);
if (idx != text.size()) {
throw std::runtime_error(FormatError("sema", "非法整数常量"));
}
return static_cast<int>(val);
}
if (ctx->FLOAT_CONST()) {
return static_cast<int>(std::stof(ctx->getText()));
}
throw std::runtime_error(FormatError("sema", "constExp 仅支持整数"));
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "constExp 非法变量"));
}
const auto* entry = table_.Lookup(ctx->ID()->getText());
if (!entry || !entry->is_const || !entry->const_value.has_value()) {
throw std::runtime_error(FormatError("sema", "constExp 使用了非常量"));
}
return entry->const_value.value();
}
private:
const SymbolTable& table_;
};
class SemaVisitor final : public SysYBaseVisitor { class SemaVisitor final : public SysYBaseVisitor {
public: public:
@ -23,172 +134,372 @@ class SemaVisitor final : public SysYBaseVisitor {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元")); throw std::runtime_error(FormatError("sema", "缺少编译单元"));
} }
auto* func = ctx->funcDef(); for (auto* func : ctx->funcDef()) {
if (!func || !func->blockStmt()) { if (!func || !func->ID()) continue;
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); std::string name = func->ID()->getText();
if (func_table_.find(name) != func_table_.end()) {
throw std::runtime_error(FormatError("sema", "重复定义函数: " + name));
} }
if (!func->ID() || func->ID()->getText() != "main") { func_table_[name] = func;
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
} }
func->accept(this);
if (!seen_return_) { for (auto* decl : ctx->decl()) {
throw std::runtime_error( if (decl) decl->accept(this);
FormatError("sema", "main 函数必须包含 return 语句")); }
for (auto* func : ctx->funcDef()) {
if (func) func->accept(this);
}
if (func_table_.find("main") == func_table_.end()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
} }
return {}; return {};
} }
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) { if (!ctx || !ctx->block()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); throw std::runtime_error(FormatError("sema", "函数体为空"));
} }
if (!ctx->funcType() || !ctx->funcType()->INT()) { if (!ctx->ID()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); throw std::runtime_error(FormatError("sema", "缺少函数名"));
} }
const auto& items = ctx->blockStmt()->blockItem(); FuncTypeDesc fty;
if (items.empty()) { fty.ret.base = BaseTypeFromFuncType(ctx->funcType());
throw std::runtime_error( if (ctx->funcFParams()) {
FormatError("sema", "main 函数不能为空,且必须以 return 结束")); for (auto* param : ctx->funcFParams()->funcFParam()) {
fty.params.push_back(BuildParamType(param));
} }
ctx->blockStmt()->accept(this);
return {};
} }
sema_.RegisterFunc(ctx, fty);
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { current_ret_ = fty.ret.base;
if (!ctx) { seen_return_ = false;
throw std::runtime_error(FormatError("sema", "缺少语句块"));
table_.EnterScope();
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
RegisterParam(param);
} }
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
} }
if (seen_return_) { ctx->block()->accept(this);
throw std::runtime_error( table_.ExitScope();
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
if (current_ret_ != BaseTypeKind::Void && !seen_return_) {
throw std::runtime_error(FormatError("sema", "非 void 函数缺少 return"));
} }
current_item_index_ = i; return {};
total_items_ = items.size();
item->accept(this);
} }
std::any visitBlock(SysYParser::BlockContext* ctx) override {
if (!ctx) return {};
table_.EnterScope();
for (auto* item : ctx->blockItem()) {
if (item) item->accept(this);
}
table_.ExitScope();
return {}; return {};
} }
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); if (ctx->decl()) return ctx->decl()->accept(this);
if (ctx->stmt()) return ctx->stmt()->accept(this);
return {};
} }
if (ctx->decl()) {
ctx->decl()->accept(this); std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) return {};
if (auto* c = ctx->constDecl()) return c->accept(this);
if (auto* v = ctx->varDecl()) return v->accept(this);
return {}; return {};
} }
if (ctx->stmt()) {
ctx->stmt()->accept(this); std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) return {};
BaseTypeKind base = BaseTypeFromBType(ctx->bType());
for (auto* def : ctx->constDef()) {
RegisterConst(def, base);
}
return {}; return {};
} }
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
if (!ctx || !ctx->bType()) return {};
BaseTypeKind base = BaseTypeFromBType(ctx->bType());
for (auto* def : ctx->varDef()) {
RegisterVar(def, base);
}
return {};
} }
std::any visitDecl(SysYParser::DeclContext* ctx) override { std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("sema", "非法变量声明")); if (ctx->lVal() && ctx->ASSIGN()) {
ctx->lVal()->accept(this);
if (ctx->exp()) ctx->exp()->accept(this);
return {};
} }
if (!ctx->btype() || !ctx->btype()->INT()) { if (ctx->block()) return ctx->block()->accept(this);
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); if (ctx->IF()) {
if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this);
if (ctx->stmt(1)) ctx->stmt(1)->accept(this);
return {};
} }
auto* var_def = ctx->varDef(); if (ctx->WHILE()) {
if (!var_def || !var_def->lValue()) { loop_depth_++;
throw std::runtime_error(FormatError("sema", "非法变量声明")); if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this);
loop_depth_--;
return {};
} }
const std::string name = GetLValueName(*var_def->lValue()); if (ctx->BREAK()) {
if (table_.Contains(name)) { if (loop_depth_ == 0) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); throw std::runtime_error(FormatError("sema", "break 不在循环内"));
} }
if (auto* init = var_def->initValue()) { return {};
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
} }
init->exp()->accept(this); if (ctx->CONTINUE()) {
if (loop_depth_ == 0) {
throw std::runtime_error(FormatError("sema", "continue 不在循环内"));
} }
table_.Add(name, var_def); return {};
}
if (ctx->RETURN()) {
if (ctx->exp()) ctx->exp()->accept(this);
if (current_ret_ == BaseTypeKind::Void && ctx->exp()) {
throw std::runtime_error(FormatError("sema", "void 函数不能返回值"));
}
if (current_ret_ != BaseTypeKind::Void && !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值"));
}
seen_return_ = true;
return {};
}
if (ctx->exp()) ctx->exp()->accept(this);
return {}; return {};
} }
std::any visitStmt(SysYParser::StmtContext* ctx) override { std::any visitExp(SysYParser::ExpContext* ctx) override {
if (!ctx || !ctx->returnStmt()) { if (ctx->addExp()) return ctx->addExp()->accept(this);
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); return {};
} }
ctx->returnStmt()->accept(this);
std::any visitCond(SysYParser::CondContext* ctx) override {
if (ctx->lOrExp()) return ctx->lOrExp()->accept(this);
return {}; return {};
} }
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
if (!ctx || !ctx->exp()) { for (auto* e : ctx->lAndExp()) e->accept(this);
throw std::runtime_error(FormatError("sema", "return 缺少表达式")); return {};
} }
ctx->exp()->accept(this);
seen_return_ = true; std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
if (current_item_index_ + 1 != total_items_) { for (auto* e : ctx->eqExp()) e->accept(this);
throw std::runtime_error( return {};
FormatError("sema", "return 必须是 main 函数中的最后一条语句")); }
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
for (auto* e : ctx->relExp()) e->accept(this);
return {};
} }
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
for (auto* e : ctx->addExp()) e->accept(this);
return {}; return {};
} }
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
if (!ctx || !ctx->exp()) { for (auto* mul : ctx->mulExp()) mul->accept(this);
throw std::runtime_error(FormatError("sema", "非法括号表达式")); return {};
} }
ctx->exp()->accept(this);
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
for (auto* unary : ctx->unaryExp()) unary->accept(this);
return {}; return {};
} }
std::any visitVarExp(SysYParser::VarExpContext* ctx) override { std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (!ctx || !ctx->var()) { if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
throw std::runtime_error(FormatError("sema", "非法变量表达式")); if (ctx->ID() && ctx->LPAREN()) {
std::string name = ctx->ID()->getText();
auto it = func_table_.find(name);
if (it == func_table_.end()) {
if (builtin_funcs_.find(name) == builtin_funcs_.end()) {
throw std::runtime_error(FormatError("sema", "未定义的函数: " + name));
}
} else {
sema_.BindFuncCall(ctx, it->second);
}
if (ctx->funcRParams()) ctx->funcRParams()->accept(this);
return {};
} }
ctx->var()->accept(this); if (ctx->unaryExp()) return ctx->unaryExp()->accept(this);
return {}; return {};
} }
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { for (auto* e : ctx->exp()) e->accept(this);
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); return {};
} }
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this);
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
return {}; return {};
} }
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { if (!ctx->INT_CONST() && !ctx->FLOAT_CONST()) {
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); throw std::runtime_error(FormatError("sema", "非法常量"));
} }
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {}; return {};
} }
std::any visitVar(SysYParser::VarContext* ctx) override { std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) { if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用")); throw std::runtime_error(FormatError("sema", "非法变量引用"));
} }
const std::string name = ctx->ID()->getText(); std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name); const SymbolEntry* entry = table_.Lookup(name);
if (!decl) { if (!entry) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
} }
sema_.BindVarUse(ctx, decl); BoundDecl bound;
if (entry->kind == SymbolKind::Var) {
bound.kind = BoundDecl::Kind::Var;
bound.var_decl = entry->var_decl;
} else if (entry->kind == SymbolKind::Const) {
bound.kind = BoundDecl::Kind::Const;
bound.const_decl = entry->const_decl;
} else {
bound.kind = BoundDecl::Kind::Param;
bound.param_decl = entry->param_decl;
}
sema_.BindVarUse(ctx, bound);
return {}; return {};
} }
SemanticContext TakeSemanticContext() { return std::move(sema_); } SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
TypeDesc BuildParamType(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法参数"));
}
TypeDesc ty;
ty.base = BaseTypeFromBType(ctx->bType());
if (ctx->LBRACK().size() > 0) {
ty.dims.push_back(-1);
for (auto* exp : ctx->exp()) {
ty.dims.push_back(EvalConstExp(exp));
}
}
return ty;
}
void RegisterParam(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "参数缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义参数: " + name));
}
TypeDesc ty = BuildParamType(ctx);
SymbolEntry entry;
entry.kind = SymbolKind::Param;
entry.param_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterParam(ctx, ty);
}
void RegisterVar(SysYParser::VarDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "变量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
TypeDesc ty;
ty.base = base;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Var;
entry.var_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterVarDecl(ctx, ty);
if (auto* init = ctx->initVal()) {
init->accept(this);
}
}
void RegisterConst(SysYParser::ConstDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "常量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
TypeDesc ty;
ty.base = base;
ty.is_const = true;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Const;
entry.const_decl = ctx;
entry.is_const = true;
entry.type = ty;
if (ctx->constInitVal() && ty.dims.empty() && ty.base == BaseTypeKind::Int) {
if (auto* exp = ctx->constInitVal()->constExp()) {
entry.const_value = EvalConstExp(exp);
}
}
table_.Add(name, entry);
sema_.RegisterConstDecl(ctx, ty);
if (auto* init = ctx->constInitVal()) {
init->accept(this);
}
}
int EvalConstExp(SysYParser::ConstExpContext* ctx) {
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->accept(&visitor));
}
int EvalConstExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "非法常量表达式"));
}
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->addExp()->accept(&visitor));
}
private: private:
SymbolTable table_; SymbolTable table_;
SemanticContext sema_; SemanticContext sema_;
std::unordered_map<std::string, SysYParser::FuncDefContext*> func_table_;
const std::unordered_set<std::string> builtin_funcs_ = {
"getint", "getch", "getarray", "putint", "putch", "putarray",
"getfloat", "getfarray", "putfloat", "putfarray", "starttime",
"stoptime"};
BaseTypeKind current_ret_ = BaseTypeKind::Void;
bool seen_return_ = false; bool seen_return_ = false;
size_t current_item_index_ = 0; int loop_depth_ = 0;
size_t total_items_ = 0;
}; };
} // namespace } // namespace

@ -2,16 +2,34 @@
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
void SymbolTable::Add(const std::string& name, void SymbolTable::EnterScope() { scopes_.emplace_back(); }
SysYParser::VarDefContext* decl) {
table_[name] = decl; void SymbolTable::ExitScope() {
if (!scopes_.empty()) {
scopes_.pop_back();
}
}
bool SymbolTable::ContainsInCurrentScope(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.back().find(name) != scopes_.back().end();
} }
bool SymbolTable::Contains(const std::string& name) const { void SymbolTable::Add(const std::string& name, const SymbolEntry& entry) {
return table_.find(name) != table_.end(); if (scopes_.empty()) {
EnterScope();
}
scopes_.back()[name] = entry;
} }
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { const SymbolEntry* SymbolTable::Lookup(const std::string& name) const {
auto it = table_.find(name); for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
return it == table_.end() ? nullptr : it->second; auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
} }

@ -1,4 +1,61 @@
// SysY 运行库实现: #include <stdio.h>
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为 int getint() {
int v = 0;
if (scanf("%d", &v) != 1) return 0;
return v;
}
int getch() { return getchar(); }
int getarray(int a[]) {
int n = 0;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; ++i) {
scanf("%d", &a[i]);
}
return n;
}
void putint(int x) { printf("%d", x); }
void putch(int x) { putchar(x); }
void putarray(int n, int a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %d", a[i]);
}
printf("\n");
}
float getfloat() {
float v = 0.0f;
if (scanf("%f", &v) != 1) return 0.0f;
return v;
}
int getfarray(float a[]) {
int n = 0;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; ++i) {
scanf("%f", &a[i]);
}
return n;
}
void putfloat(float x) { printf("%a", x); }
void putfarray(int n, float a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %a", a[i]);
}
printf("\n");
}
// Performance timing hooks (no-op stubs for correctness testing).
void starttime() {}
void stoptime() {}

Loading…
Cancel
Save