From 308bcac3fa5468fb1813ad48a6fdefa76068f266 Mon Sep 17 00:00:00 2001 From: Su Xing Date: Thu, 6 Apr 2023 14:28:33 +0800 Subject: [PATCH] Partial implementation of IR generator. Now can generate a single block function within +/-/*// and return. --- src/CMakeLists.txt | 1 + src/IR.cpp | 378 +++++++++++++++++++++++++++++++++++++--- src/IR.h | 168 ++++++++++++++---- src/SysY.g4 | 4 +- src/SysYFormatter.h | 44 ++--- src/SysYIRGenerator.cpp | 191 +++++++++++++++++--- src/SysYIRGenerator.h | 123 +++++++++---- src/sysyc.cpp | 11 +- 8 files changed, 780 insertions(+), 140 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 856a640..84d5d47 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,6 +15,7 @@ add_executable(sysyc sysyc.cpp IR.cpp SysYIRGenerator.cpp + Diagnostic.cpp ) target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(sysyc PRIVATE SysYParser) diff --git a/src/IR.cpp b/src/IR.cpp index 941cead..acd8455 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -1,28 +1,50 @@ #include "IR.h" +#include "range.h" #include #include #include +#include #include #include #include #include +#include #include +#include +#include #include using namespace std; namespace sysy { template -std::ostream &interleave(std::ostream &os, const T &container, - const std::string sep = ", ") { +ostream &interleave(ostream &os, const T &container, const string sep = ", ") { auto b = container.begin(), e = container.end(); if (b == e) return os; os << *b; - for (b = std::next(b); b != e; b = std::next(b)) + for (b = next(b); b != e; b = next(b)) os << sep << *b; return os; } +static inline ostream &printVarName(ostream &os, const Value *var) { + return os << (dynamic_cast(var) ? '@' : '%') + << var->getName(); +} +static inline ostream &printBlockName(ostream &os, const BasicBlock *block) { + return os << '^' << block->getName(); +} +static inline ostream &printFunctionName(ostream &os, const Function *fn) { + return os << '@' << fn->getName(); +} +static inline ostream &printOperand(ostream &os, const Value *value) { + auto constant = dynamic_cast(value); + if (constant) { + constant->print(os); + return os; + } + return printVarName(os, value); +} //===----------------------------------------------------------------------===// // Types //===----------------------------------------------------------------------===// @@ -53,7 +75,7 @@ Type *Type::getPointerType(Type *baseType) { } Type *Type::getFunctionType(Type *returnType, - const std::vector ¶mTypes) { + const vector ¶mTypes) { // forward to FunctionType return FunctionType::get(returnType, paramTypes); } @@ -74,27 +96,28 @@ int Type::getSize() const { } void Type::print(ostream &os) const { - switch (getKind()) { - kInt: + auto kind = getKind(); + switch (kind) { + case kInt: os << "int"; break; - kFloat: + case kFloat: os << "float"; break; - kVoid: + case kVoid: os << "void"; break; - kPointer: + case kPointer: static_cast(this)->getBaseType()->print(os); os << "*"; break; - kFunction: + case kFunction: static_cast(this)->getReturnType()->print(os); os << "("; interleave(os, static_cast(this)->getParamTypes()); os << ")"; break; - kLabel: + case kLabel: default: cerr << "Unexpected type!\n"; break; @@ -138,28 +161,343 @@ void Value::replaceAllUsesWith(Value *value) { uses.clear(); } -ConstantValue *ConstantValue::get(int value, const std::string &name) { +bool Value::isConstant() const { + if (dynamic_cast(this)) + return true; + if (dynamic_cast(this) or + dynamic_cast(this)) + return true; + if (auto array = dynamic_cast(this)) { + auto elements = array->getValues(); + return all_of(elements.begin(), elements.end(), + [](Value *v) -> bool { return v->isConstant(); }); + } + return false; +} + +ConstantValue *ConstantValue::get(int value) { static std::map> intConstants; auto iter = intConstants.find(value); if (iter != intConstants.end()) return iter->second.get(); - auto inst = new ConstantValue(value); - assert(inst); - auto result = intConstants.emplace(value, inst); + auto constant = new ConstantValue(value); + assert(constant); + auto result = intConstants.emplace(value, constant); return result.first->second.get(); } -ConstantValue *ConstantValue::get(float value, const std::string &name) { +ConstantValue *ConstantValue::get(float value) { static std::map> floatConstants; auto iter = floatConstants.find(value); if (iter != floatConstants.end()) return iter->second.get(); - auto inst = new ConstantValue(value); - assert(inst); - auto result = floatConstants.emplace(value, inst); + auto constant = new ConstantValue(value); + assert(constant); + auto result = floatConstants.emplace(value, constant); return result.first->second.get(); } +void ConstantValue::print(ostream &os) const { + if (isInt()) + os << getInt(); + else + os << getFloat(); +} + +Argument::Argument(Type *type, BasicBlock *block, int index, + const std::string &name) + : Value(type, name), block(block), index(index) { + if (not hasName()) + setName(to_string(block->getParent()->allocateVariableID())); +} + +void Argument::print(std::ostream &os) const { + assert(hasName()); + printVarName(os, this) << ": " << *getType(); +} + +BasicBlock::BasicBlock(Function *parent, const std::string &name) + : Value(Type::getLabelType(), name), parent(parent), instructions(), + arguments(), successors(), predecessors() { + if (not hasName()) + setName("bb" + to_string(getParent()->allocateblockID())); +} + +void BasicBlock::print(std::ostream &os) const { + assert(hasName()); + os << " "; + printBlockName(os, this); + auto args = getArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) { + printVarName(os, &*b) << ": " << *b->getType(); + for (auto arg : make_range(std::next(b), e)) { + os << ", "; + printVarName(os, &arg) << ": " << arg.getType(); + } + } + os << ":\n"; + for (auto &inst : instructions) { + os << " " << *inst << '\n'; + } +} + +Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent, + const std::string &name) + : User(type, name), kind(kind), parent(parent) { + if (not type->isVoid() and not hasName()) + setName(to_string(getFunction()->allocateVariableID())); +} + +void CallInst::print(std::ostream &os) const { + if (not getType()->isVoid()) + printVarName(os, this) << " = "; + printFunctionName(os, getCallee()) << '('; + auto args = getArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) { + printOperand(os, *b); + for (auto arg : make_range(std::next(b), e)) { + os << ", "; + printOperand(os, arg); + } + } + os << ") : " << *getType(); +} + +void UnaryInst::print(std::ostream &os) const { + printVarName(os, this) << " = "; + switch (getKind()) { + case kNeg: + os << "neg"; + break; + case kNot: + os << "not"; + break; + case kFNeg: + os << "fneg"; + break; + case kFtoI: + os << "ftoi"; + break; + case kIToF: + os << "itof"; + break; + default: + assert(false); + } + printOperand(os, getOperand()) << " : " << *getType(); +} + +void BinaryInst::print(std::ostream &os) const { + printVarName(os, this) << " = "; + switch (getKind()) { + case kAdd: + os << "add"; + break; + case kSub: + os << "sub"; + break; + case kMul: + os << "mul"; + break; + case kDiv: + os << "div"; + break; + case kRem: + os << "rem"; + break; + case kICmpEQ: + os << "icmpeq"; + break; + case kICmpNE: + os << "icmpne"; + break; + case kICmpLT: + os << "icmplt"; + break; + case kICmpGT: + os << "icmpgt"; + break; + case kICmpLE: + os << "icmple"; + break; + case kICmpGE: + os << "icmpge"; + break; + case kFAdd: + os << "fadd"; + break; + case kFSub: + os << "fsub"; + break; + case kFMul: + os << "fmul"; + break; + case kFDiv: + os << "fdiv"; + break; + case kFRem: + os << "frem"; + break; + case kFCmpEQ: + os << "fcmpeq"; + break; + case kFCmpNE: + os << "fcmpne"; + break; + case kFCmpLT: + os << "fcmplt"; + break; + case kFCmpGT: + os << "fcmpgt"; + break; + case kFCmpLE: + os << "fcmple"; + break; + case kFCmpGE: + os << "fcmpge"; + break; + default: + assert(false); + } + printOperand(os, getLhs()) << ", "; + printOperand(os, getRhs()) << " : " << *getType(); +} + +void ReturnInst::print(std::ostream &os) const { + os << "return"; + if (auto value = getReturnValue()) { + os << ' '; + printOperand(os, value) << " : " << *value->getType(); + } +} + +void UncondBrInst::print(std::ostream &os) const { + os << "br "; + printBlockName(os, getBlock()); + auto args = getArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) { + os << '('; + printOperand(os, *b); + for (auto arg : make_range(std::next(b), e)) { + os << ", "; + printOperand(os, arg); + } + os << ')'; + } +} + +void CondBrInst::print(std::ostream &os) const { + os << "condbr "; + printOperand(os, getCondition()) << ", "; + printBlockName(os, getThenBlock()); + { + auto args = getThenArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) { + os << '('; + printOperand(os, *b); + for (auto arg : make_range(std::next(b), e)) { + os << ", "; + printOperand(os, arg); + } + os << ')'; + } + } + os << ", "; + printBlockName(os, getElseBlock()); + { + auto args = getElseArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) { + os << '('; + printOperand(os, *b); + for (auto arg : make_range(std::next(b), e)) { + os << ", "; + printOperand(os, arg); + } + os << ')'; + } + } +} + +void AllocaInst::print(std::ostream &os) const { + if (getNumDims()) + cerr << "not implemented yet\n"; + printVarName(os, this) << " = "; + os << "alloca " + << *static_cast(getType())->getBaseType(); + os << " : " << *getType(); +} + +void LoadInst::print(std::ostream &os) const { + if (getNumIndices()) + cerr << "not implemented yet\n"; + printVarName(os, this) << " = "; + os << "load "; + printOperand(os, getPointer()) << " : " << *getType(); +} + +void StoreInst::print(std::ostream &os) const { + if (getNumIndices()) + cerr << "not implemented yet\n"; + os << "store "; + printOperand(os, getValue()) << ", "; + printOperand(os, getPointer()) << " : " << *getValue()->getType(); +} + +void Function::print(std::ostream &os) const { + auto returnType = getReturnType(); + auto paramTypes = getParamTypes(); + os << *returnType << ' '; + printFunctionName(os, this) << '('; + interleave(os, paramTypes) << ')'; + os << "{\n"; + for (auto &bb : getBasicBlocks()) { + os << *bb << '\n'; + } + os << "}"; +} + +void Module::print(std::ostream &os) const { + for (auto &g : globals) + os << *g.second << '\n'; + for (auto &f : functions) + os << *f.second << '\n'; +} + +ArrayValue *ArrayValue::get(Type *type, const vector &values) { + static map, unique_ptr> arrayConstants; + hash hasher; + auto key = make_pair( + type, hasher(string(reinterpret_cast(values.data()), + values.size() * sizeof(Value *)))); + + auto iter = arrayConstants.find(key); + if (iter != arrayConstants.end()) + return iter->second.get(); + auto constant = new ArrayValue(type, values); + assert(constant); + auto result = arrayConstants.emplace(key, constant); + return result.first->second.get(); +} + +ArrayValue *ArrayValue::get(const std::vector &values) { + vector vals(values.size(), nullptr); + std::transform(values.begin(), values.end(), vals.begin(), + [](int v) { return ConstantValue::get(v); }); + return get(Type::getIntType(), vals); +} + +ArrayValue *ArrayValue::get(const std::vector &values) { + vector vals(values.size(), nullptr); + std::transform(values.begin(), values.end(), vals.begin(), + [](float v) { return ConstantValue::get(v); }); + return get(Type::getFloatType(), vals); +} + void User::setOperand(int index, Value *value) { assert(index < getNumOperands()); operands[index].setValue(value); @@ -180,7 +518,7 @@ CallInst::CallInst(Function *callee, const std::vector args, addOperand(arg); } -Function *CallInst::getCallee() { +Function *CallInst::getCallee() const { return dynamic_cast(getOperand(0)); } diff --git a/src/IR.h b/src/IR.h index 22540c9..9591c49 100644 --- a/src/IR.h +++ b/src/IR.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -66,6 +67,7 @@ public: bool isLabel() const { return kind == kLabel; } bool isPointer() const { return kind == kPointer; } bool isFunction() const { return kind == kFunction; } + bool isIntOrFloat() const { return kind == kInt or kind == kFloat; } int getSize() const; template std::enable_if_t, T *> as() const { @@ -185,6 +187,9 @@ protected: public: Type *getType() const { return type; } + const std::string &getName() const { return name; } + void setName(const std::string &n) { name = n; } + bool hasName() const { return not name.empty(); } bool isInt() const { return type->isInt(); } bool isFloat() const { return type->isFloat(); } bool isPointer() const { return type->isPointer(); } @@ -192,6 +197,10 @@ public: void addUse(Use *use) { uses.push_back(use); } void replaceAllUsesWith(Value *value); void removeUse(Use *use) { uses.remove(use); } + bool isConstant() const; + +public: + virtual void print(std::ostream &os) const = 0; }; // class Value /*! @@ -208,14 +217,13 @@ protected: }; protected: - ConstantValue(int value, const std::string &name = "") - : Value(Type::getIntType(), name), iScalar(value) {} - ConstantValue(float value, const std::string &name = "") - : Value(Type::getFloatType(), name), fScalar(value) {} + ConstantValue(int value) : Value(Type::getIntType(), ""), iScalar(value) {} + ConstantValue(float value) + : Value(Type::getFloatType(), ""), fScalar(value) {} public: - static ConstantValue *get(int value, const std::string &name = ""); - static ConstantValue *get(float value, const std::string &name = ""); + static ConstantValue *get(int value); + static ConstantValue *get(float value); public: int getInt() const { @@ -226,6 +234,9 @@ public: assert(isFloat()); return fScalar; } + +public: + virtual void print(std::ostream &os) const override; }; // class ConstantValue class BasicBlock; @@ -246,8 +257,14 @@ protected: public: Argument(Type *type, BasicBlock *block, int index, - const std::string &name = "") - : Value(type, name), block(block), index(index) {} + const std::string &name = ""); + +public: + BasicBlock *getParent() const { return block; } + int getIndex() const { return index; } + +public: + virtual void print(std::ostream &os) const override; }; class Instruction; @@ -276,9 +293,7 @@ protected: block_list predecessors; protected: - explicit BasicBlock(Function *parent, const std::string &name = "") - : Value(Type::getLabelType(), name), parent(parent), instructions(), - arguments(), successors(), predecessors() {} + explicit BasicBlock(Function *parent, const std::string &name = ""); public: int getNumInstructions() const { return instructions.size(); } @@ -287,7 +302,7 @@ public: int getNumSuccessors() const { return successors.size(); } Function *getParent() const { return parent; } inst_list &getInstructions() { return instructions; } - arg_list &getArguments() { return arguments; } + auto getArguments() const { return make_range(arguments); } block_list &getPredecessors() { return predecessors; } block_list &getSuccessors() { return successors; } iterator begin() { return instructions.begin(); } @@ -297,6 +312,9 @@ public: arguments.emplace_back(type, this, arguments.size(), name); return &arguments.back(); }; + +public: + virtual void print(std::ostream &os) const override; }; // class BasicBlock //! User is the abstract base type of `Value` types which use other `Value` as @@ -311,17 +329,25 @@ protected: : Value(type, name), operands() {} public: - struct operand_iterator : std::vector::iterator { - using Base = std::vector::iterator; - using Base::Base; + using use_iterator = std::vector::const_iterator; + struct operand_iterator : public std::vector::const_iterator { + using Base = std::vector::const_iterator; + operand_iterator(const Base &iter) : Base(iter) {} using value_type = Value *; - value_type operator->() { return operator*().getValue(); } + value_type operator->() { return Base::operator*().getValue(); } + value_type operator*() { return Base::operator*().getValue(); } }; + // struct const_operand_iterator : std::vector::const_iterator { + // using Base = std::vector::const_iterator; + // const_operand_iterator(const Base &iter) : Base(iter) {} + // using value_type = Value *; + // value_type operator->() { return operator*().getValue(); } + // }; public: int getNumOperands() const { return operands.size(); } - auto operand_begin() const { return operands.begin(); } - auto operand_end() const { return operands.end(); } + operand_iterator operand_begin() const { return operands.begin(); } + operand_iterator operand_end() const { return operands.end(); } auto getOperands() const { return make_range(operand_begin(), operand_end()); } @@ -336,7 +362,6 @@ public: } void replaceOperand(int index, Value *value); void setOperand(int index, Value *value); - const std::string &getName() const { return name; } }; // class User /*! @@ -372,7 +397,7 @@ public: // Unary kNeg = 0x1UL << 25, kNot = 0x1UL << 26, - kFNeg = 0x1UL << 26, + kFNeg = 0x1UL << 27, kFtoI = 0x1UL << 28, kIToF = 0x1UL << 29, // call @@ -395,12 +420,12 @@ protected: protected: Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, - const std::string &name = "") - : User(type, name), kind(kind), parent(parent) {} + const std::string &name = ""); public: Kind getKind() const { return kind; } BasicBlock *getParent() const { return parent; } + Function *getFunction() const { return parent->getParent(); } void setParent(BasicBlock *bb) { parent = bb; } bool isBinary() const { @@ -452,10 +477,13 @@ protected: BasicBlock *parent = nullptr, const std::string &name = ""); public: - Function *getCallee(); - auto getArguments() { + Function *getCallee() const; + auto getArguments() const { return make_range(std::next(operand_begin()), operand_end()); } + +public: + virtual void print(std::ostream &os) const override; }; // class CallInst //! Unary instruction, includes '!', '-' and type conversion. @@ -471,6 +499,9 @@ protected: public: Value *getOperand() const { return User::getOperand(0); } + +public: + virtual void print(std::ostream &os) const override; }; // class UnaryInst //! Binary instruction, e.g., arithmatic, relation, logic, etc. @@ -488,6 +519,9 @@ protected: public: Value *getLhs() const { return getOperand(0); } Value *getRhs() const { return getOperand(1); } + +public: + virtual void print(std::ostream &os) const override; }; // class BinaryInst //! The return statement @@ -506,6 +540,9 @@ public: Value *getReturnValue() const { return hasReturnValue() ? getOperand(0) : nullptr; } + +public: + virtual void print(std::ostream &os) const override; }; // class ReturnInst //! Unconditional branch @@ -526,8 +563,11 @@ public: return dynamic_cast(getOperand(0)); } auto getArguments() const { - return make_range(std::next(operands.begin()), operands.end()); + return make_range(std::next(operand_begin()), operand_end()); } + +public: + virtual void print(std::ostream &os) const override; }; // class UncondBrInst //! Conditional branch @@ -549,22 +589,27 @@ protected: } public: + Value *getCondition() const { return getOperand(0); } BasicBlock *getThenBlock() const { - return dynamic_cast(getOperand(0)); + return dynamic_cast(getOperand(1)); } BasicBlock *getElseBlock() const { - return dynamic_cast(getOperand(1)); + return dynamic_cast(getOperand(2)); } auto getThenArguments() const { - auto begin = operands.begin() + 2; - auto end = begin + getThenBlock()->getNumArguments(); + auto begin = std::next(operand_begin(), 3); + auto end = std::next(begin, getThenBlock()->getNumArguments()); return make_range(begin, end); } auto getElseArguments() const { - auto begin = operands.begin() + 2 + getThenBlock()->getNumArguments(); - auto end = operands.end(); + auto begin = + std::next(operand_begin(), 3 + getThenBlock()->getNumArguments()); + auto end = operand_end(); return make_range(begin, end); } + +public: + virtual void print(std::ostream &os) const override; }; // class CondBrInst //! Allocate memory for stack variables, used for non-global variable declartion @@ -582,6 +627,8 @@ public: int getNumDims() const { return getNumOperands(); } auto getDims() const { return getOperands(); } Value *getDim(int index) { return getOperand(index); } +public: + virtual void print(std::ostream &os) const override; }; // class AllocaInst //! Load a value from memory address specified by a pointer value @@ -593,6 +640,7 @@ protected: BasicBlock *parent = nullptr, const std::string &name = "") : Instruction(kLoad, pointer->getType()->as()->getBaseType(), parent, name) { + addOperand(pointer); addOperands(indices); } @@ -603,6 +651,8 @@ public: return make_range(std::next(operand_begin()), operand_end()); } Value *getIndex(int index) const { return getOperand(index + 1); } +public: + virtual void print(std::ostream &os) const override; }; // class LoadInst //! Store a value to memory address specified by a pointer value @@ -624,9 +674,11 @@ public: Value *getValue() const { return getOperand(0); } Value *getPointer() const { return getOperand(1); } auto getIndices() const { - return make_range(operand_begin() + 2, operand_end()); + return make_range(std::next(operand_begin(), 2), operand_end()); } Value *getIndex(int index) const { return getOperand(index + 2); } +public: + virtual void print(std::ostream &os) const override; }; // class StoreInst class Module; @@ -636,7 +688,7 @@ class Function : public Value { protected: Function(Module *parent, Type *type, const std::string &name) - : Value(type, name), parent(parent), blocks() { + : Value(type, name), parent(parent), variableID(0), blocks() { blocks.emplace_back(new BasicBlock(this, "entry")); } @@ -645,6 +697,8 @@ public: protected: Module *parent; + int variableID; + int blockID; block_list blocks; public: @@ -654,8 +708,8 @@ public: auto getParamTypes() const { return getType()->as()->getParamTypes(); } - auto getBasicBlocks() { return make_range(blocks); } - BasicBlock *getEntryBlock() { return blocks.front().get(); } + auto getBasicBlocks() const { return make_range(blocks); } + BasicBlock *getEntryBlock() const { return blocks.front().get(); } BasicBlock *addBasicBlock(const std::string &name = "") { blocks.emplace_back(new BasicBlock(this, name)); return blocks.back().get(); @@ -665,8 +719,30 @@ public: return block == b.get(); }); } + int allocateVariableID() { return variableID++; } + int allocateblockID() { return blockID++; } +public: + virtual void print(std::ostream &os) const override; }; // class Function +class ArrayValue : public User { +protected: + ArrayValue(Type *type, const std::vector &values = {}) + : User(type, "") { + addOperands(values); + } + +public: + static ArrayValue *get(Type *type, const std::vector &values); + static ArrayValue *get(const std::vector &values); + static ArrayValue *get(const std::vector &values); + +public: + auto getValues() const { return getOperands(); } +public: + virtual void print(std::ostream &os) const override {}; +}; // class ConstantArray + //! Global value declared at file scope class GlobalValue : public User { friend class Module; @@ -674,6 +750,7 @@ class GlobalValue : public User { protected: Module *parent; bool hasInit; + bool isConst; protected: GlobalValue(Module *parent, Type *type, const std::string &name, @@ -689,6 +766,8 @@ public: Value *init() const { return hasInit ? operands.back().getValue() : nullptr; } int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); } Value *getDim(int index) { return getOperand(index); } +public: + virtual void print(std::ostream &os) const override {}; }; // class GlobalValue //! IR unit for representing a SysY compile unit @@ -708,9 +787,10 @@ public: return result.first->second.get(); }; GlobalValue *createGlobalValue(const std::string &name, Type *type, - const std::vector &dims = {}) { - auto result = - globals.try_emplace(name, new GlobalValue(this, type, name, dims)); + const std::vector &dims = {}, + Value *init = nullptr) { + auto result = globals.try_emplace( + name, new GlobalValue(this, type, name, dims, init)); if (not result.second) return nullptr; return result.first->second.get(); @@ -727,9 +807,21 @@ public: return nullptr; return result->second.get(); } +public: + void print(std::ostream &os) const; }; // class Module /*! * @} */ +inline std::ostream &operator<<(std::ostream &os, const Type &type) { + type.print(os); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const Value &value) { + value.print(os); + return os; +} + } // namespace sysy \ No newline at end of file diff --git a/src/SysY.g4 b/src/SysY.g4 index b6b4232..b42eb00 100644 --- a/src/SysY.g4 +++ b/src/SysY.g4 @@ -99,9 +99,7 @@ btype: INT | FLOAT; varDef: lValue (ASSIGN initValue)?; -initValue: - exp # scalarInitValue - | LBRACE (initValue (COMMA initValue)*)? # arrayInitValue; +initValue: exp | LBRACE (initValue (COMMA initValue)*)?; func: funcType ID LPAREN funcFParams? RPAREN blockStmt; diff --git a/src/SysYFormatter.h b/src/SysYFormatter.h index fdaf615..d4c9fb7 100644 --- a/src/SysYFormatter.h +++ b/src/SysYFormatter.h @@ -34,9 +34,9 @@ protected: } public: -// virtual std::any visitModule(SysYParser::ModuleContext *ctx) override { -// return visitChildren(ctx); -// } + // virtual std::any visitModule(SysYParser::ModuleContext *ctx) override { + // return visitChildren(ctx); + // } virtual std::any visitBtype(SysYParser::BtypeContext *ctx) override { os << ctx->getText(); @@ -62,13 +62,14 @@ public: return 0; } - virtual std::any - visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) override { - os << '{'; - auto values = ctx->initValue(); - if (values.size()) - interleave(values, ", "); - os << '}'; + virtual std::any visitInitValue(SysYParser::InitValueContext *ctx) override { + if (not ctx->exp()) { + os << '{'; + auto values = ctx->initValue(); + if (values.size()) + interleave(values, ", "); + os << '}'; + } return 0; } @@ -191,8 +192,7 @@ public: virtual std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override { - space() << ctx->CONTINUE()->getText() << ';' - << '\n'; + space() << ctx->CONTINUE()->getText() << ';' << '\n'; return 0; } @@ -235,13 +235,15 @@ public: return 0; } -// virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override { -// return visitChildren(ctx); -// } + // virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) + // override { + // return visitChildren(ctx); + // } -// virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override { -// return visitChildren(ctx); -// } + // virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) + // override { + // return visitChildren(ctx); + // } virtual std::any visitAndExp(SysYParser::AndExpContext *ctx) override { ctx->exp(0)->accept(this); @@ -275,9 +277,9 @@ public: return 0; } -// virtual std::any visitCallExp(SysYParser::CallExpContext *ctx) override { -// return visitChildren(ctx); -// } + // virtual std::any visitCallExp(SysYParser::CallExpContext *ctx) override { + // return visitChildren(ctx); + // } virtual std::any visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override { diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index b0a2cf3..3d76d67 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -1,65 +1,214 @@ +#include "IR.h" +#include #include #include #include using namespace std; +#include "Diagnostic.h" #include "SysYIRGenerator.h" namespace sysy { any SysYIRGenerator::visitModule(SysYParser::ModuleContext *ctx) { + // global scople of the module + SymbolTable::ModuleScope scope(symbols); + // create the IR module auto pModule = new Module(); assert(pModule); module.reset(pModule); + // generates globals and functions visitChildren(ctx); + // return the IR module return pModule; } +any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) { + // global and local declarations are handled in different ways + return symbols.isModuleScope() ? visitGlobalDecl(ctx) : visitLocalDecl(ctx); +} + +any SysYIRGenerator::visitGlobalDecl(SysYParser::DeclContext *ctx) { + error(ctx, "not implemented yet"); + std::vector values; + bool isConst = ctx->CONST(); + auto type = any_cast(visitBtype(ctx->btype())); + for (auto varDef : ctx->varDef()) { + auto name = varDef->lValue()->ID()->getText(); + vector dims; + for (auto exp : varDef->lValue()->exp()) + dims.push_back(any_cast(exp->accept(this))); + auto init = varDef->ASSIGN() + ? any_cast(visitInitValue(varDef->initValue())) + : nullptr; + values.push_back(module->createGlobalValue(name, type, dims, init)); + } + // FIXME const + return values; +} + +any SysYIRGenerator::visitLocalDecl(SysYParser::DeclContext *ctx) { + // a single declaration statement can declare several variables, + // collect them in a vector + std::vector values; + // obtain the declared type + auto type = Type::getPointerType(any_cast(visitBtype(ctx->btype()))); + // handle variables + for (auto varDef : ctx->varDef()) { + // obtain the variable name and allocate space on the stack + auto name = varDef->lValue()->ID()->getText(); + auto alloca = builder.createAllocaInst(type, {}, name); + // update the symbol table + symbols.insert(name, alloca); + // if an initial value is given, create a store instruction + if (varDef->ASSIGN()) { + auto value = any_cast(visitInitValue(varDef->initValue())); + auto store = builder.createStoreInst(value, alloca); + } + // collect the created variable (pointer) + values.push_back(alloca); + } + return values; +} + any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) { + // create the function scope + SymbolTable::FunctionScope scope(symbols); + // obtain function name and type signature auto name = ctx->ID()->getText(); - auto params = ctx->funcFParams()->funcFParam(); vector paramTypes; vector paramNames; - for (auto param : params) { - paramTypes.push_back(any_cast(visitBtype(param->btype()))); - paramNames.push_back(param->ID()->getText()); + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (auto param : params) { + paramTypes.push_back(any_cast(visitBtype(param->btype()))); + paramNames.push_back(param->ID()->getText()); + } } Type *returnType = any_cast(visitFuncType(ctx->funcType())); auto funcType = Type::getFunctionType(returnType, paramTypes); + // create the IR function auto function = module->createFunction(name, funcType); + // create the entry block with the same parameters as the function, + // and update the symbol table auto entry = function->getEntryBlock(); - for (auto i = 0; i < paramTypes.size(); ++i) - entry->createArgument(paramTypes[i], paramNames[i]); + for (auto i = 0; i < paramTypes.size(); ++i) { + auto arg = entry->createArgument(paramTypes[i], paramNames[i]); + symbols.insert(paramNames[i], arg); + } + // setup the instruction insert point builder.setPosition(entry, entry->end()); + // generate the function body visitBlockStmt(ctx->blockStmt()); return function; } + any SysYIRGenerator::visitBtype(SysYParser::BtypeContext *ctx) { return ctx->INT() ? Type::getIntType() : Type::getFloatType(); } +any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { + return ctx->INT() + ? Type::getIntType() + : (ctx->FLOAT() ? Type::getFloatType() : Type::getVoidType()); +} any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { + // the insert point has already been setup for (auto item : ctx->blockItem()) visitBlockItem(item); + // return the corresponding IR block return builder.getBasicBlock(); } -any SysYIRGenerator::visitBlockItem(SysYParser::BlockItemContext *ctx) { - return ctx->decl() ? visitDecl(ctx->decl()) : visitStmt(ctx->stmt()); +any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { + // generate the rhs expression + auto rhs = any_cast(ctx->exp()->accept(this)); + // get the address of the lhs variable + auto lValue = ctx->lValue(); + auto name = lValue->ID()->getText(); + auto pointer = symbols.lookup(name); + if (not pointer) + error(ctx, "undefined variable"); + // update the variable + Value *store = builder.createStoreInst(rhs, pointer); + return store; } -any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) { - std::vector values; - auto type = any_cast(visitBtype(ctx->btype())); - for (auto varDef : ctx->varDef()) { - auto name = varDef->lValue()->ID()->getText(); - auto alloca = builder.createAllocaInst(type, {}, name); - if (varDef->ASSIGN()) { - auto value = any_cast(varDef->initValue()->accept(this)); - auto store = builder.createStoreInst(value, alloca); - } - values.push_back(alloca); - } - return values; +any SysYIRGenerator::visitNumberExp(SysYParser::NumberExpContext *ctx) { + Value *result = nullptr; + assert(ctx->number()->ILITERAL() or ctx->number()->FLITERAL()); + if (auto iLiteral = ctx->number()->ILITERAL()) + result = ConstantValue::get(std::stoi(iLiteral->getText())); + else + result = + ConstantValue::get(std::stof(ctx->number()->FLITERAL()->getText())); + return result; +} + +any SysYIRGenerator::visitLValueExp(SysYParser::LValueExpContext *ctx) { + auto name = ctx->lValue()->ID()->getText(); + auto ptr = symbols.lookup(name); + if (not ptr) + error(ctx, "undefined variable"); + Value *value = builder.createLoadInst(ptr); + return value; +} + +any SysYIRGenerator::visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) { + // generate the operands + auto lhs = any_cast(ctx->exp(0)->accept(this)); + auto rhs = any_cast(ctx->exp(1)->accept(this)); + // create convert instruction if needed + auto lhsTy = lhs->getType(); + auto rhsTy = rhs->getType(); + auto type = getArithmeticResultType(lhsTy, rhsTy); + if (lhsTy != type) + lhs = builder.createIToFInst(lhs); + if (rhsTy != type) + rhs = builder.createIToFInst(rhs); + // create the arithmetic instruction + Value *result = nullptr; + if (ctx->ADD()) + result = type->isInt() ? builder.createAddInst(lhs, rhs) + : builder.createFAddInst(lhs, rhs); + else + result = type->isInt() ? builder.createSubInst(lhs, rhs) + : builder.createFSubInst(lhs, rhs); + return result; +} + +any SysYIRGenerator::visitMultiplicativeExp( + SysYParser::MultiplicativeExpContext *ctx) { + // generate the operands + auto lhs = any_cast(ctx->exp(0)->accept(this)); + auto rhs = any_cast(ctx->exp(1)->accept(this)); + // create convert instruction if needed + auto lhsTy = lhs->getType(); + auto rhsTy = rhs->getType(); + auto type = getArithmeticResultType(lhsTy, rhsTy); + if (lhsTy != type) + lhs = builder.createIToFInst(lhs); + if (rhsTy != type) + rhs = builder.createIToFInst(rhs); + // create the arithmetic instruction + Value *result = nullptr; + if (ctx->MUL()) + result = type->isInt() ? builder.createMulInst(lhs, rhs) + : builder.createFMulInst(lhs, rhs); + else if (ctx->DIV()) + result = type->isInt() ? builder.createDivInst(lhs, rhs) + : builder.createFDivInst(lhs, rhs); + + else + result = type->isInt() ? builder.createRemInst(lhs, rhs) + : builder.createFRemInst(lhs, rhs); + return result; +} + +any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { + auto value = + ctx->exp() ? any_cast(ctx->exp()->accept(this)) : nullptr; + Value *result = builder.createReturnInst(value); + return result; } } // namespace sysy \ No newline at end of file diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h index 3ea51ef..ed87323 100644 --- a/src/SysYIRGenerator.h +++ b/src/SysYIRGenerator.h @@ -1,47 +1,111 @@ #pragma once -#include #include "IR.h" #include "IRBuilder.h" #include "SysYBaseVisitor.h" #include "SysYParser.h" +#include +#include +#include +#include namespace sysy { +class SymbolTable { +private: + enum Kind { + kModule, + kFunction, + kBlock, + }; + +public: + struct ModuleScope { + SymbolTable &table; + ModuleScope(SymbolTable &table) : table(table) { table.enter(kModule); } + ~ModuleScope() { table.exit(); } + }; + struct FunctionScope { + SymbolTable &table; + FunctionScope(SymbolTable &table) : table(table) { table.enter(kFunction); } + ~FunctionScope() { table.exit(); } + }; + + struct BlockScope { + SymbolTable &table; + BlockScope(SymbolTable &table) : table(table) { table.enter(kBlock); } + ~BlockScope() { table.exit(); } + }; + +private: + std::forward_list>> symbols; + +public: + SymbolTable() = default; + +public: + bool isModuleScope() const { return symbols.front().first == kModule; } + bool isFunctionScope() const { return symbols.front().first == kFunction; } + bool isBlockScope() const { return symbols.front().first == kBlock; } + Value *lookup(const std::string &name) const { + for (auto &scope : symbols) { + auto iter = scope.second.find(name); + if (iter != scope.second.end()) + return iter->second; + } + return nullptr; + } + auto insert(const std::string &name, Value *value) { + assert(not symbols.empty()); + return symbols.front().second.emplace(name, value); + } + // void remove(const std::string &name) { + // for (auto &scope : symbols) { + // auto iter = scope.find(name); + // if (iter != scope.end()) { + // scope.erase(iter); + // return; + // } + // } + // } +private: + void enter(Kind kind) { + symbols.emplace_front(); + symbols.front().first = kind; + } + void exit() { symbols.pop_front(); } +}; // class SymbolTable + class SysYIRGenerator : public SysYBaseVisitor { private: std::unique_ptr module; IRBuilder builder; + SymbolTable symbols; public: SysYIRGenerator() = default; +public: + Module *get() const { return module.get(); } + public: virtual std::any visitModule(SysYParser::ModuleContext *ctx) override; virtual std::any visitDecl(SysYParser::DeclContext *ctx) override; virtual std::any visitBtype(SysYParser::BtypeContext *ctx) override; - - virtual std::any visitVarDef(SysYParser::VarDefContext *ctx) override { - return visitChildren(ctx); - } - virtual std::any - visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) override { + virtual std::any visitVarDef(SysYParser::VarDefContext *ctx) override { return visitChildren(ctx); } - virtual std::any - visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) override { + virtual std::any visitInitValue(SysYParser::InitValueContext *ctx) override { return visitChildren(ctx); } virtual std::any visitFunc(SysYParser::FuncContext *ctx) override; - virtual std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override; virtual std::any visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override { @@ -55,16 +119,14 @@ public: virtual std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override; - virtual std::any visitBlockItem(SysYParser::BlockItemContext *ctx) override; + // virtual std::any visitBlockItem(SysYParser::BlockItemContext *ctx) + // override; virtual std::any visitStmt(SysYParser::StmtContext *ctx) override { return visitChildren(ctx); } - virtual std::any - visitAssignStmt(SysYParser::AssignStmtContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; virtual std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override { return visitChildren(ctx); @@ -88,9 +150,7 @@ public: } virtual std::any - visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override { - return visitChildren(ctx); - } + visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; virtual std::any visitEmptyStmt(SysYParser::EmptyStmtContext *ctx) override { return visitChildren(ctx); @@ -102,17 +162,11 @@ public: } virtual std::any - visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override { - return visitChildren(ctx); - } + visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override; - virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override; - virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override; virtual std::any visitAndExp(SysYParser::AndExpContext *ctx) override { return visitChildren(ctx); @@ -139,9 +193,7 @@ public: } virtual std::any - visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override { - return visitChildren(ctx); - } + visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override; virtual std::any visitEqualExp(SysYParser::EqualExpContext *ctx) override { return visitChildren(ctx); @@ -168,6 +220,13 @@ public: return visitChildren(ctx); } +private: + std::any visitGlobalDecl(SysYParser::DeclContext *ctx); + std::any visitLocalDecl(SysYParser::DeclContext *ctx); + Type *getArithmeticResultType(Type *lhs, Type *rhs) { + assert(lhs->isIntOrFloat() and rhs->isIntOrFloat()); + return lhs == rhs ? lhs : Type::getFloatType(); + } }; // class SysYIRGenerator } // namespace sysy \ No newline at end of file diff --git a/src/sysyc.cpp b/src/sysyc.cpp index a6bb042..a865a7e 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -1,4 +1,3 @@ -#include "ASTPrinter.h" #include "tree/ParseTreeWalker.h" #include #include @@ -7,7 +6,7 @@ using namespace std; #include "SysYLexer.h" #include "SysYParser.h" using namespace antlr4; -#include "SysYFormatter.h" +// #include "SysYFormatter.h" #include "SysYIRGenerator.h" using namespace sysy; @@ -25,10 +24,12 @@ int main(int argc, char **argv) { SysYLexer lexer(&input); CommonTokenStream tokens(&lexer); SysYParser parser(&tokens); - auto module = parser.module(); + auto moduleAST = parser.module(); SysYIRGenerator generator; - generator.visitModule(module); - + generator.visitModule(moduleAST); + auto moduleIR = generator.get(); + moduleIR->print(cout); + return EXIT_SUCCESS; } \ No newline at end of file