diff --git a/src/IR.cpp b/src/IR.cpp index 318b8c8..038fd61 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -15,517 +15,579 @@ #include using namespace std; -namespace sysy { +namespace sysy +{ -template -ostream &interleave(ostream &os, const T &container, const string sep = ", ") { - auto b = container.begin(), e = container.end(); - if (b == e) - return os; - os << *b; - for (b = next(b); b != e; b = next(b)) - os << sep << *b; - return os; -} -static inline ostream &printVarName(ostream &os, const Value *var) { - return os << (dyncast(var) ? '@' : '%') - << var->getName(); -} -static inline ostream &printBlockName(ostream &os, const BasicBlock *block) { - return os << '^' << block->getName(); -} -static inline ostream &printFunctionName(ostream &os, const Function *fn) { - return os << '@' << fn->getName(); -} -static inline ostream &printOperand(ostream &os, const Value *value) { - auto constant = dyncast(value); - if (constant) { - constant->print(os); + template + ostream &interleave(ostream &os, const T &container, const string sep = ", ") + { + auto b = container.begin(), e = container.end(); + if (b == e) + return os; + os << *b; + for (b = next(b); b != e; b = next(b)) + os << sep << *b; return os; } - return printVarName(os, value); -} -//===----------------------------------------------------------------------===// -// Types -//===----------------------------------------------------------------------===// - -Type *Type::getIntType() { - static Type intType(kInt); - return &intType; -} - -Type *Type::getFloatType() { - static Type floatType(kFloat); - return &floatType; -} - -Type *Type::getVoidType() { - static Type voidType(kVoid); - return &voidType; -} - -Type *Type::getLabelType() { - static Type labelType(kLabel); - return &labelType; -} - -Type *Type::getPointerType(Type *baseType) { - // forward to PointerType - return PointerType::get(baseType); -} - -Type *Type::getFunctionType(Type *returnType, - const vector ¶mTypes) { - // forward to FunctionType - return FunctionType::get(returnType, paramTypes); -} - -int Type::getSize() const { - switch (kind) { - case kInt: - case kFloat: - return 4; - case kLabel: - case kPointer: - case kFunction: - return 8; - case kVoid: - return 0; + static inline ostream &printVarName(ostream &os, const Value *var) + { + return os << (dyncast(var) ? '@' : '%') + << var->getName(); } - return 0; -} - -void Type::print(ostream &os) const { - auto kind = getKind(); - switch (kind) { - case kInt: - os << "int"; - break; - case kFloat: - os << "float"; - break; - case kVoid: - os << "void"; - break; - case kPointer: - static_cast(this)->getBaseType()->print(os); - os << "*"; - break; - case kFunction: - static_cast(this)->getReturnType()->print(os); - os << "("; - interleave(os, static_cast(this)->getParamTypes()); - os << ")"; - break; - case kLabel: - default: - cerr << "Unexpected type!\n"; - break; - } -} - -PointerType *PointerType::get(Type *baseType) { - static std::map> pointerTypes; - auto iter = pointerTypes.find(baseType); - if (iter != pointerTypes.end()) - return iter->second.get(); - auto type = new PointerType(baseType); - assert(type); - auto result = pointerTypes.emplace(baseType, type); - return result.first->second.get(); -} - -FunctionType *FunctionType::get(Type *returnType, - const std::vector ¶mTypes) { - static std::set> functionTypes; - auto iter = - std::find_if(functionTypes.begin(), functionTypes.end(), - [&](const std::unique_ptr &type) -> bool { - if (returnType != type->getReturnType() or - paramTypes.size() != type->getParamTypes().size()) - return false; - return std::equal(paramTypes.begin(), paramTypes.end(), - type->getParamTypes().begin()); - }); - if (iter != functionTypes.end()) - return iter->get(); - auto type = new FunctionType(returnType, paramTypes); - assert(type); - auto result = functionTypes.emplace(type); - return result.first->get(); -} - -void Value::replaceAllUsesWith(Value *value) { - for (auto &use : uses) - use->getUser()->setOperand(use->getIndex(), value); - uses.clear(); -} - -bool Value::isConstant() const { - if (dyncast(this)) - return true; - if (dyncast(this) or - dyncast(this)) - return true; - // if (auto array = dyncast(this)) { - // auto elements = array->getValues(); - // return all_of(elements.begin(), elements.end(), - // [](Value *v) -> bool { return v->isConstant(); }); - // } - return false; -} - -ConstantValue *ConstantValue::get(int value) { - static std::map> intConstants; - auto iter = intConstants.find(value); - if (iter != intConstants.end()) - return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = intConstants.emplace(value, constant); - return result.first->second.get(); -} - -ConstantValue *ConstantValue::get(float value) { - static std::map> floatConstants; - auto iter = floatConstants.find(value); - if (iter != floatConstants.end()) - return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = floatConstants.emplace(value, constant); - return result.first->second.get(); -} - -void ConstantValue::print(ostream &os) const { - if (isInt()) - os << getInt(); - else - os << getFloat(); -} - -Argument::Argument(Type *type, BasicBlock *block, int index, - const std::string &name) - : Value(kArgument, type, name), block(block), index(index) { - if (not hasName()) - setName(to_string(block->getParent()->allocateVariableID())); -} - -void Argument::print(std::ostream &os) const { - assert(hasName()); - printVarName(os, this) << ": " << *getType(); -} - -BasicBlock::BasicBlock(Function *parent, const std::string &name) - : Value(kBasicBlock, Type::getLabelType(), name), parent(parent), - instructions(), arguments(), successors(), predecessors() { - if (not hasName()) - setName("bb" + to_string(getParent()->allocateblockID())); -} - -void BasicBlock::print(std::ostream &os) const { - assert(hasName()); - os << " "; - printBlockName(os, this); - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printVarName(os, b->get()) << ": " << *b->get()->getType(); - for (auto &arg : make_range(std::next(b), e)) { - os << ", "; - printVarName(os, arg.get()) << ": " << *arg->getType(); + static inline ostream &printBlockName(ostream &os, const BasicBlock *block) + { + return os << '^' << block->getName(); + } + static inline ostream &printFunctionName(ostream &os, const Function *fn) + { + return os << '@' << fn->getName(); + } + static inline ostream &printOperand(ostream &os, const Value *value) + { + auto constant = dyncast(value); + if (constant) + { + constant->print(os); + return os; } - os << ')'; - } - os << ":\n"; - for (auto &inst : instructions) { - os << " " << *inst << '\n'; - } -} - -Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent, - const std::string &name) - : User(kind, type, name), kind(kind), parent(parent) { - if (not type->isVoid() and not hasName()) - setName(to_string(getFunction()->allocateVariableID())); -} - -void CallInst::print(std::ostream &os) const { - if (not getType()->isVoid()) - printVarName(os, this) << " = call "; - printFunctionName(os, getCallee()) << '('; - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); + return printVarName(os, value); + } + //===----------------------------------------------------------------------===// + // Types + //===----------------------------------------------------------------------===// + + Type *Type::getIntType() + { + static Type intType(kInt); + return &intType; + } + + Type *Type::getFloatType() + { + static Type floatType(kFloat); + return &floatType; + } + + Type *Type::getVoidType() + { + static Type voidType(kVoid); + return &voidType; + } + + Type *Type::getLabelType() + { + static Type labelType(kLabel); + return &labelType; + } + + Type *Type::getPointerType(Type *baseType) + { + // forward to PointerType + return PointerType::get(baseType); + } + + Type *Type::getFunctionType(Type *returnType, + const vector ¶mTypes) + { + // forward to FunctionType + return FunctionType::get(returnType, paramTypes); + } + + int Type::getSize() const + { + switch (kind) + { + case kInt: + case kFloat: + return 4; + case kLabel: + case kPointer: + case kFunction: + return 8; + case kVoid: + return 0; } + return 0; } - os << ") : " << *getType(); -} - -void UnaryInst::print(std::ostream &os) const { - printVarName(os, this) << " = "; - switch (getKind()) { - case kNeg: - os << "neg"; - break; - case kNot: - os << "not"; - break; - case kFNeg: - os << "fneg"; - break; - case kFtoI: - os << "ftoi"; - break; - case kIToF: - os << "itof"; - break; - default: - assert(false); - } - printOperand(os, getOperand()) << " : " << *getType(); -} - -void BinaryInst::print(std::ostream &os) const { - printVarName(os, this) << " = "; - switch (getKind()) { - case kAdd: - os << "add"; - break; - case kSub: - os << "sub"; - break; - case kMul: - os << "mul"; - break; - case kDiv: - os << "div"; - break; - case kRem: - os << "rem"; - break; - case kICmpEQ: - os << "icmpeq"; - break; - case kICmpNE: - os << "icmpne"; - break; - case kICmpLT: - os << "icmplt"; - break; - case kICmpGT: - os << "icmpgt"; - break; - case kICmpLE: - os << "icmple"; - break; - case kICmpGE: - os << "icmpge"; - break; - case kFAdd: - os << "fadd"; - break; - case kFSub: - os << "fsub"; - break; - case kFMul: - os << "fmul"; - break; - case kFDiv: - os << "fdiv"; - break; - case kFRem: - os << "frem"; - break; - case kFCmpEQ: - os << "fcmpeq"; - break; - case kFCmpNE: - os << "fcmpne"; - break; - case kFCmpLT: - os << "fcmplt"; - break; - case kFCmpGT: - os << "fcmpgt"; - break; - case kFCmpLE: - os << "fcmple"; - break; - case kFCmpGE: - os << "fcmpge"; - break; - default: - assert(false); - } - os << ' '; - printOperand(os, getLhs()) << ", "; - printOperand(os, getRhs()) << " : " << *getType(); -} - -void ReturnInst::print(std::ostream &os) const { - os << "return"; - if (auto value = getReturnValue()) { - os << ' '; - printOperand(os, value) << " : " << *value->getType(); - } -} - -void UncondBrInst::print(std::ostream &os) const { - os << "br "; - printBlockName(os, getBlock()); - auto args = getArguments(); - auto b = args.begin(), e = args.end(); - if (b != e) { - os << '('; - printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { - os << ", "; - printOperand(os, arg); + + void Type::print(ostream &os) const + { + auto kind = getKind(); + switch (kind) + { + case kInt: + os << "int"; + break; + case kFloat: + os << "float"; + break; + case kVoid: + os << "void"; + break; + case kPointer: + static_cast(this)->getBaseType()->print(os); + os << "*"; + break; + case kFunction: + static_cast(this)->getReturnType()->print(os); + os << "("; + interleave(os, static_cast(this)->getParamTypes()); + os << ")"; + break; + case kLabel: + default: + cerr << "Unexpected type!\n"; + break; } - os << ')'; } -} -void CondBrInst::print(std::ostream &os) const { - os << "condbr "; - printOperand(os, getCondition()) << ", "; - printBlockName(os, getThenBlock()); + PointerType *PointerType::get(Type *baseType) + { + static std::map> pointerTypes; + auto iter = pointerTypes.find(baseType); + if (iter != pointerTypes.end()) + return iter->second.get(); + auto type = new PointerType(baseType); + assert(type); + auto result = pointerTypes.emplace(baseType, type); + return result.first->second.get(); + } + + FunctionType *FunctionType::get(Type *returnType, + const std::vector ¶mTypes) + { + static std::set> functionTypes; + auto iter = + std::find_if(functionTypes.begin(), functionTypes.end(), + [&](const std::unique_ptr &type) -> bool + { + if (returnType != type->getReturnType() or + paramTypes.size() != type->getParamTypes().size()) + return false; + return std::equal(paramTypes.begin(), paramTypes.end(), + type->getParamTypes().begin()); + }); + if (iter != functionTypes.end()) + return iter->get(); + auto type = new FunctionType(returnType, paramTypes); + assert(type); + auto result = functionTypes.emplace(type); + return result.first->get(); + } + + void Value::replaceAllUsesWith(Value *value) + { + for (auto &use : uses) + use->getUser()->setOperand(use->getIndex(), value); + uses.clear(); + } + + bool Value::isConstant() const + { + if (dyncast(this)) + return true; + if (dyncast(this) or + dyncast(this)) + return true; + // if (auto array = dyncast(this)) { + // auto elements = array->getValues(); + // return all_of(elements.begin(), elements.end(), + // [](Value *v) -> bool { return v->isConstant(); }); + // } + return false; + } + + ConstantValue *ConstantValue::get(int value) + { + static std::map> intConstants; + auto iter = intConstants.find(value); + if (iter != intConstants.end()) + return iter->second.get(); + auto constant = new ConstantValue(value); + assert(constant); + auto result = intConstants.emplace(value, constant); + return result.first->second.get(); + } + + ConstantValue *ConstantValue::get(float value) + { + static std::map> floatConstants; + auto iter = floatConstants.find(value); + if (iter != floatConstants.end()) + return iter->second.get(); + auto constant = new ConstantValue(value); + assert(constant); + auto result = floatConstants.emplace(value, constant); + return result.first->second.get(); + } + + void ConstantValue::print(ostream &os) const + { + if (isInt()) + os << getInt(); + else + os << getFloat(); + } + + Argument::Argument(Type *type, BasicBlock *block, int index, + const std::string &name) + : Value(kArgument, type, name), block(block), index(index) { - auto args = getThenArguments(); + if (not hasName()) + setName(to_string(block->getParent()->allocateVariableID())); + } + + void Argument::print(std::ostream &os) const + { + assert(hasName()); + printVarName(os, this) << ": " << *getType(); + } + + BasicBlock::BasicBlock(Function *parent, const std::string &name) + : Value(kBasicBlock, Type::getLabelType(), name), parent(parent), + instructions(), arguments(), successors(), predecessors() + { + if (not hasName()) + setName("bb" + to_string(getParent()->allocateblockID())); + } + + void BasicBlock::print(std::ostream &os) const + { + assert(hasName()); + os << " "; + printBlockName(os, this); + auto args = getArguments(); auto b = args.begin(), e = args.end(); - if (b != e) { + if (b != e) + { os << '('; + printVarName(os, b->get()) << ": " << *b->get()->getType(); + for (auto &arg : make_range(std::next(b), e)) + { + os << ", "; + printVarName(os, arg.get()) << ": " << *arg->getType(); + } + os << ')'; + } + os << ":\n"; + for (auto &inst : instructions) + { + os << " " << *inst << '\n'; + } + } + + Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent, + const std::string &name) + : User(kind, type, name), kind(kind), parent(parent) + { + if (not type->isVoid() and not hasName()) + setName(to_string(getFunction()->allocateVariableID())); + } + + void CallInst::print(std::ostream &os) const + { + if (not getType()->isVoid()) + printVarName(os, this) << " = call "; + printFunctionName(os, getCallee()) << '('; + auto args = getArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) + { printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { + for (auto arg : make_range(std::next(b), e)) + { os << ", "; printOperand(os, arg); } - os << ')'; } + os << ") : " << *getType(); + } + + void UnaryInst::print(std::ostream &os) const + { + printVarName(os, this) << " = "; + switch (getKind()) + { + case kNeg: + os << "neg"; + break; + case kNot: + os << "not"; + break; + case kFNeg: + os << "fneg"; + break; + case kFtoI: + os << "ftoi"; + break; + case kIToF: + os << "itof"; + break; + default: + assert(false); + } + os << ' '; + printOperand(os, getOperand()) << " : " << *getType(); + } + + void BinaryInst::print(std::ostream &os) const + { + printVarName(os, this) << " = "; + switch (getKind()) + { + case kAdd: + os << "add"; + break; + case kSub: + os << "sub"; + break; + case kMul: + os << "mul"; + break; + case kDiv: + os << "div"; + break; + case kRem: + os << "rem"; + break; + case kICmpEQ: + os << "icmpeq"; + break; + case kICmpNE: + os << "icmpne"; + break; + case kICmpLT: + os << "icmplt"; + break; + case kICmpGT: + os << "icmpgt"; + break; + case kICmpLE: + os << "icmple"; + break; + case kICmpGE: + os << "icmpge"; + break; + case kFAdd: + os << "fadd"; + break; + case kFSub: + os << "fsub"; + break; + case kFMul: + os << "fmul"; + break; + case kFDiv: + os << "fdiv"; + break; + case kFRem: + os << "frem"; + break; + case kFCmpEQ: + os << "fcmpeq"; + break; + case kFCmpNE: + os << "fcmpne"; + break; + case kFCmpLT: + os << "fcmplt"; + break; + case kFCmpGT: + os << "fcmpgt"; + break; + case kFCmpLE: + os << "fcmple"; + break; + case kFCmpGE: + os << "fcmpge"; + break; + default: + assert(false); + } + os << ' '; + printOperand(os, getLhs()) << ", "; + printOperand(os, getRhs()) << " : " << *getType(); } - os << ", "; - printBlockName(os, getElseBlock()); + + void ReturnInst::print(std::ostream &os) const { - auto args = getElseArguments(); + os << "return"; + if (auto value = getReturnValue()) + { + os << ' '; + printOperand(os, value) << " : " << *value->getType(); + } + } + + void UncondBrInst::print(std::ostream &os) const + { + os << "br "; + printBlockName(os, getBlock()); + auto args = getArguments(); auto b = args.begin(), e = args.end(); - if (b != e) { + if (b != e) + { os << '('; printOperand(os, *b); - for (auto arg : make_range(std::next(b), e)) { + for (auto arg : make_range(std::next(b), e)) + { os << ", "; printOperand(os, arg); } os << ')'; } } -} - -void AllocaInst::print(std::ostream &os) const { - if (getNumDims()) - cerr << "not implemented yet\n"; - printVarName(os, this) << " = "; - os << "alloca " - << *static_cast(getType())->getBaseType(); - os << " : " << *getType(); -} - -void LoadInst::print(std::ostream &os) const { - if (getNumIndices()) - cerr << "not implemented yet\n"; - printVarName(os, this) << " = "; - os << "load "; - printOperand(os, getPointer()) << " : " << *getType(); -} - -void StoreInst::print(std::ostream &os) const { - if (getNumIndices()) - cerr << "not implemented yet\n"; - os << "store "; - printOperand(os, getValue()) << ", "; - printOperand(os, getPointer()) << " : " << *getValue()->getType(); -} - -void Function::print(std::ostream &os) const { - auto returnType = getReturnType(); - auto paramTypes = getParamTypes(); - os << *returnType << ' '; - printFunctionName(os, this) << '('; - auto b = paramTypes.begin(), e = paramTypes.end(); - if (b != e) { - os << *(*b); - for (auto type : make_range(std::next(b), e)) - os << ", " << *type; - } - os << ") {\n"; - for (auto &bb : getBasicBlocks()) { - os << *bb << '\n'; - } - os << "}"; -} - -void Module::print(std::ostream &os) const { - for (auto &value : children) - os << *value << '\n'; -} - -// ArrayValue *ArrayValue::get(Type *type, const vector &values) { -// static map, unique_ptr> arrayConstants; -// hash hasher; -// auto key = make_pair( -// type, hasher(string(reinterpret_cast(values.data()), -// values.size() * sizeof(Value *)))); - -// auto iter = arrayConstants.find(key); -// if (iter != arrayConstants.end()) -// return iter->second.get(); -// auto constant = new ArrayValue(type, values); -// assert(constant); -// auto result = arrayConstants.emplace(key, constant); -// return result.first->second.get(); -// } - -// ArrayValue *ArrayValue::get(const std::vector &values) { -// vector vals(values.size(), nullptr); -// std::transform(values.begin(), values.end(), vals.begin(), -// [](int v) { return ConstantValue::get(v); }); -// return get(Type::getIntType(), vals); -// } - -// ArrayValue *ArrayValue::get(const std::vector &values) { -// vector vals(values.size(), nullptr); -// std::transform(values.begin(), values.end(), vals.begin(), -// [](float v) { return ConstantValue::get(v); }); -// return get(Type::getFloatType(), vals); -// } - -void User::setOperand(int index, Value *value) { - assert(index < getNumOperands()); - operands[index].setValue(value); -} - -void User::replaceOperand(int index, Value *value) { - assert(index < getNumOperands()); - auto &use = operands[index]; - use.getValue()->removeUse(&use); - use.setValue(value); -} - -CallInst::CallInst(Function *callee, const std::vector &args, - BasicBlock *parent, const std::string &name) - : Instruction(kCall, callee->getReturnType(), parent, name) { - addOperand(callee); - for (auto arg : args) - addOperand(arg); -} - -Function *CallInst::getCallee() const { - return dyncast(getOperand(0)); -} + + void CondBrInst::print(std::ostream &os) const + { + os << "condbr "; + printOperand(os, getCondition()) << ", "; + printBlockName(os, getThenBlock()); + { + auto args = getThenArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) + { + os << '('; + printOperand(os, *b); + for (auto arg : make_range(std::next(b), e)) + { + os << ", "; + printOperand(os, arg); + } + os << ')'; + } + } + os << ", "; + printBlockName(os, getElseBlock()); + { + auto args = getElseArguments(); + auto b = args.begin(), e = args.end(); + if (b != e) + { + os << '('; + printOperand(os, *b); + for (auto arg : make_range(std::next(b), e)) + { + os << ", "; + printOperand(os, arg); + } + os << ')'; + } + } + } + + void AllocaInst::print(std::ostream &os) const + { + if (getNumDims()) + cerr << "not implemented yet\n"; + printVarName(os, this) << " = "; + os << "alloca " + << *static_cast(getType())->getBaseType(); + os << " : " << *getType(); + } + + void LoadInst::print(std::ostream &os) const + { + if (getNumIndices()) + cerr << "not implemented yet\n"; + printVarName(os, this) << " = "; + os << "load "; + printOperand(os, getPointer()) << " : " << *getType(); + } + + void StoreInst::print(std::ostream &os) const + { + if (getNumIndices()) + cerr << "not implemented yet\n"; + os << "store "; + printOperand(os, getValue()) << ", "; + printOperand(os, getPointer()) << " : " << *getValue()->getType(); + } + + void Function::print(std::ostream &os) const + { + auto returnType = getReturnType(); + auto paramTypes = getParamTypes(); + os << *returnType << ' '; + printFunctionName(os, this) << '('; + auto b = paramTypes.begin(), e = paramTypes.end(); + if (b != e) + { + os << *(*b); + for (auto type : make_range(std::next(b), e)) + os << ", " << *type; + } + os << ") {\n"; + for (auto &bb : getBasicBlocks()) + { + os << *bb << '\n'; + } + os << "}"; + } + + void Module::print(std::ostream &os) const + { + for (auto &value : children) + os << *value << '\n'; + } + + // ArrayValue *ArrayValue::get(Type *type, const vector &values) { + // static map, unique_ptr> arrayConstants; + // hash hasher; + // auto key = make_pair( + // type, hasher(string(reinterpret_cast(values.data()), + // values.size() * sizeof(Value *)))); + + // auto iter = arrayConstants.find(key); + // if (iter != arrayConstants.end()) + // return iter->second.get(); + // auto constant = new ArrayValue(type, values); + // assert(constant); + // auto result = arrayConstants.emplace(key, constant); + // return result.first->second.get(); + // } + + // ArrayValue *ArrayValue::get(const std::vector &values) { + // vector vals(values.size(), nullptr); + // std::transform(values.begin(), values.end(), vals.begin(), + // [](int v) { return ConstantValue::get(v); }); + // return get(Type::getIntType(), vals); + // } + + // ArrayValue *ArrayValue::get(const std::vector &values) { + // vector vals(values.size(), nullptr); + // std::transform(values.begin(), values.end(), vals.begin(), + // [](float v) { return ConstantValue::get(v); }); + // return get(Type::getFloatType(), vals); + // } + + void User::setOperand(int index, Value *value) + { + assert(index < getNumOperands()); + operands[index].setValue(value); + } + + void User::replaceOperand(int index, Value *value) + { + assert(index < getNumOperands()); + auto &use = operands[index]; + use.getValue()->removeUse(&use); + use.setValue(value); + } + + CallInst::CallInst(Function *callee, const std::vector &args, + BasicBlock *parent, const std::string &name) + : Instruction(kCall, callee->getReturnType(), parent, name) + { + addOperand(callee); + for (auto arg : args) + addOperand(arg); + } + + Function *CallInst::getCallee() const + { + return dyncast(getOperand(0)); + } } // namespace sysy \ No newline at end of file diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 82e6a30..e8131b4 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -260,6 +260,7 @@ namespace sysy { // generate condition expression auto cond = any_cast(ctx->exp()->accept(this)); + cond->setName("flag"); auto current_block = builder.getBasicBlock(); auto func = current_block->getParent(); // create then basicblock @@ -336,4 +337,74 @@ namespace sysy builder.setPosition(exitblock, exitblock->begin()); return builder.getBasicBlock(); } + //******************Revised by lyq BEGIN************************************* + any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) + { + // generate the operands + auto hs = any_cast(ctx->exp()->accept(this)); + + Value *result = nullptr; + if (ctx->SUB()) + result = builder.createNegInst(hs); + else if (ctx->NOT()) + result = builder.createNotInst(hs); + else if (ctx->ADD()) + result = hs; + return result; + } + + any SysYIRGenerator::visitRelationExp(SysYParser::RelationExpContext *ctx) + { + // generate the operands + auto lhs = any_cast(ctx->exp(0)->accept(this)); + auto rhs = any_cast(ctx->exp(1)->accept(this)); + // create convert instruction if needed + auto lhsTy = lhs->getType(); + auto rhsTy = rhs->getType(); + auto type = getArithmeticResultType(lhsTy, rhsTy); + if (lhsTy != type) + lhs = builder.createIToFInst(lhs); + if (rhsTy != type) + rhs = builder.createIToFInst(rhs); + + Value *result = nullptr; + if (ctx->LT()) + result = type->isInt() ? builder.createICmpLTInst(lhs, rhs) + : builder.createFCmpLTInst(lhs, rhs); + else if (ctx->GT()) + result = type->isInt() ? builder.createICmpGTInst(lhs, rhs) + : builder.createFCmpGTInst(lhs, rhs); + else if (ctx->LE()) + result = type->isInt() ? builder.createICmpLEInst(lhs, rhs) + : builder.createFCmpLEInst(lhs, rhs); + else if (ctx->GE()) + result = type->isInt() ? builder.createICmpGEInst(lhs, rhs) + : builder.createFCmpGEInst(lhs, rhs); + return result; + } + + any SysYIRGenerator::visitEqualExp(SysYParser::EqualExpContext *ctx) + { + // generate the operands + auto lhs = any_cast(ctx->exp(0)->accept(this)); + auto rhs = any_cast(ctx->exp(1)->accept(this)); + // create convert instruction if needed + auto lhsTy = lhs->getType(); + auto rhsTy = rhs->getType(); + auto type = getArithmeticResultType(lhsTy, rhsTy); + if (lhsTy != type) + lhs = builder.createIToFInst(lhs); + if (rhsTy != type) + rhs = builder.createIToFInst(rhs); + + Value *result = nullptr; + if (ctx->EQ()) + result = type->isInt() ? builder.createICmpEQInst(lhs, rhs) + : builder.createFCmpEQInst(lhs, rhs); + else if (ctx->NE()) + result = type->isInt() ? builder.createICmpNEInst(lhs, rhs) + : builder.createFCmpNEInst(lhs, rhs); + return result; + } + //******************Revised by lyq END*************************************** } // namespace sysy diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h index 5264c09..fc678a3 100644 --- a/src/SysYIRGenerator.h +++ b/src/SysYIRGenerator.h @@ -173,10 +173,7 @@ namespace sysy } virtual std::any - visitRelationExp(SysYParser::RelationExpContext *ctx) override - { - return visitChildren(ctx); - } + visitRelationExp(SysYParser::RelationExpContext *ctx) override; virtual std::any visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override; @@ -190,10 +187,7 @@ namespace sysy return visitChildren(ctx); } - virtual std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override - { - return visitChildren(ctx); - } + virtual std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override; virtual std::any visitParenExp(SysYParser::ParenExpContext *ctx) override { @@ -218,10 +212,7 @@ namespace sysy virtual std::any visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override; - virtual std::any visitEqualExp(SysYParser::EqualExpContext *ctx) override - { - return visitChildren(ctx); - } + virtual std::any visitEqualExp(SysYParser::EqualExpContext *ctx) override; virtual std::any visitCall(SysYParser::CallContext *ctx) override; diff --git a/test/if_else.sy b/test/if_else.sy index 399264f..bd36a7a 100644 --- a/test/if_else.sy +++ b/test/if_else.sy @@ -8,7 +8,7 @@ int main(){ int a, b; a = 10; b = 0; - if(a+1) + if(a>0) a = mul(a, b); else b = mul(a, b);