diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index adc4431..8847ea2 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -7,225 +7,337 @@ using namespace std; #include "Diagnostic.h" #include "SysYIRGenerator.h" -namespace sysy { - -any SysYIRGenerator::visitModule(SysYParser::ModuleContext *ctx) { - // global scople of the module - SymbolTable::ModuleScope scope(symbols); - // create the IR module - auto pModule = new Module(); - assert(pModule); - module.reset(pModule); - // generates globals and functions - visitChildren(ctx); - // return the IR module - return pModule; -} - -any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) { - // global and local declarations are handled in different ways - return symbols.isModuleScope() ? visitGlobalDecl(ctx) : visitLocalDecl(ctx); -} - -any SysYIRGenerator::visitGlobalDecl(SysYParser::DeclContext *ctx) { - error(ctx, "not implemented yet"); - std::vector values; - bool isConst = ctx->CONST(); - auto type = any_cast(visitBtype(ctx->btype())); - for (auto varDef : ctx->varDef()) { - auto name = varDef->lValue()->ID()->getText(); - vector dims; - for (auto exp : varDef->lValue()->exp()) - dims.push_back(any_cast(exp->accept(this))); - auto init = varDef->ASSIGN() - ? any_cast(visitInitValue(varDef->initValue())) - : nullptr; - values.push_back(module->createGlobalValue(name, type, dims, init)); - } - // FIXME const - return values; -} - -any SysYIRGenerator::visitLocalDecl(SysYParser::DeclContext *ctx) { - // a single declaration statement can declare several variables, - // collect them in a vector - std::vector values; - // obtain the declared type - auto type = Type::getPointerType(any_cast(visitBtype(ctx->btype()))); - // handle variables - for (auto varDef : ctx->varDef()) { - // obtain the variable name and allocate space on the stack - auto name = varDef->lValue()->ID()->getText(); - auto alloca = builder.createAllocaInst(type, {}, name); +namespace sysy +{ + + any SysYIRGenerator::visitModule(SysYParser::ModuleContext *ctx) + { + // global scople of the module + SymbolTable::ModuleScope scope(symbols); + // create the IR module + auto pModule = new Module(); + assert(pModule); + module.reset(pModule); + // generates globals and functions + visitChildren(ctx); + // return the IR module + return pModule; + } + + any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) + { + // global and local declarations are handled in different ways + return symbols.isModuleScope() ? visitGlobalDecl(ctx) : visitLocalDecl(ctx); + } + + any SysYIRGenerator::visitGlobalDecl(SysYParser::DeclContext *ctx) + { + error(ctx, "not implemented yet"); + std::vector values; + bool isConst = ctx->CONST(); + auto type = any_cast(visitBtype(ctx->btype())); + for (auto varDef : ctx->varDef()) + { + auto name = varDef->lValue()->ID()->getText(); + vector dims; + for (auto exp : varDef->lValue()->exp()) + dims.push_back(any_cast(exp->accept(this))); + auto init = varDef->ASSIGN() + ? any_cast(visitInitValue(varDef->initValue())) + : nullptr; + values.push_back(module->createGlobalValue(name, type, dims, init)); + } + // FIXME const + return values; + } + + any SysYIRGenerator::visitLocalDecl(SysYParser::DeclContext *ctx) + { + // a single declaration statement can declare several variables, + // collect them in a vector + std::vector values; + // obtain the declared type + auto type = Type::getPointerType(any_cast(visitBtype(ctx->btype()))); + // handle variables + for (auto varDef : ctx->varDef()) + { + // obtain the variable name and allocate space on the stack + auto name = varDef->lValue()->ID()->getText(); + auto alloca = builder.createAllocaInst(type, {}, name); + // update the symbol table + symbols.insert(name, alloca); + // if an initial value is given, create a store instruction + if (varDef->ASSIGN()) + { + auto value = any_cast(visitInitValue(varDef->initValue())); + auto store = builder.createStoreInst(value, alloca); + } + // collect the created variable (pointer) + values.push_back(alloca); + } + return values; + } + + any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) + { + // obtain function name and type signature + auto name = ctx->ID()->getText(); + vector paramTypes; + vector paramNames; + if (ctx->funcFParams()) + { + auto params = ctx->funcFParams()->funcFParam(); + for (auto param : params) + { + paramTypes.push_back(any_cast(visitBtype(param->btype()))); + paramNames.push_back(param->ID()->getText()); + } + } + Type *returnType = any_cast(visitFuncType(ctx->funcType())); + auto funcType = Type::getFunctionType(returnType, paramTypes); + // create the IR function + auto function = module->createFunction(name, funcType); // update the symbol table - symbols.insert(name, alloca); - // if an initial value is given, create a store instruction - if (varDef->ASSIGN()) { - auto value = any_cast(visitInitValue(varDef->initValue())); - auto store = builder.createStoreInst(value, alloca); + symbols.insert(name, function); + // create the function scope + SymbolTable::FunctionScope scope(symbols); + // create the entry block with the same parameters as the function, + // and update the symbol table + auto entry = function->getEntryBlock(); + for (auto i = 0; i < paramTypes.size(); ++i) + { + auto arg = entry->createArgument(paramTypes[i], paramNames[i]); + symbols.insert(paramNames[i], arg); } - // collect the created variable (pointer) - values.push_back(alloca); - } - return values; -} - -any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) { - // obtain function name and type signature - auto name = ctx->ID()->getText(); - vector paramTypes; - vector paramNames; - if (ctx->funcFParams()) { - auto params = ctx->funcFParams()->funcFParam(); - for (auto param : params) { - paramTypes.push_back(any_cast(visitBtype(param->btype()))); - paramNames.push_back(param->ID()->getText()); + // setup the instruction insert point + builder.setPosition(entry, entry->end()); + // generate the function body + visitBlockStmt(ctx->blockStmt()); + return function; + } + + any SysYIRGenerator::visitBtype(SysYParser::BtypeContext *ctx) + { + return ctx->INT() ? Type::getIntType() : Type::getFloatType(); + } + any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) + { + return ctx->INT() + ? Type::getIntType() + : (ctx->FLOAT() ? Type::getFloatType() : Type::getVoidType()); + } + + any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) + { + // create the block scope + SymbolTable::BlockScope scope(symbols); + // create new basicblock + + // the insert point has already been setup + for (auto item : ctx->blockItem()) + visitBlockItem(item); + // return the corresponding IR block + return builder.getBasicBlock(); + } + + any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) + { + // generate the rhs expression + auto rhs = any_cast(ctx->exp()->accept(this)); + // get the address of the lhs variable + auto lValue = ctx->lValue(); + auto name = lValue->ID()->getText(); + auto pointer = symbols.lookup(name); + if (not pointer) + error(ctx, "undefined variable"); + // update the variable + Value *store = builder.createStoreInst(rhs, pointer); + return store; + } + + any SysYIRGenerator::visitNumberExp(SysYParser::NumberExpContext *ctx) + { + Value *result = nullptr; + assert(ctx->number()->ILITERAL() or ctx->number()->FLITERAL()); + if (auto iLiteral = ctx->number()->ILITERAL()) + result = ConstantValue::get(std::stoi(iLiteral->getText())); + else + result = + ConstantValue::get(std::stof(ctx->number()->FLITERAL()->getText())); + return result; + } + + any SysYIRGenerator::visitLValueExp(SysYParser::LValueExpContext *ctx) + { + auto name = ctx->lValue()->ID()->getText(); + Value *value = symbols.lookup(name); + if (not value) + error(ctx, "undefined variable"); + if (isa(value) or isa(value)) + value = builder.createLoadInst(value); + return value; + } + + any SysYIRGenerator::visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) + { + // generate the operands + auto lhs = any_cast(ctx->exp(0)->accept(this)); + auto rhs = any_cast(ctx->exp(1)->accept(this)); + // create convert instruction if needed + auto lhsTy = lhs->getType(); + auto rhsTy = rhs->getType(); + auto type = getArithmeticResultType(lhsTy, rhsTy); + if (lhsTy != type) + lhs = builder.createIToFInst(lhs); + if (rhsTy != type) + rhs = builder.createIToFInst(rhs); + // create the arithmetic instruction + Value *result = nullptr; + if (ctx->ADD()) + result = type->isInt() ? builder.createAddInst(lhs, rhs) + : builder.createFAddInst(lhs, rhs); + else + result = type->isInt() ? builder.createSubInst(lhs, rhs) + : builder.createFSubInst(lhs, rhs); + return result; + } + + any SysYIRGenerator::visitMultiplicativeExp( + SysYParser::MultiplicativeExpContext *ctx) + { + // generate the operands + auto lhs = any_cast(ctx->exp(0)->accept(this)); + auto rhs = any_cast(ctx->exp(1)->accept(this)); + // create convert instruction if needed + auto lhsTy = lhs->getType(); + auto rhsTy = rhs->getType(); + auto type = getArithmeticResultType(lhsTy, rhsTy); + if (lhsTy != type) + lhs = builder.createIToFInst(lhs); + if (rhsTy != type) + rhs = builder.createIToFInst(rhs); + // create the arithmetic instruction + Value *result = nullptr; + if (ctx->MUL()) + result = type->isInt() ? builder.createMulInst(lhs, rhs) + : builder.createFMulInst(lhs, rhs); + else if (ctx->DIV()) + result = type->isInt() ? builder.createDivInst(lhs, rhs) + : builder.createFDivInst(lhs, rhs); + + else + result = type->isInt() ? builder.createRemInst(lhs, rhs) + : builder.createFRemInst(lhs, rhs); + return result; + } + + any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) + { + auto value = + ctx->exp() ? any_cast(ctx->exp()->accept(this)) : nullptr; + Value *result = builder.createReturnInst(value); + return result; + } + + any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) + { + auto funcName = ctx->ID()->getText(); + auto func = dyncast(symbols.lookup(funcName)); + assert(func); + vector args; + if (auto rArgs = ctx->funcRParams()) + { + for (auto exp : rArgs->exp()) + { + args.push_back(any_cast(exp->accept(this))); + } } + Value *call = builder.createCallInst(func, args); + return call; } - Type *returnType = any_cast(visitFuncType(ctx->funcType())); - auto funcType = Type::getFunctionType(returnType, paramTypes); - // create the IR function - auto function = module->createFunction(name, funcType); - // update the symbol table - symbols.insert(name, function); - // create the function scope - SymbolTable::FunctionScope scope(symbols); - // create the entry block with the same parameters as the function, - // and update the symbol table - auto entry = function->getEntryBlock(); - for (auto i = 0; i < paramTypes.size(); ++i) { - auto arg = entry->createArgument(paramTypes[i], paramNames[i]); - symbols.insert(paramNames[i], arg); - } - // setup the instruction insert point - builder.setPosition(entry, entry->end()); - // generate the function body - visitBlockStmt(ctx->blockStmt()); - return function; -} - -any SysYIRGenerator::visitBtype(SysYParser::BtypeContext *ctx) { - return ctx->INT() ? Type::getIntType() : Type::getFloatType(); -} -any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { - return ctx->INT() - ? Type::getIntType() - : (ctx->FLOAT() ? Type::getFloatType() : Type::getVoidType()); -} - -any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { - // the insert point has already been setup - for (auto item : ctx->blockItem()) - visitBlockItem(item); - // return the corresponding IR block - return builder.getBasicBlock(); -} - -any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { - // generate the rhs expression - auto rhs = any_cast(ctx->exp()->accept(this)); - // get the address of the lhs variable - auto lValue = ctx->lValue(); - auto name = lValue->ID()->getText(); - auto pointer = symbols.lookup(name); - if (not pointer) - error(ctx, "undefined variable"); - // update the variable - Value *store = builder.createStoreInst(rhs, pointer); - return store; -} - -any SysYIRGenerator::visitNumberExp(SysYParser::NumberExpContext *ctx) { - Value *result = nullptr; - assert(ctx->number()->ILITERAL() or ctx->number()->FLITERAL()); - if (auto iLiteral = ctx->number()->ILITERAL()) - result = ConstantValue::get(std::stoi(iLiteral->getText())); - else - result = - ConstantValue::get(std::stof(ctx->number()->FLITERAL()->getText())); - return result; -} - -any SysYIRGenerator::visitLValueExp(SysYParser::LValueExpContext *ctx) { - auto name = ctx->lValue()->ID()->getText(); - Value *value = symbols.lookup(name); - if (not value) - error(ctx, "undefined variable"); - if (isa(value) or isa(value)) - value = builder.createLoadInst(value); - return value; -} - -any SysYIRGenerator::visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) { - // generate the operands - auto lhs = any_cast(ctx->exp(0)->accept(this)); - auto rhs = any_cast(ctx->exp(1)->accept(this)); - // create convert instruction if needed - auto lhsTy = lhs->getType(); - auto rhsTy = rhs->getType(); - auto type = getArithmeticResultType(lhsTy, rhsTy); - if (lhsTy != type) - lhs = builder.createIToFInst(lhs); - if (rhsTy != type) - rhs = builder.createIToFInst(rhs); - // create the arithmetic instruction - Value *result = nullptr; - if (ctx->ADD()) - result = type->isInt() ? builder.createAddInst(lhs, rhs) - : builder.createFAddInst(lhs, rhs); - else - result = type->isInt() ? builder.createSubInst(lhs, rhs) - : builder.createFSubInst(lhs, rhs); - return result; -} - -any SysYIRGenerator::visitMultiplicativeExp( - SysYParser::MultiplicativeExpContext *ctx) { - // generate the operands - auto lhs = any_cast(ctx->exp(0)->accept(this)); - auto rhs = any_cast(ctx->exp(1)->accept(this)); - // create convert instruction if needed - auto lhsTy = lhs->getType(); - auto rhsTy = rhs->getType(); - auto type = getArithmeticResultType(lhsTy, rhsTy); - if (lhsTy != type) - lhs = builder.createIToFInst(lhs); - if (rhsTy != type) - rhs = builder.createIToFInst(rhs); - // create the arithmetic instruction - Value *result = nullptr; - if (ctx->MUL()) - result = type->isInt() ? builder.createMulInst(lhs, rhs) - : builder.createFMulInst(lhs, rhs); - else if (ctx->DIV()) - result = type->isInt() ? builder.createDivInst(lhs, rhs) - : builder.createFDivInst(lhs, rhs); - - else - result = type->isInt() ? builder.createRemInst(lhs, rhs) - : builder.createFRemInst(lhs, rhs); - return result; -} - -any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { - auto value = - ctx->exp() ? any_cast(ctx->exp()->accept(this)) : nullptr; - Value *result = builder.createReturnInst(value); - return result; -} - -any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { - auto funcName = ctx->ID()->getText(); - auto func = dyncast(symbols.lookup(funcName)); - assert(func); - vector args; - if (auto rArgs = ctx->funcRParams()) { - for (auto exp : rArgs->exp()) { - args.push_back(any_cast(exp->accept(this))); + + any SysYIRGenerator::visitCondExp(SysYParser::ExpContext *ctx) + { + } + + any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) + { + // generate condition expression + auto cond = any_cast(ctx->exp()->accept(this)); + auto current_block = builder.getBasicBlock(); + auto func = current_block->getParent(); + // create then basicblock + auto thenblock = func->addBasicBlock("then"); + current_block->getSuccessors().push_back(thenblock); + thenblock->getPredecessors().push_back(current_block); + // create exit basicblock + auto exitblock = func->addBasicBlock("exit"); + exitblock->getPredecessors().push_back(thenblock); + thenblock->getSuccessors().push_back(exitblock); + // create condbr instr + // visit thenblock(and elseblock) + if (ctx->stmt().size() == 1) + { + // if-then + current_block->getSuccessors().push_back(exitblock); + exitblock->getPredecessors().push_back(current_block); + Value *CondBr = builder.createCondBrInst(cond, thenblock, exitblock, vector(), vector()); + builder.setPosition(thenblock, thenblock->begin()); + visitStmt(ctx->stmt()[0]); + Value *then_br = builder.createUncondBrInst(exitblock, vector()); } + if (ctx->stmt().size() == 2) + { + // if-then-else + // create else basicblock + auto elseblock = func->addBasicBlock("else"); + current_block->getSuccessors().push_back(elseblock); + elseblock->getPredecessors().push_back(current_block); + elseblock->getSuccessors().push_back(exitblock); + exitblock->getPredecessors().push_back(current_block); + CondBrInst *CondBr = builder.createCondBrInst(cond, thenblock, elseblock, vector(), vector()); + builder.setPosition(thenblock, thenblock->begin()); + visitStmt(ctx->stmt()[0]); + Value *then_br = builder.createUncondBrInst(exitblock, vector()); + builder.setPosition(elseblock, elseblock->begin()); + visitStmt(ctx->stmt()[1]); + Value *else_br = builder.createUncondBrInst(exitblock, vector()); + } + // setup the instruction insert position + builder.setPosition(exitblock, exitblock->begin()); + return builder.getBasicBlock(); } - Value *call = builder.createCallInst(func, args); - return call; -} -} // namespace sysy \ No newline at end of file + any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) + { + auto current_block = builder.getBasicBlock(); + auto func = current_block->getParent(); + // create header basicblock + auto headerblock = func->addBasicBlock("header"); + current_block->getSuccessors().push_back(headerblock); + headerblock->getPredecessors().push_back(current_block); + // uncondbr:current->header + // Value *Current_uncondbr = builder.createUncondBrInst(headerblock, vector()); + // generate condition expression + auto cond = any_cast(ctx->exp()->accept(this)); + // create body basicblock + auto bodyblock = func->addBasicBlock("body"); + headerblock->getSuccessors().push_back(bodyblock); + bodyblock->getPredecessors().push_back(headerblock); + // create exit basicblock + auto exitblock = func->addBasicBlock("exit"); + headerblock->getSuccessors().push_back(exitblock); + exitblock->getPredecessors().push_back(headerblock); + // create condbr in header + builder.setPosition(headerblock, headerblock->begin()); + Value *header_condbr = builder.createCondBrInst(cond, bodyblock, exitblock, vector(), vector()); + // generate code in body block + builder.setPosition(bodyblock, bodyblock->begin()); + visitStmt(ctx->stmt()); + // create uncondbr in body block + Value *body_uncondbr = builder.createUncondBrInst(headerblock, vector()); + // setup the instruction insert position + builder.setPosition(exitblock, exitblock->begin()); + return builder.getBasicBlock(); + } +} // namespace sysy diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h index 6b9db90..f8d70fa 100644 --- a/src/SysYIRGenerator.h +++ b/src/SysYIRGenerator.h @@ -9,222 +9,253 @@ #include #include -namespace sysy { - -class SymbolTable { -private: - enum Kind { - kModule, - kFunction, - kBlock, - }; - -public: - struct ModuleScope { - SymbolTable &table; - ModuleScope(SymbolTable &table) : table(table) { table.enter(kModule); } - ~ModuleScope() { table.exit(); } - }; - struct FunctionScope { - SymbolTable &table; - FunctionScope(SymbolTable &table) : table(table) { table.enter(kFunction); } - ~FunctionScope() { table.exit(); } - }; - - struct BlockScope { - SymbolTable &table; - BlockScope(SymbolTable &table) : table(table) { table.enter(kBlock); } - ~BlockScope() { table.exit(); } - }; - -private: - std::forward_list>> symbols; - -public: - SymbolTable() = default; - -public: - bool isModuleScope() const { return symbols.front().first == kModule; } - bool isFunctionScope() const { return symbols.front().first == kFunction; } - bool isBlockScope() const { return symbols.front().first == kBlock; } - Value *lookup(const std::string &name) const { - for (auto &scope : symbols) { - auto iter = scope.second.find(name); - if (iter != scope.second.end()) - return iter->second; - } - return nullptr; - } - auto insert(const std::string &name, Value *value) { - assert(not symbols.empty()); - return symbols.front().second.emplace(name, value); - } - // void remove(const std::string &name) { - // for (auto &scope : symbols) { - // auto iter = scope.find(name); - // if (iter != scope.end()) { - // scope.erase(iter); - // return; - // } - // } - // } -private: - void enter(Kind kind) { - symbols.emplace_front(); - symbols.front().first = kind; - } - void exit() { symbols.pop_front(); } -}; // class SymbolTable - -class SysYIRGenerator : public SysYBaseVisitor { -private: - std::unique_ptr module; - IRBuilder builder; - SymbolTable symbols; - -public: - SysYIRGenerator() = default; - -public: - Module *get() const { return module.get(); } - -public: - virtual std::any visitModule(SysYParser::ModuleContext *ctx) override; - - virtual std::any visitDecl(SysYParser::DeclContext *ctx) override; - - virtual std::any visitBtype(SysYParser::BtypeContext *ctx) override; - - virtual std::any visitVarDef(SysYParser::VarDefContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitInitValue(SysYParser::InitValueContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitFunc(SysYParser::FuncContext *ctx) override; - - virtual std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override; - - virtual std::any - visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any - visitFuncFParam(SysYParser::FuncFParamContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override; - - // virtual std::any visitBlockItem(SysYParser::BlockItemContext *ctx) - // override; - - virtual std::any visitStmt(SysYParser::StmtContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - - virtual std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any - visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any - visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; - - virtual std::any visitEmptyStmt(SysYParser::EmptyStmtContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any - visitRelationExp(SysYParser::RelationExpContext *ctx) override { - return visitChildren(ctx); - } +namespace sysy +{ + + class SymbolTable + { + private: + enum Kind + { + kModule, + kFunction, + kBlock, + }; + + public: + struct ModuleScope + { + SymbolTable &table; + ModuleScope(SymbolTable &table) : table(table) { table.enter(kModule); } + ~ModuleScope() { table.exit(); } + }; + struct FunctionScope + { + SymbolTable &table; + FunctionScope(SymbolTable &table) : table(table) { table.enter(kFunction); } + ~FunctionScope() { table.exit(); } + }; + + struct BlockScope + { + SymbolTable &table; + BlockScope(SymbolTable &table) : table(table) { table.enter(kBlock); } + ~BlockScope() { table.exit(); } + }; + + private: + std::forward_list>> symbols; + + public: + SymbolTable() = default; + + public: + bool isModuleScope() const { return symbols.front().first == kModule; } + bool isFunctionScope() const { return symbols.front().first == kFunction; } + bool isBlockScope() const { return symbols.front().first == kBlock; } + Value *lookup(const std::string &name) const + { + for (auto &scope : symbols) + { + auto iter = scope.second.find(name); + if (iter != scope.second.end()) + return iter->second; + } + return nullptr; + } + auto insert(const std::string &name, Value *value) + { + assert(not symbols.empty()); + return symbols.front().second.emplace(name, value); + } + // void remove(const std::string &name) { + // for (auto &scope : symbols) { + // auto iter = scope.find(name); + // if (iter != scope.end()) { + // scope.erase(iter); + // return; + // } + // } + // } + private: + void enter(Kind kind) + { + symbols.emplace_front(); + symbols.front().first = kind; + } + void exit() { symbols.pop_front(); } + }; // class SymbolTable + + class SysYIRGenerator : public SysYBaseVisitor + { + private: + std::unique_ptr module; + IRBuilder builder; + SymbolTable symbols; + + public: + SysYIRGenerator() = default; + + public: + Module *get() const { return module.get(); } + + public: + virtual std::any visitModule(SysYParser::ModuleContext *ctx) override; + + virtual std::any visitDecl(SysYParser::DeclContext *ctx) override; + + virtual std::any visitBtype(SysYParser::BtypeContext *ctx) override; + + virtual std::any visitVarDef(SysYParser::VarDefContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any visitInitValue(SysYParser::InitValueContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any visitFunc(SysYParser::FuncContext *ctx) override; + + virtual std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override; + + virtual std::any + visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any + visitFuncFParam(SysYParser::FuncFParamContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override; + + // virtual std::any visitBlockItem(SysYParser::BlockItemContext *ctx) + // override; + + virtual std::any visitStmt(SysYParser::StmtContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - virtual std::any - visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override; + virtual std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; + + virtual std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; + + virtual std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any + visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any + visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; + + virtual std::any visitEmptyStmt(SysYParser::EmptyStmtContext *ctx) override + { + return visitChildren(ctx); + } + + virtual std::any + visitRelationExp(SysYParser::RelationExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override; + virtual std::any + visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override; - virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override; + virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override; - virtual std::any visitAndExp(SysYParser::AndExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override; - virtual std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitAndExp(SysYParser::AndExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitParenExp(SysYParser::ParenExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitStringExp(SysYParser::StringExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitParenExp(SysYParser::ParenExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitOrExp(SysYParser::OrExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitStringExp(SysYParser::StringExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitCallExp(SysYParser::CallExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitOrExp(SysYParser::OrExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any - visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override; + virtual std::any visitCallExp(SysYParser::CallExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitEqualExp(SysYParser::EqualExpContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any + visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override; - virtual std::any visitCall(SysYParser::CallContext *ctx) override; + virtual std::any visitEqualExp(SysYParser::EqualExpContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitLValue(SysYParser::LValueContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitCall(SysYParser::CallContext *ctx) override; - virtual std::any visitNumber(SysYParser::NumberContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitLValue(SysYParser::LValueContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any visitString(SysYParser::StringContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitNumber(SysYParser::NumberContext *ctx) override + { + return visitChildren(ctx); + } - virtual std::any - visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override { - return visitChildren(ctx); - } + virtual std::any visitString(SysYParser::StringContext *ctx) override + { + return visitChildren(ctx); + } -private: - std::any visitGlobalDecl(SysYParser::DeclContext *ctx); - std::any visitLocalDecl(SysYParser::DeclContext *ctx); - Type *getArithmeticResultType(Type *lhs, Type *rhs) { - assert(lhs->isIntOrFloat() and rhs->isIntOrFloat()); - return lhs == rhs ? lhs : Type::getFloatType(); - } -}; // class SysYIRGenerator + virtual std::any + visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override + { + return visitChildren(ctx); + } + virtual std::any + visitCondExp(SysYParser::ExpContext *ctx) override; + + private: + std::any visitGlobalDecl(SysYParser::DeclContext *ctx); + std::any visitLocalDecl(SysYParser::DeclContext *ctx); + Type *getArithmeticResultType(Type *lhs, Type *rhs) + { + assert(lhs->isIntOrFloat() and rhs->isIntOrFloat()); + return lhs == rhs ? lhs : Type::getFloatType(); + } + }; // class SysYIRGenerator } // namespace sysy \ No newline at end of file