From 8208469b1356c428df47c5a2211bcd3dc6d82879 Mon Sep 17 00:00:00 2001 From: Xing Su Date: Fri, 7 Apr 2023 21:46:19 -0400 Subject: [PATCH] Support function call and IR printing. --- src/IR.cpp | 108 ++++++++-------- src/IR.h | 277 +++++++++++++++++++++++++++++----------- src/SysYIRGenerator.cpp | 13 +- 3 files changed, 267 insertions(+), 131 deletions(-) diff --git a/src/IR.cpp b/src/IR.cpp index 3335598..9ecbede 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -28,7 +28,7 @@ ostream &interleave(ostream &os, const T &container, const string sep = ", ") { return os; } static inline ostream &printVarName(ostream &os, const Value *var) { - return os << (dynamic_cast(var) ? '@' : '%') + return os << (dyncast(var) ? '@' : '%') << var->getName(); } static inline ostream &printBlockName(ostream &os, const BasicBlock *block) { @@ -38,7 +38,7 @@ static inline ostream &printFunctionName(ostream &os, const Function *fn) { return os << '@' << fn->getName(); } static inline ostream &printOperand(ostream &os, const Value *value) { - auto constant = dynamic_cast(value); + auto constant = dyncast(value); if (constant) { constant->print(os); return os; @@ -162,16 +162,16 @@ void Value::replaceAllUsesWith(Value *value) { } bool Value::isConstant() const { - if (dynamic_cast(this)) + if (dyncast(this)) return true; - if (dynamic_cast(this) or - dynamic_cast(this)) + if (dyncast(this) or + dyncast(this)) return true; - if (auto array = dynamic_cast(this)) { - auto elements = array->getValues(); - return all_of(elements.begin(), elements.end(), - [](Value *v) -> bool { return v->isConstant(); }); - } + // if (auto array = dyncast(this)) { + // auto elements = array->getValues(); + // return all_of(elements.begin(), elements.end(), + // [](Value *v) -> bool { return v->isConstant(); }); + // } return false; } @@ -206,7 +206,7 @@ void ConstantValue::print(ostream &os) const { Argument::Argument(Type *type, BasicBlock *block, int index, const std::string &name) - : Value(type, name), block(block), index(index) { + : Value(kArgument, type, name), block(block), index(index) { if (not hasName()) setName(to_string(block->getParent()->allocateVariableID())); } @@ -217,8 +217,8 @@ void Argument::print(std::ostream &os) const { } BasicBlock::BasicBlock(Function *parent, const std::string &name) - : Value(Type::getLabelType(), name), parent(parent), instructions(), - arguments(), successors(), predecessors() { + : Value(kBasicBlock, Type::getLabelType(), name), parent(parent), + instructions(), arguments(), successors(), predecessors() { if (not hasName()) setName("bb" + to_string(getParent()->allocateblockID())); } @@ -230,11 +230,13 @@ void BasicBlock::print(std::ostream &os) const { auto args = getArguments(); auto b = args.begin(), e = args.end(); if (b != e) { - printVarName(os, &*b) << ": " << *b->getType(); - for (auto arg : make_range(std::next(b), e)) { + os << '('; + printVarName(os, b->get()) << ": " << *b->get()->getType(); + for (auto &arg : make_range(std::next(b), e)) { os << ", "; - printVarName(os, &arg) << ": " << arg.getType(); + printVarName(os, arg.get()) << ": " << *arg->getType(); } + os << ')'; } os << ":\n"; for (auto &inst : instructions) { @@ -244,14 +246,14 @@ void BasicBlock::print(std::ostream &os) const { Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent, const std::string &name) - : User(type, name), kind(kind), parent(parent) { + : User(kind, type, name), kind(kind), parent(parent) { if (not type->isVoid() and not hasName()) setName(to_string(getFunction()->allocateVariableID())); } void CallInst::print(std::ostream &os) const { if (not getType()->isVoid()) - printVarName(os, this) << " = "; + printVarName(os, this) << " = call "; printFunctionName(os, getCallee()) << '('; auto args = getArguments(); auto b = args.begin(), e = args.end(); @@ -361,6 +363,7 @@ void BinaryInst::print(std::ostream &os) const { default: assert(false); } + os << ' '; printOperand(os, getLhs()) << ", "; printOperand(os, getRhs()) << " : " << *getType(); } @@ -453,8 +456,13 @@ void Function::print(std::ostream &os) const { auto paramTypes = getParamTypes(); os << *returnType << ' '; printFunctionName(os, this) << '('; - interleave(os, paramTypes) << ')'; - os << "{\n"; + auto b = paramTypes.begin(), e = paramTypes.end(); + if (b != e) { + os << *(*b); + for (auto type : make_range(std::next(b), e)) + os << ", " << *type; + } + os << ") {\n"; for (auto &bb : getBasicBlocks()) { os << *bb << '\n'; } @@ -468,35 +476,35 @@ void Module::print(std::ostream &os) const { os << *f.second << '\n'; } -ArrayValue *ArrayValue::get(Type *type, const vector &values) { - static map, unique_ptr> arrayConstants; - hash hasher; - auto key = make_pair( - type, hasher(string(reinterpret_cast(values.data()), - values.size() * sizeof(Value *)))); - - auto iter = arrayConstants.find(key); - if (iter != arrayConstants.end()) - return iter->second.get(); - auto constant = new ArrayValue(type, values); - assert(constant); - auto result = arrayConstants.emplace(key, constant); - return result.first->second.get(); -} - -ArrayValue *ArrayValue::get(const std::vector &values) { - vector vals(values.size(), nullptr); - std::transform(values.begin(), values.end(), vals.begin(), - [](int v) { return ConstantValue::get(v); }); - return get(Type::getIntType(), vals); -} - -ArrayValue *ArrayValue::get(const std::vector &values) { - vector vals(values.size(), nullptr); - std::transform(values.begin(), values.end(), vals.begin(), - [](float v) { return ConstantValue::get(v); }); - return get(Type::getFloatType(), vals); -} +// ArrayValue *ArrayValue::get(Type *type, const vector &values) { +// static map, unique_ptr> arrayConstants; +// hash hasher; +// auto key = make_pair( +// type, hasher(string(reinterpret_cast(values.data()), +// values.size() * sizeof(Value *)))); + +// auto iter = arrayConstants.find(key); +// if (iter != arrayConstants.end()) +// return iter->second.get(); +// auto constant = new ArrayValue(type, values); +// assert(constant); +// auto result = arrayConstants.emplace(key, constant); +// return result.first->second.get(); +// } + +// ArrayValue *ArrayValue::get(const std::vector &values) { +// vector vals(values.size(), nullptr); +// std::transform(values.begin(), values.end(), vals.begin(), +// [](int v) { return ConstantValue::get(v); }); +// return get(Type::getIntType(), vals); +// } + +// ArrayValue *ArrayValue::get(const std::vector &values) { +// vector vals(values.size(), nullptr); +// std::transform(values.begin(), values.end(), vals.begin(), +// [](float v) { return ConstantValue::get(v); }); +// return get(Type::getFloatType(), vals); +// } void User::setOperand(int index, Value *value) { assert(index < getNumOperands()); @@ -519,7 +527,7 @@ CallInst::CallInst(Function *callee, const std::vector &args, } Function *CallInst::getCallee() const { - return dynamic_cast(getOperand(0)); + return dyncast(getOperand(0)); } } // namespace sysy \ No newline at end of file diff --git a/src/IR.h b/src/IR.h index 1de92d3..2ed3902 100644 --- a/src/IR.h +++ b/src/IR.h @@ -2,6 +2,7 @@ #include "range.h" #include +#include #include #include #include @@ -173,18 +174,91 @@ public: void setValue(Value *value) { value = value; } }; // class Use +template +inline std::enable_if_t, bool> isa(const Value *value) { + return T::classof(value); +} + +template +inline std::enable_if_t, T *> dyncast(Value *value) { + return isa(value) ? static_cast(value) : nullptr; +} + +template +inline std::enable_if_t, const T *> dyncast(const Value *value) { + return isa(value) ? static_cast(value) : nullptr; +} + //! The base class of all value types class Value { +public: + enum Kind : uint64_t { + kInvalid, + // Instructions + // Binary + kAdd = 0x1UL << 0, + kSub = 0x1UL << 1, + kMul = 0x1UL << 2, + kDiv = 0x1UL << 3, + kRem = 0x1UL << 4, + kICmpEQ = 0x1UL << 5, + kICmpNE = 0x1UL << 6, + kICmpLT = 0x1UL << 7, + kICmpGT = 0x1UL << 8, + kICmpLE = 0x1UL << 9, + kICmpGE = 0x1UL << 10, + kFAdd = 0x1UL << 14, + kFSub = 0x1UL << 15, + kFMul = 0x1UL << 16, + kFDiv = 0x1UL << 17, + kFRem = 0x1UL << 18, + kFCmpEQ = 0x1UL << 19, + kFCmpNE = 0x1UL << 20, + kFCmpLT = 0x1UL << 21, + kFCmpGT = 0x1UL << 22, + kFCmpLE = 0x1UL << 23, + kFCmpGE = 0x1UL << 24, + // Unary + kNeg = 0x1UL << 25, + kNot = 0x1UL << 26, + kFNeg = 0x1UL << 27, + kFtoI = 0x1UL << 28, + kIToF = 0x1UL << 29, + // call + kCall = 0x1UL << 30, + // terminator + kCondBr = 0x1UL << 31, + kBr = 0x1UL << 32, + kReturn = 0x1UL << 33, + // mem op + kAlloca = 0x1UL << 34, + kLoad = 0x1UL << 35, + kStore = 0x1UL << 36, + kFirstInst = kAdd, + kLastInst = kStore, + // others + kArgument = 0x1UL << 37, + kBasicBlock = 0x1UL << 38, + kFunction = 0x1UL << 39, + kConstant = 0x1UL << 40, + kGlobal = 0x1UL << 41, + }; + protected: + Kind kind; Type *type; std::string name; std::list uses; protected: - Value(Type *type, const std::string &name = "") - : type(type), name(name), uses() {} + Value(Kind kind, Type *type, const std::string &name = "") + : kind(kind), type(type), name(name), uses() {} virtual ~Value() = default; +public: + Kind getKind() const { return kind; } + static bool classof(const Value *) { return true; } + public: Type *getType() const { return type; } const std::string &getName() const { return name; } @@ -200,7 +274,7 @@ public: bool isConstant() const; public: - virtual void print(std::ostream &os) const = 0; + virtual void print(std::ostream &os) const {}; }; // class Value /*! @@ -217,14 +291,18 @@ protected: }; protected: - ConstantValue(int value) : Value(Type::getIntType(), ""), iScalar(value) {} + ConstantValue(int value) + : Value(kConstant, Type::getIntType(), ""), iScalar(value) {} ConstantValue(float value) - : Value(Type::getFloatType(), ""), fScalar(value) {} + : Value(kConstant, Type::getFloatType(), ""), fScalar(value) {} public: static ConstantValue *get(int value); static ConstantValue *get(float value); +public: + static bool classof(const Value *value) { return value->getKind() == kConstant; } + public: int getInt() const { assert(isInt()); @@ -259,6 +337,9 @@ public: Argument(Type *type, BasicBlock *block, int index, const std::string &name = ""); +public: + static bool classof(const Value *value) { return value->getKind() == kConstant; } + public: BasicBlock *getParent() const { return block; } int getIndex() const { return index; } @@ -282,7 +363,7 @@ class BasicBlock : public Value { public: using inst_list = std::list>; using iterator = inst_list::iterator; - using arg_list = std::vector; + using arg_list = std::vector>; using block_list = std::vector; protected: @@ -295,6 +376,9 @@ protected: protected: explicit BasicBlock(Function *parent, const std::string &name = ""); +public: + static bool classof(const Value *value) { return value->getKind() == kBasicBlock; } + public: int getNumInstructions() const { return instructions.size(); } int getNumArguments() const { return arguments.size(); } @@ -309,8 +393,10 @@ public: iterator end() { return instructions.end(); } iterator terminator() { return std::prev(end()); } Argument *createArgument(Type *type, const std::string &name = "") { - arguments.emplace_back(type, this, arguments.size(), name); - return &arguments.back(); + auto arg = new Argument(type, this, arguments.size(), name); + assert(arg); + arguments.emplace_back(arg); + return arguments.back().get(); }; public: @@ -325,8 +411,8 @@ protected: std::vector operands; protected: - User(Type *type, const std::string &name = "") - : Value(type, name), operands() {} + User(Kind kind, Type *type, const std::string &name = "") + : Value(kind, type, name), operands() {} public: using use_iterator = std::vector::const_iterator; @@ -369,50 +455,50 @@ public: */ class Instruction : public User { public: - enum Kind : uint64_t { - kInvalid = 0x0UL, - // Binary - kAdd = 0x1UL << 0, - kSub = 0x1UL << 1, - kMul = 0x1UL << 2, - kDiv = 0x1UL << 3, - kRem = 0x1UL << 4, - kICmpEQ = 0x1UL << 5, - kICmpNE = 0x1UL << 6, - kICmpLT = 0x1UL << 7, - kICmpGT = 0x1UL << 8, - kICmpLE = 0x1UL << 9, - kICmpGE = 0x1UL << 10, - kFAdd = 0x1UL << 14, - kFSub = 0x1UL << 15, - kFMul = 0x1UL << 16, - kFDiv = 0x1UL << 17, - kFRem = 0x1UL << 18, - kFCmpEQ = 0x1UL << 19, - kFCmpNE = 0x1UL << 20, - kFCmpLT = 0x1UL << 21, - kFCmpGT = 0x1UL << 22, - kFCmpLE = 0x1UL << 23, - kFCmpGE = 0x1UL << 24, - // Unary - kNeg = 0x1UL << 25, - kNot = 0x1UL << 26, - kFNeg = 0x1UL << 27, - kFtoI = 0x1UL << 28, - kIToF = 0x1UL << 29, - // call - kCall = 0x1UL << 30, - // terminator - kCondBr = 0x1UL << 31, - kBr = 0x1UL << 32, - kReturn = 0x1UL << 33, - // mem op - kAlloca = 0x1UL << 34, - kLoad = 0x1UL << 35, - kStore = 0x1UL << 36, - // constant - // kConstant = 0x1UL << 37, - }; + // enum Kind : uint64_t { + // kInvalid = 0x0UL, + // // Binary + // kAdd = 0x1UL << 0, + // kSub = 0x1UL << 1, + // kMul = 0x1UL << 2, + // kDiv = 0x1UL << 3, + // kRem = 0x1UL << 4, + // kICmpEQ = 0x1UL << 5, + // kICmpNE = 0x1UL << 6, + // kICmpLT = 0x1UL << 7, + // kICmpGT = 0x1UL << 8, + // kICmpLE = 0x1UL << 9, + // kICmpGE = 0x1UL << 10, + // kFAdd = 0x1UL << 14, + // kFSub = 0x1UL << 15, + // kFMul = 0x1UL << 16, + // kFDiv = 0x1UL << 17, + // kFRem = 0x1UL << 18, + // kFCmpEQ = 0x1UL << 19, + // kFCmpNE = 0x1UL << 20, + // kFCmpLT = 0x1UL << 21, + // kFCmpGT = 0x1UL << 22, + // kFCmpLE = 0x1UL << 23, + // kFCmpGE = 0x1UL << 24, + // // Unary + // kNeg = 0x1UL << 25, + // kNot = 0x1UL << 26, + // kFNeg = 0x1UL << 27, + // kFtoI = 0x1UL << 28, + // kIToF = 0x1UL << 29, + // // call + // kCall = 0x1UL << 30, + // // terminator + // kCondBr = 0x1UL << 31, + // kBr = 0x1UL << 32, + // kReturn = 0x1UL << 33, + // // mem op + // kAlloca = 0x1UL << 34, + // kLoad = 0x1UL << 35, + // kStore = 0x1UL << 36, + // // constant + // // kConstant = 0x1UL << 37, + // }; protected: Kind kind; @@ -422,6 +508,11 @@ protected: Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, const std::string &name = ""); +public: + static bool classof(const Value *value) { + return value->getKind() >= kFirstInst and value->getKind() <= kLastInst; + } + public: Kind getKind() const { return kind; } BasicBlock *getParent() const { return parent; } @@ -476,6 +567,9 @@ protected: CallInst(Function *callee, const std::vector &args = {}, BasicBlock *parent = nullptr, const std::string &name = ""); +public: + static bool classof(const Value *value) { return value->getKind() == kCall; } + public: Function *getCallee() const; auto getArguments() const { @@ -497,6 +591,12 @@ protected: addOperand(operand); } +public: + static bool classof(const Value *value) { + return Instruction::classof(value) and + static_cast(value)->isUnary(); + } + public: Value *getOperand() const { return User::getOperand(0); } @@ -516,6 +616,12 @@ protected: addOperand(rhs); } +public: + static bool classof(const Value *value) { + return Instruction::classof(value) and + static_cast(value)->isBinary(); + } + public: Value *getLhs() const { return getOperand(0); } Value *getRhs() const { return getOperand(1); } @@ -535,6 +641,9 @@ protected: addOperand(value); } +public: + static bool classof(const Value *value) { return value->getKind() == kReturn; } + public: bool hasReturnValue() const { return not operands.empty(); } Value *getReturnValue() const { @@ -558,9 +667,12 @@ protected: addOperands(args); } +public: + static bool classof(const Value *value) { return value->getKind() == kBr; } + public: BasicBlock *getBlock() const { - return dynamic_cast(getOperand(0)); + return dyncast(getOperand(0)); } auto getArguments() const { return make_range(std::next(operand_begin()), operand_end()); @@ -588,13 +700,16 @@ protected: addOperands(elseArgs); } +public: + static bool classof(const Value *value) { return value->getKind() == kCondBr; } + public: Value *getCondition() const { return getOperand(0); } BasicBlock *getThenBlock() const { - return dynamic_cast(getOperand(1)); + return dyncast(getOperand(1)); } BasicBlock *getElseBlock() const { - return dynamic_cast(getOperand(2)); + return dyncast(getOperand(2)); } auto getThenArguments() const { auto begin = std::next(operand_begin(), 3); @@ -623,6 +738,9 @@ protected: addOperands(dims); } +public: + static bool classof(const Value *value) { return value->getKind() == kAlloca; } + public: int getNumDims() const { return getNumOperands(); } auto getDims() const { return getOperands(); } @@ -645,6 +763,9 @@ protected: addOperands(indices); } +public: + static bool classof(const Value *value) { return value->getKind() == kLoad; } + public: int getNumIndices() const { return getNumOperands() - 1; } Value *getPointer() const { return getOperand(0); } @@ -671,6 +792,9 @@ protected: addOperands(indices); } +public: + static bool classof(const Value *value) { return value->getKind() == kStore; } + public: int getNumIndices() const { return getNumOperands() - 2; } Value *getValue() const { return getOperand(0); } @@ -691,10 +815,13 @@ class Function : public Value { protected: Function(Module *parent, Type *type, const std::string &name) - : Value(type, name), parent(parent), variableID(0), blocks() { + : Value(kFunction, type, name), parent(parent), variableID(0), blocks() { blocks.emplace_back(new BasicBlock(this, "entry")); } +public: + static bool classof(const Value *value) { return value->getKind() == kFunction; } + public: using block_list = std::list>; @@ -729,24 +856,24 @@ public: void print(std::ostream &os) const override; }; // class Function -class ArrayValue : public User { -protected: - ArrayValue(Type *type, const std::vector &values = {}) - : User(type, "") { - addOperands(values); - } +// class ArrayValue : public User { +// protected: +// ArrayValue(Type *type, const std::vector &values = {}) +// : User(type, "") { +// addOperands(values); +// } -public: - static ArrayValue *get(Type *type, const std::vector &values); - static ArrayValue *get(const std::vector &values); - static ArrayValue *get(const std::vector &values); +// public: +// static ArrayValue *get(Type *type, const std::vector &values); +// static ArrayValue *get(const std::vector &values); +// static ArrayValue *get(const std::vector &values); -public: - auto getValues() const { return getOperands(); } +// public: +// auto getValues() const { return getOperands(); } -public: - void print(std::ostream &os) const override{}; -}; // class ConstantArray +// public: +// void print(std::ostream &os) const override{}; +// }; // class ConstantArray //! Global value declared at file scope class GlobalValue : public User { @@ -760,13 +887,15 @@ protected: protected: GlobalValue(Module *parent, Type *type, const std::string &name, const std::vector &dims = {}, Value *init = nullptr) - : User(type, name), parent(parent), hasInit(init) { + : User(kGlobal, type, name), parent(parent), hasInit(init) { assert(type->isPointer()); addOperands(dims); if (init) addOperand(init); } +public: + static bool classof(const Value *value) { return value->getKind() == kGlobal; } public: Value *init() const { return hasInit ? operands.back().getValue() : nullptr; } int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); } diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index dad6a64..adc4431 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -71,8 +71,6 @@ any SysYIRGenerator::visitLocalDecl(SysYParser::DeclContext *ctx) { } any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) { - // create the function scope - SymbolTable::FunctionScope scope(symbols); // obtain function name and type signature auto name = ctx->ID()->getText(); vector paramTypes; @@ -88,6 +86,10 @@ any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) { auto funcType = Type::getFunctionType(returnType, paramTypes); // create the IR function auto function = module->createFunction(name, funcType); + // update the symbol table + symbols.insert(name, function); + // create the function scope + SymbolTable::FunctionScope scope(symbols); // create the entry block with the same parameters as the function, // and update the symbol table auto entry = function->getEntryBlock(); @@ -149,10 +151,7 @@ any SysYIRGenerator::visitLValueExp(SysYParser::LValueExpContext *ctx) { Value *value = symbols.lookup(name); if (not value) error(ctx, "undefined variable"); - auto a = dynamic_cast(value); - if (dynamic_cast(value) or - (dynamic_cast(value) and - dynamic_cast(value))) + if (isa(value) or isa(value)) value = builder.createLoadInst(value); return value; } @@ -217,7 +216,7 @@ any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { auto funcName = ctx->ID()->getText(); - auto func = dynamic_cast(symbols.lookup(funcName)); + auto func = dyncast(symbols.lookup(funcName)); assert(func); vector args; if (auto rArgs = ctx->funcRParams()) {