diff --git a/src/IR.h b/src/IR.h index f54205a..97d3f79 100644 --- a/src/IR.h +++ b/src/IR.h @@ -8,6 +8,7 @@ #include #include #include +#include #include namespace sysy { @@ -65,6 +66,10 @@ public: bool isPointer() const { return kind == kPointer; } bool isFunction() const { return kind == kFunction; } int getSize() const; + template + std::enable_if_t, T *> as() const { + return dynamic_cast(const_cast(this)); + } }; // class Type //! Pointer type @@ -196,15 +201,15 @@ public: class ConstantValue : public Value { protected: union { - int iConstant; - float fConstant; + int iScalar; + float fScalar; }; protected: ConstantValue(int value, const std::string &name = "") - : Value(Type::getIntType(), name), iConstant(value) {} + : Value(Type::getIntType(), name), iScalar(value) {} ConstantValue(float value, const std::string &name = "") - : Value(Type::getFloatType(), name), fConstant(value) {} + : Value(Type::getFloatType(), name), fScalar(value) {} public: static ConstantValue *get(int value, const std::string &name = ""); @@ -213,11 +218,11 @@ public: public: int getInt() const { assert(isInt()); - return iConstant; + return iScalar; } float getFloat() const { assert(isFloat()); - return fConstant; + return fScalar; } }; // class ConstantValue @@ -584,10 +589,8 @@ class LoadInst : public Instruction { protected: LoadInst(Value *pointer, const std::vector &indices = {}, BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction( - kLoad, - dynamic_cast(pointer->getType())->getBaseType(), - parent, name) { + : Instruction(kLoad, pointer->getType()->as()->getBaseType(), + parent, name) { addOperands(indices); } @@ -644,10 +647,10 @@ protected: public: Type *getReturnType() const { - return dynamic_cast(getType())->getReturnType(); + return getType()->as()->getReturnType(); } auto getParamTypes() const { - return dynamic_cast(getType())->getParamTypes(); + return getType()->as()->getParamTypes(); } auto getBasicBlocks() { return make_range(blocks); } BasicBlock *getEntryBlock() { return blocks.front().get(); } @@ -668,16 +671,21 @@ class GlobalValue : public User { protected: Module *parent; + bool hasInit; protected: GlobalValue(Module *parent, Type *type, const std::string &name, - const std::vector &dims = {}) - : User(type, name), parent(parent) { + const std::vector &dims = {}, Value *init = nullptr) + : User(type, name), parent(parent), hasInit(init) { + assert(type->isPointer()); addOperands(dims); + if (init) + addOperand(init); } public: - int getNumDims() const { return getNumOperands(); } + Value *init() const { return hasInit ? operands.back().getValue() : nullptr; } + int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); } Value *getDim(int index) { return getOperand(index); } }; // class GlobalValue diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 976c95f..b0a2cf3 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -10,10 +10,7 @@ any SysYIRGenerator::visitModule(SysYParser::ModuleContext *ctx) { auto pModule = new Module(); assert(pModule); module.reset(pModule); - for (auto decl : ctx->decl()) - visitDecl(decl); - for (auto func : ctx->func()) - visitFunc(func); + visitChildren(ctx); return pModule; }