Partial implementation of IR generator.

Now can generate a single block function within +/-/*// and return.
main
Su Xing 3 years ago
parent 5a9538c0ec
commit 308bcac3fa

@ -15,6 +15,7 @@ add_executable(sysyc
sysyc.cpp
IR.cpp
SysYIRGenerator.cpp
Diagnostic.cpp
)
target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(sysyc PRIVATE SysYParser)

@ -1,28 +1,50 @@
#include "IR.h"
#include "range.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <functional>
#include <iostream>
#include <iterator>
#include <map>
#include <memory>
#include <ostream>
#include <set>
#include <string>
#include <utility>
#include <vector>
using namespace std;
namespace sysy {
template <typename T>
std::ostream &interleave(std::ostream &os, const T &container,
const std::string sep = ", ") {
ostream &interleave(ostream &os, const T &container, const string sep = ", ") {
auto b = container.begin(), e = container.end();
if (b == e)
return os;
os << *b;
for (b = std::next(b); b != e; b = std::next(b))
for (b = next(b); b != e; b = next(b))
os << sep << *b;
return os;
}
static inline ostream &printVarName(ostream &os, const Value *var) {
return os << (dynamic_cast<const GlobalValue *>(var) ? '@' : '%')
<< var->getName();
}
static inline ostream &printBlockName(ostream &os, const BasicBlock *block) {
return os << '^' << block->getName();
}
static inline ostream &printFunctionName(ostream &os, const Function *fn) {
return os << '@' << fn->getName();
}
static inline ostream &printOperand(ostream &os, const Value *value) {
auto constant = dynamic_cast<const ConstantValue *>(value);
if (constant) {
constant->print(os);
return os;
}
return printVarName(os, value);
}
//===----------------------------------------------------------------------===//
// Types
//===----------------------------------------------------------------------===//
@ -53,7 +75,7 @@ Type *Type::getPointerType(Type *baseType) {
}
Type *Type::getFunctionType(Type *returnType,
const std::vector<Type *> &paramTypes) {
const vector<Type *> &paramTypes) {
// forward to FunctionType
return FunctionType::get(returnType, paramTypes);
}
@ -74,27 +96,28 @@ int Type::getSize() const {
}
void Type::print(ostream &os) const {
switch (getKind()) {
kInt:
auto kind = getKind();
switch (kind) {
case kInt:
os << "int";
break;
kFloat:
case kFloat:
os << "float";
break;
kVoid:
case kVoid:
os << "void";
break;
kPointer:
case kPointer:
static_cast<const PointerType *>(this)->getBaseType()->print(os);
os << "*";
break;
kFunction:
case kFunction:
static_cast<const FunctionType *>(this)->getReturnType()->print(os);
os << "(";
interleave(os, static_cast<const FunctionType *>(this)->getParamTypes());
os << ")";
break;
kLabel:
case kLabel:
default:
cerr << "Unexpected type!\n";
break;
@ -138,28 +161,343 @@ void Value::replaceAllUsesWith(Value *value) {
uses.clear();
}
ConstantValue *ConstantValue::get(int value, const std::string &name) {
bool Value::isConstant() const {
if (dynamic_cast<const ConstantValue *>(this))
return true;
if (dynamic_cast<const GlobalValue *>(this) or
dynamic_cast<const Function *>(this))
return true;
if (auto array = dynamic_cast<const ArrayValue *>(this)) {
auto elements = array->getValues();
return all_of(elements.begin(), elements.end(),
[](Value *v) -> bool { return v->isConstant(); });
}
return false;
}
ConstantValue *ConstantValue::get(int value) {
static std::map<int, std::unique_ptr<ConstantValue>> intConstants;
auto iter = intConstants.find(value);
if (iter != intConstants.end())
return iter->second.get();
auto inst = new ConstantValue(value);
assert(inst);
auto result = intConstants.emplace(value, inst);
auto constant = new ConstantValue(value);
assert(constant);
auto result = intConstants.emplace(value, constant);
return result.first->second.get();
}
ConstantValue *ConstantValue::get(float value, const std::string &name) {
ConstantValue *ConstantValue::get(float value) {
static std::map<float, std::unique_ptr<ConstantValue>> floatConstants;
auto iter = floatConstants.find(value);
if (iter != floatConstants.end())
return iter->second.get();
auto inst = new ConstantValue(value);
assert(inst);
auto result = floatConstants.emplace(value, inst);
auto constant = new ConstantValue(value);
assert(constant);
auto result = floatConstants.emplace(value, constant);
return result.first->second.get();
}
void ConstantValue::print(ostream &os) const {
if (isInt())
os << getInt();
else
os << getFloat();
}
Argument::Argument(Type *type, BasicBlock *block, int index,
const std::string &name)
: Value(type, name), block(block), index(index) {
if (not hasName())
setName(to_string(block->getParent()->allocateVariableID()));
}
void Argument::print(std::ostream &os) const {
assert(hasName());
printVarName(os, this) << ": " << *getType();
}
BasicBlock::BasicBlock(Function *parent, const std::string &name)
: Value(Type::getLabelType(), name), parent(parent), instructions(),
arguments(), successors(), predecessors() {
if (not hasName())
setName("bb" + to_string(getParent()->allocateblockID()));
}
void BasicBlock::print(std::ostream &os) const {
assert(hasName());
os << " ";
printBlockName(os, this);
auto args = getArguments();
auto b = args.begin(), e = args.end();
if (b != e) {
printVarName(os, &*b) << ": " << *b->getType();
for (auto arg : make_range(std::next(b), e)) {
os << ", ";
printVarName(os, &arg) << ": " << arg.getType();
}
}
os << ":\n";
for (auto &inst : instructions) {
os << " " << *inst << '\n';
}
}
Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent,
const std::string &name)
: User(type, name), kind(kind), parent(parent) {
if (not type->isVoid() and not hasName())
setName(to_string(getFunction()->allocateVariableID()));
}
void CallInst::print(std::ostream &os) const {
if (not getType()->isVoid())
printVarName(os, this) << " = ";
printFunctionName(os, getCallee()) << '(';
auto args = getArguments();
auto b = args.begin(), e = args.end();
if (b != e) {
printOperand(os, *b);
for (auto arg : make_range(std::next(b), e)) {
os << ", ";
printOperand(os, arg);
}
}
os << ") : " << *getType();
}
void UnaryInst::print(std::ostream &os) const {
printVarName(os, this) << " = ";
switch (getKind()) {
case kNeg:
os << "neg";
break;
case kNot:
os << "not";
break;
case kFNeg:
os << "fneg";
break;
case kFtoI:
os << "ftoi";
break;
case kIToF:
os << "itof";
break;
default:
assert(false);
}
printOperand(os, getOperand()) << " : " << *getType();
}
void BinaryInst::print(std::ostream &os) const {
printVarName(os, this) << " = ";
switch (getKind()) {
case kAdd:
os << "add";
break;
case kSub:
os << "sub";
break;
case kMul:
os << "mul";
break;
case kDiv:
os << "div";
break;
case kRem:
os << "rem";
break;
case kICmpEQ:
os << "icmpeq";
break;
case kICmpNE:
os << "icmpne";
break;
case kICmpLT:
os << "icmplt";
break;
case kICmpGT:
os << "icmpgt";
break;
case kICmpLE:
os << "icmple";
break;
case kICmpGE:
os << "icmpge";
break;
case kFAdd:
os << "fadd";
break;
case kFSub:
os << "fsub";
break;
case kFMul:
os << "fmul";
break;
case kFDiv:
os << "fdiv";
break;
case kFRem:
os << "frem";
break;
case kFCmpEQ:
os << "fcmpeq";
break;
case kFCmpNE:
os << "fcmpne";
break;
case kFCmpLT:
os << "fcmplt";
break;
case kFCmpGT:
os << "fcmpgt";
break;
case kFCmpLE:
os << "fcmple";
break;
case kFCmpGE:
os << "fcmpge";
break;
default:
assert(false);
}
printOperand(os, getLhs()) << ", ";
printOperand(os, getRhs()) << " : " << *getType();
}
void ReturnInst::print(std::ostream &os) const {
os << "return";
if (auto value = getReturnValue()) {
os << ' ';
printOperand(os, value) << " : " << *value->getType();
}
}
void UncondBrInst::print(std::ostream &os) const {
os << "br ";
printBlockName(os, getBlock());
auto args = getArguments();
auto b = args.begin(), e = args.end();
if (b != e) {
os << '(';
printOperand(os, *b);
for (auto arg : make_range(std::next(b), e)) {
os << ", ";
printOperand(os, arg);
}
os << ')';
}
}
void CondBrInst::print(std::ostream &os) const {
os << "condbr ";
printOperand(os, getCondition()) << ", ";
printBlockName(os, getThenBlock());
{
auto args = getThenArguments();
auto b = args.begin(), e = args.end();
if (b != e) {
os << '(';
printOperand(os, *b);
for (auto arg : make_range(std::next(b), e)) {
os << ", ";
printOperand(os, arg);
}
os << ')';
}
}
os << ", ";
printBlockName(os, getElseBlock());
{
auto args = getElseArguments();
auto b = args.begin(), e = args.end();
if (b != e) {
os << '(';
printOperand(os, *b);
for (auto arg : make_range(std::next(b), e)) {
os << ", ";
printOperand(os, arg);
}
os << ')';
}
}
}
void AllocaInst::print(std::ostream &os) const {
if (getNumDims())
cerr << "not implemented yet\n";
printVarName(os, this) << " = ";
os << "alloca "
<< *static_cast<const PointerType *>(getType())->getBaseType();
os << " : " << *getType();
}
void LoadInst::print(std::ostream &os) const {
if (getNumIndices())
cerr << "not implemented yet\n";
printVarName(os, this) << " = ";
os << "load ";
printOperand(os, getPointer()) << " : " << *getType();
}
void StoreInst::print(std::ostream &os) const {
if (getNumIndices())
cerr << "not implemented yet\n";
os << "store ";
printOperand(os, getValue()) << ", ";
printOperand(os, getPointer()) << " : " << *getValue()->getType();
}
void Function::print(std::ostream &os) const {
auto returnType = getReturnType();
auto paramTypes = getParamTypes();
os << *returnType << ' ';
printFunctionName(os, this) << '(';
interleave(os, paramTypes) << ')';
os << "{\n";
for (auto &bb : getBasicBlocks()) {
os << *bb << '\n';
}
os << "}";
}
void Module::print(std::ostream &os) const {
for (auto &g : globals)
os << *g.second << '\n';
for (auto &f : functions)
os << *f.second << '\n';
}
ArrayValue *ArrayValue::get(Type *type, const vector<Value *> &values) {
static map<pair<Type *, size_t>, unique_ptr<ArrayValue>> arrayConstants;
hash<string> hasher;
auto key = make_pair(
type, hasher(string(reinterpret_cast<const char *>(values.data()),
values.size() * sizeof(Value *))));
auto iter = arrayConstants.find(key);
if (iter != arrayConstants.end())
return iter->second.get();
auto constant = new ArrayValue(type, values);
assert(constant);
auto result = arrayConstants.emplace(key, constant);
return result.first->second.get();
}
ArrayValue *ArrayValue::get(const std::vector<int> &values) {
vector<Value *> vals(values.size(), nullptr);
std::transform(values.begin(), values.end(), vals.begin(),
[](int v) { return ConstantValue::get(v); });
return get(Type::getIntType(), vals);
}
ArrayValue *ArrayValue::get(const std::vector<float> &values) {
vector<Value *> vals(values.size(), nullptr);
std::transform(values.begin(), values.end(), vals.begin(),
[](float v) { return ConstantValue::get(v); });
return get(Type::getFloatType(), vals);
}
void User::setOperand(int index, Value *value) {
assert(index < getNumOperands());
operands[index].setValue(value);
@ -180,7 +518,7 @@ CallInst::CallInst(Function *callee, const std::vector<Value *> args,
addOperand(arg);
}
Function *CallInst::getCallee() {
Function *CallInst::getCallee() const {
return dynamic_cast<Function *>(getOperand(0));
}

@ -4,6 +4,7 @@
#include <cassert>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <list>
#include <map>
#include <memory>
@ -66,6 +67,7 @@ public:
bool isLabel() const { return kind == kLabel; }
bool isPointer() const { return kind == kPointer; }
bool isFunction() const { return kind == kFunction; }
bool isIntOrFloat() const { return kind == kInt or kind == kFloat; }
int getSize() const;
template <typename T>
std::enable_if_t<std::is_base_of_v<Type, T>, T *> as() const {
@ -185,6 +187,9 @@ protected:
public:
Type *getType() const { return type; }
const std::string &getName() const { return name; }
void setName(const std::string &n) { name = n; }
bool hasName() const { return not name.empty(); }
bool isInt() const { return type->isInt(); }
bool isFloat() const { return type->isFloat(); }
bool isPointer() const { return type->isPointer(); }
@ -192,6 +197,10 @@ public:
void addUse(Use *use) { uses.push_back(use); }
void replaceAllUsesWith(Value *value);
void removeUse(Use *use) { uses.remove(use); }
bool isConstant() const;
public:
virtual void print(std::ostream &os) const = 0;
}; // class Value
/*!
@ -208,14 +217,13 @@ protected:
};
protected:
ConstantValue(int value, const std::string &name = "")
: Value(Type::getIntType(), name), iScalar(value) {}
ConstantValue(float value, const std::string &name = "")
: Value(Type::getFloatType(), name), fScalar(value) {}
ConstantValue(int value) : Value(Type::getIntType(), ""), iScalar(value) {}
ConstantValue(float value)
: Value(Type::getFloatType(), ""), fScalar(value) {}
public:
static ConstantValue *get(int value, const std::string &name = "");
static ConstantValue *get(float value, const std::string &name = "");
static ConstantValue *get(int value);
static ConstantValue *get(float value);
public:
int getInt() const {
@ -226,6 +234,9 @@ public:
assert(isFloat());
return fScalar;
}
public:
virtual void print(std::ostream &os) const override;
}; // class ConstantValue
class BasicBlock;
@ -246,8 +257,14 @@ protected:
public:
Argument(Type *type, BasicBlock *block, int index,
const std::string &name = "")
: Value(type, name), block(block), index(index) {}
const std::string &name = "");
public:
BasicBlock *getParent() const { return block; }
int getIndex() const { return index; }
public:
virtual void print(std::ostream &os) const override;
};
class Instruction;
@ -276,9 +293,7 @@ protected:
block_list predecessors;
protected:
explicit BasicBlock(Function *parent, const std::string &name = "")
: Value(Type::getLabelType(), name), parent(parent), instructions(),
arguments(), successors(), predecessors() {}
explicit BasicBlock(Function *parent, const std::string &name = "");
public:
int getNumInstructions() const { return instructions.size(); }
@ -287,7 +302,7 @@ public:
int getNumSuccessors() const { return successors.size(); }
Function *getParent() const { return parent; }
inst_list &getInstructions() { return instructions; }
arg_list &getArguments() { return arguments; }
auto getArguments() const { return make_range(arguments); }
block_list &getPredecessors() { return predecessors; }
block_list &getSuccessors() { return successors; }
iterator begin() { return instructions.begin(); }
@ -297,6 +312,9 @@ public:
arguments.emplace_back(type, this, arguments.size(), name);
return &arguments.back();
};
public:
virtual void print(std::ostream &os) const override;
}; // class BasicBlock
//! User is the abstract base type of `Value` types which use other `Value` as
@ -311,17 +329,25 @@ protected:
: Value(type, name), operands() {}
public:
struct operand_iterator : std::vector<Use>::iterator {
using Base = std::vector<Use>::iterator;
using Base::Base;
using use_iterator = std::vector<Use>::const_iterator;
struct operand_iterator : public std::vector<Use>::const_iterator {
using Base = std::vector<Use>::const_iterator;
operand_iterator(const Base &iter) : Base(iter) {}
using value_type = Value *;
value_type operator->() { return operator*().getValue(); }
value_type operator->() { return Base::operator*().getValue(); }
value_type operator*() { return Base::operator*().getValue(); }
};
// struct const_operand_iterator : std::vector<Use>::const_iterator {
// using Base = std::vector<Use>::const_iterator;
// const_operand_iterator(const Base &iter) : Base(iter) {}
// using value_type = Value *;
// value_type operator->() { return operator*().getValue(); }
// };
public:
int getNumOperands() const { return operands.size(); }
auto operand_begin() const { return operands.begin(); }
auto operand_end() const { return operands.end(); }
operand_iterator operand_begin() const { return operands.begin(); }
operand_iterator operand_end() const { return operands.end(); }
auto getOperands() const {
return make_range(operand_begin(), operand_end());
}
@ -336,7 +362,6 @@ public:
}
void replaceOperand(int index, Value *value);
void setOperand(int index, Value *value);
const std::string &getName() const { return name; }
}; // class User
/*!
@ -372,7 +397,7 @@ public:
// Unary
kNeg = 0x1UL << 25,
kNot = 0x1UL << 26,
kFNeg = 0x1UL << 26,
kFNeg = 0x1UL << 27,
kFtoI = 0x1UL << 28,
kIToF = 0x1UL << 29,
// call
@ -395,12 +420,12 @@ protected:
protected:
Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr,
const std::string &name = "")
: User(type, name), kind(kind), parent(parent) {}
const std::string &name = "");
public:
Kind getKind() const { return kind; }
BasicBlock *getParent() const { return parent; }
Function *getFunction() const { return parent->getParent(); }
void setParent(BasicBlock *bb) { parent = bb; }
bool isBinary() const {
@ -452,10 +477,13 @@ protected:
BasicBlock *parent = nullptr, const std::string &name = "");
public:
Function *getCallee();
auto getArguments() {
Function *getCallee() const;
auto getArguments() const {
return make_range(std::next(operand_begin()), operand_end());
}
public:
virtual void print(std::ostream &os) const override;
}; // class CallInst
//! Unary instruction, includes '!', '-' and type conversion.
@ -471,6 +499,9 @@ protected:
public:
Value *getOperand() const { return User::getOperand(0); }
public:
virtual void print(std::ostream &os) const override;
}; // class UnaryInst
//! Binary instruction, e.g., arithmatic, relation, logic, etc.
@ -488,6 +519,9 @@ protected:
public:
Value *getLhs() const { return getOperand(0); }
Value *getRhs() const { return getOperand(1); }
public:
virtual void print(std::ostream &os) const override;
}; // class BinaryInst
//! The return statement
@ -506,6 +540,9 @@ public:
Value *getReturnValue() const {
return hasReturnValue() ? getOperand(0) : nullptr;
}
public:
virtual void print(std::ostream &os) const override;
}; // class ReturnInst
//! Unconditional branch
@ -526,8 +563,11 @@ public:
return dynamic_cast<BasicBlock *>(getOperand(0));
}
auto getArguments() const {
return make_range(std::next(operands.begin()), operands.end());
return make_range(std::next(operand_begin()), operand_end());
}
public:
virtual void print(std::ostream &os) const override;
}; // class UncondBrInst
//! Conditional branch
@ -549,22 +589,27 @@ protected:
}
public:
Value *getCondition() const { return getOperand(0); }
BasicBlock *getThenBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(0));
return dynamic_cast<BasicBlock *>(getOperand(1));
}
BasicBlock *getElseBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(1));
return dynamic_cast<BasicBlock *>(getOperand(2));
}
auto getThenArguments() const {
auto begin = operands.begin() + 2;
auto end = begin + getThenBlock()->getNumArguments();
auto begin = std::next(operand_begin(), 3);
auto end = std::next(begin, getThenBlock()->getNumArguments());
return make_range(begin, end);
}
auto getElseArguments() const {
auto begin = operands.begin() + 2 + getThenBlock()->getNumArguments();
auto end = operands.end();
auto begin =
std::next(operand_begin(), 3 + getThenBlock()->getNumArguments());
auto end = operand_end();
return make_range(begin, end);
}
public:
virtual void print(std::ostream &os) const override;
}; // class CondBrInst
//! Allocate memory for stack variables, used for non-global variable declartion
@ -582,6 +627,8 @@ public:
int getNumDims() const { return getNumOperands(); }
auto getDims() const { return getOperands(); }
Value *getDim(int index) { return getOperand(index); }
public:
virtual void print(std::ostream &os) const override;
}; // class AllocaInst
//! Load a value from memory address specified by a pointer value
@ -593,6 +640,7 @@ protected:
BasicBlock *parent = nullptr, const std::string &name = "")
: Instruction(kLoad, pointer->getType()->as<PointerType>()->getBaseType(),
parent, name) {
addOperand(pointer);
addOperands(indices);
}
@ -603,6 +651,8 @@ public:
return make_range(std::next(operand_begin()), operand_end());
}
Value *getIndex(int index) const { return getOperand(index + 1); }
public:
virtual void print(std::ostream &os) const override;
}; // class LoadInst
//! Store a value to memory address specified by a pointer value
@ -624,9 +674,11 @@ public:
Value *getValue() const { return getOperand(0); }
Value *getPointer() const { return getOperand(1); }
auto getIndices() const {
return make_range(operand_begin() + 2, operand_end());
return make_range(std::next(operand_begin(), 2), operand_end());
}
Value *getIndex(int index) const { return getOperand(index + 2); }
public:
virtual void print(std::ostream &os) const override;
}; // class StoreInst
class Module;
@ -636,7 +688,7 @@ class Function : public Value {
protected:
Function(Module *parent, Type *type, const std::string &name)
: Value(type, name), parent(parent), blocks() {
: Value(type, name), parent(parent), variableID(0), blocks() {
blocks.emplace_back(new BasicBlock(this, "entry"));
}
@ -645,6 +697,8 @@ public:
protected:
Module *parent;
int variableID;
int blockID;
block_list blocks;
public:
@ -654,8 +708,8 @@ public:
auto getParamTypes() const {
return getType()->as<FunctionType>()->getParamTypes();
}
auto getBasicBlocks() { return make_range(blocks); }
BasicBlock *getEntryBlock() { return blocks.front().get(); }
auto getBasicBlocks() const { return make_range(blocks); }
BasicBlock *getEntryBlock() const { return blocks.front().get(); }
BasicBlock *addBasicBlock(const std::string &name = "") {
blocks.emplace_back(new BasicBlock(this, name));
return blocks.back().get();
@ -665,8 +719,30 @@ public:
return block == b.get();
});
}
int allocateVariableID() { return variableID++; }
int allocateblockID() { return blockID++; }
public:
virtual void print(std::ostream &os) const override;
}; // class Function
class ArrayValue : public User {
protected:
ArrayValue(Type *type, const std::vector<Value *> &values = {})
: User(type, "") {
addOperands(values);
}
public:
static ArrayValue *get(Type *type, const std::vector<Value *> &values);
static ArrayValue *get(const std::vector<int> &values);
static ArrayValue *get(const std::vector<float> &values);
public:
auto getValues() const { return getOperands(); }
public:
virtual void print(std::ostream &os) const override {};
}; // class ConstantArray
//! Global value declared at file scope
class GlobalValue : public User {
friend class Module;
@ -674,6 +750,7 @@ class GlobalValue : public User {
protected:
Module *parent;
bool hasInit;
bool isConst;
protected:
GlobalValue(Module *parent, Type *type, const std::string &name,
@ -689,6 +766,8 @@ public:
Value *init() const { return hasInit ? operands.back().getValue() : nullptr; }
int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); }
Value *getDim(int index) { return getOperand(index); }
public:
virtual void print(std::ostream &os) const override {};
}; // class GlobalValue
//! IR unit for representing a SysY compile unit
@ -708,9 +787,10 @@ public:
return result.first->second.get();
};
GlobalValue *createGlobalValue(const std::string &name, Type *type,
const std::vector<Value *> &dims = {}) {
auto result =
globals.try_emplace(name, new GlobalValue(this, type, name, dims));
const std::vector<Value *> &dims = {},
Value *init = nullptr) {
auto result = globals.try_emplace(
name, new GlobalValue(this, type, name, dims, init));
if (not result.second)
return nullptr;
return result.first->second.get();
@ -727,9 +807,21 @@ public:
return nullptr;
return result->second.get();
}
public:
void print(std::ostream &os) const;
}; // class Module
/*!
* @}
*/
inline std::ostream &operator<<(std::ostream &os, const Type &type) {
type.print(os);
return os;
}
inline std::ostream &operator<<(std::ostream &os, const Value &value) {
value.print(os);
return os;
}
} // namespace sysy

@ -99,9 +99,7 @@ btype: INT | FLOAT;
varDef: lValue (ASSIGN initValue)?;
initValue:
exp # scalarInitValue
| LBRACE (initValue (COMMA initValue)*)? # arrayInitValue;
initValue: exp | LBRACE (initValue (COMMA initValue)*)?;
func: funcType ID LPAREN funcFParams? RPAREN blockStmt;

@ -34,9 +34,9 @@ protected:
}
public:
// virtual std::any visitModule(SysYParser::ModuleContext *ctx) override {
// return visitChildren(ctx);
// }
// virtual std::any visitModule(SysYParser::ModuleContext *ctx) override {
// return visitChildren(ctx);
// }
virtual std::any visitBtype(SysYParser::BtypeContext *ctx) override {
os << ctx->getText();
@ -62,13 +62,14 @@ public:
return 0;
}
virtual std::any
visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) override {
os << '{';
auto values = ctx->initValue();
if (values.size())
interleave(values, ", ");
os << '}';
virtual std::any visitInitValue(SysYParser::InitValueContext *ctx) override {
if (not ctx->exp()) {
os << '{';
auto values = ctx->initValue();
if (values.size())
interleave(values, ", ");
os << '}';
}
return 0;
}
@ -191,8 +192,7 @@ public:
virtual std::any
visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override {
space() << ctx->CONTINUE()->getText() << ';'
<< '\n';
space() << ctx->CONTINUE()->getText() << ';' << '\n';
return 0;
}
@ -235,13 +235,15 @@ public:
return 0;
}
// virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override {
// return visitChildren(ctx);
// }
// virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx)
// override {
// return visitChildren(ctx);
// }
// virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override {
// return visitChildren(ctx);
// }
// virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx)
// override {
// return visitChildren(ctx);
// }
virtual std::any visitAndExp(SysYParser::AndExpContext *ctx) override {
ctx->exp(0)->accept(this);
@ -275,9 +277,9 @@ public:
return 0;
}
// virtual std::any visitCallExp(SysYParser::CallExpContext *ctx) override {
// return visitChildren(ctx);
// }
// virtual std::any visitCallExp(SysYParser::CallExpContext *ctx) override {
// return visitChildren(ctx);
// }
virtual std::any
visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override {

@ -1,65 +1,214 @@
#include "IR.h"
#include <any>
#include <iostream>
#include <memory>
#include <vector>
using namespace std;
#include "Diagnostic.h"
#include "SysYIRGenerator.h"
namespace sysy {
any SysYIRGenerator::visitModule(SysYParser::ModuleContext *ctx) {
// global scople of the module
SymbolTable::ModuleScope scope(symbols);
// create the IR module
auto pModule = new Module();
assert(pModule);
module.reset(pModule);
// generates globals and functions
visitChildren(ctx);
// return the IR module
return pModule;
}
any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) {
// global and local declarations are handled in different ways
return symbols.isModuleScope() ? visitGlobalDecl(ctx) : visitLocalDecl(ctx);
}
any SysYIRGenerator::visitGlobalDecl(SysYParser::DeclContext *ctx) {
error(ctx, "not implemented yet");
std::vector<Value *> values;
bool isConst = ctx->CONST();
auto type = any_cast<Type *>(visitBtype(ctx->btype()));
for (auto varDef : ctx->varDef()) {
auto name = varDef->lValue()->ID()->getText();
vector<Value *> dims;
for (auto exp : varDef->lValue()->exp())
dims.push_back(any_cast<Value *>(exp->accept(this)));
auto init = varDef->ASSIGN()
? any_cast<Value *>(visitInitValue(varDef->initValue()))
: nullptr;
values.push_back(module->createGlobalValue(name, type, dims, init));
}
// FIXME const
return values;
}
any SysYIRGenerator::visitLocalDecl(SysYParser::DeclContext *ctx) {
// a single declaration statement can declare several variables,
// collect them in a vector
std::vector<Value *> values;
// obtain the declared type
auto type = Type::getPointerType(any_cast<Type *>(visitBtype(ctx->btype())));
// handle variables
for (auto varDef : ctx->varDef()) {
// obtain the variable name and allocate space on the stack
auto name = varDef->lValue()->ID()->getText();
auto alloca = builder.createAllocaInst(type, {}, name);
// update the symbol table
symbols.insert(name, alloca);
// if an initial value is given, create a store instruction
if (varDef->ASSIGN()) {
auto value = any_cast<Value *>(visitInitValue(varDef->initValue()));
auto store = builder.createStoreInst(value, alloca);
}
// collect the created variable (pointer)
values.push_back(alloca);
}
return values;
}
any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) {
// create the function scope
SymbolTable::FunctionScope scope(symbols);
// obtain function name and type signature
auto name = ctx->ID()->getText();
auto params = ctx->funcFParams()->funcFParam();
vector<Type *> paramTypes;
vector<string> paramNames;
for (auto param : params) {
paramTypes.push_back(any_cast<Type *>(visitBtype(param->btype())));
paramNames.push_back(param->ID()->getText());
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (auto param : params) {
paramTypes.push_back(any_cast<Type *>(visitBtype(param->btype())));
paramNames.push_back(param->ID()->getText());
}
}
Type *returnType = any_cast<Type *>(visitFuncType(ctx->funcType()));
auto funcType = Type::getFunctionType(returnType, paramTypes);
// create the IR function
auto function = module->createFunction(name, funcType);
// create the entry block with the same parameters as the function,
// and update the symbol table
auto entry = function->getEntryBlock();
for (auto i = 0; i < paramTypes.size(); ++i)
entry->createArgument(paramTypes[i], paramNames[i]);
for (auto i = 0; i < paramTypes.size(); ++i) {
auto arg = entry->createArgument(paramTypes[i], paramNames[i]);
symbols.insert(paramNames[i], arg);
}
// setup the instruction insert point
builder.setPosition(entry, entry->end());
// generate the function body
visitBlockStmt(ctx->blockStmt());
return function;
}
any SysYIRGenerator::visitBtype(SysYParser::BtypeContext *ctx) {
return ctx->INT() ? Type::getIntType() : Type::getFloatType();
}
any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) {
return ctx->INT()
? Type::getIntType()
: (ctx->FLOAT() ? Type::getFloatType() : Type::getVoidType());
}
any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) {
// the insert point has already been setup
for (auto item : ctx->blockItem())
visitBlockItem(item);
// return the corresponding IR block
return builder.getBasicBlock();
}
any SysYIRGenerator::visitBlockItem(SysYParser::BlockItemContext *ctx) {
return ctx->decl() ? visitDecl(ctx->decl()) : visitStmt(ctx->stmt());
any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
// generate the rhs expression
auto rhs = any_cast<Value *>(ctx->exp()->accept(this));
// get the address of the lhs variable
auto lValue = ctx->lValue();
auto name = lValue->ID()->getText();
auto pointer = symbols.lookup(name);
if (not pointer)
error(ctx, "undefined variable");
// update the variable
Value *store = builder.createStoreInst(rhs, pointer);
return store;
}
any SysYIRGenerator::visitDecl(SysYParser::DeclContext *ctx) {
std::vector<Value *> values;
auto type = any_cast<Type *>(visitBtype(ctx->btype()));
for (auto varDef : ctx->varDef()) {
auto name = varDef->lValue()->ID()->getText();
auto alloca = builder.createAllocaInst(type, {}, name);
if (varDef->ASSIGN()) {
auto value = any_cast<Value *>(varDef->initValue()->accept(this));
auto store = builder.createStoreInst(value, alloca);
}
values.push_back(alloca);
}
return values;
any SysYIRGenerator::visitNumberExp(SysYParser::NumberExpContext *ctx) {
Value *result = nullptr;
assert(ctx->number()->ILITERAL() or ctx->number()->FLITERAL());
if (auto iLiteral = ctx->number()->ILITERAL())
result = ConstantValue::get(std::stoi(iLiteral->getText()));
else
result =
ConstantValue::get(std::stof(ctx->number()->FLITERAL()->getText()));
return result;
}
any SysYIRGenerator::visitLValueExp(SysYParser::LValueExpContext *ctx) {
auto name = ctx->lValue()->ID()->getText();
auto ptr = symbols.lookup(name);
if (not ptr)
error(ctx, "undefined variable");
Value *value = builder.createLoadInst(ptr);
return value;
}
any SysYIRGenerator::visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) {
// generate the operands
auto lhs = any_cast<Value *>(ctx->exp(0)->accept(this));
auto rhs = any_cast<Value *>(ctx->exp(1)->accept(this));
// create convert instruction if needed
auto lhsTy = lhs->getType();
auto rhsTy = rhs->getType();
auto type = getArithmeticResultType(lhsTy, rhsTy);
if (lhsTy != type)
lhs = builder.createIToFInst(lhs);
if (rhsTy != type)
rhs = builder.createIToFInst(rhs);
// create the arithmetic instruction
Value *result = nullptr;
if (ctx->ADD())
result = type->isInt() ? builder.createAddInst(lhs, rhs)
: builder.createFAddInst(lhs, rhs);
else
result = type->isInt() ? builder.createSubInst(lhs, rhs)
: builder.createFSubInst(lhs, rhs);
return result;
}
any SysYIRGenerator::visitMultiplicativeExp(
SysYParser::MultiplicativeExpContext *ctx) {
// generate the operands
auto lhs = any_cast<Value *>(ctx->exp(0)->accept(this));
auto rhs = any_cast<Value *>(ctx->exp(1)->accept(this));
// create convert instruction if needed
auto lhsTy = lhs->getType();
auto rhsTy = rhs->getType();
auto type = getArithmeticResultType(lhsTy, rhsTy);
if (lhsTy != type)
lhs = builder.createIToFInst(lhs);
if (rhsTy != type)
rhs = builder.createIToFInst(rhs);
// create the arithmetic instruction
Value *result = nullptr;
if (ctx->MUL())
result = type->isInt() ? builder.createMulInst(lhs, rhs)
: builder.createFMulInst(lhs, rhs);
else if (ctx->DIV())
result = type->isInt() ? builder.createDivInst(lhs, rhs)
: builder.createFDivInst(lhs, rhs);
else
result = type->isInt() ? builder.createRemInst(lhs, rhs)
: builder.createFRemInst(lhs, rhs);
return result;
}
any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
auto value =
ctx->exp() ? any_cast<Value *>(ctx->exp()->accept(this)) : nullptr;
Value *result = builder.createReturnInst(value);
return result;
}
} // namespace sysy

@ -1,47 +1,111 @@
#pragma once
#include <memory>
#include "IR.h"
#include "IRBuilder.h"
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include <cassert>
#include <forward_list>
#include <memory>
#include <string>
namespace sysy {
class SymbolTable {
private:
enum Kind {
kModule,
kFunction,
kBlock,
};
public:
struct ModuleScope {
SymbolTable &table;
ModuleScope(SymbolTable &table) : table(table) { table.enter(kModule); }
~ModuleScope() { table.exit(); }
};
struct FunctionScope {
SymbolTable &table;
FunctionScope(SymbolTable &table) : table(table) { table.enter(kFunction); }
~FunctionScope() { table.exit(); }
};
struct BlockScope {
SymbolTable &table;
BlockScope(SymbolTable &table) : table(table) { table.enter(kBlock); }
~BlockScope() { table.exit(); }
};
private:
std::forward_list<std::pair<Kind, std::map<std::string, Value *>>> symbols;
public:
SymbolTable() = default;
public:
bool isModuleScope() const { return symbols.front().first == kModule; }
bool isFunctionScope() const { return symbols.front().first == kFunction; }
bool isBlockScope() const { return symbols.front().first == kBlock; }
Value *lookup(const std::string &name) const {
for (auto &scope : symbols) {
auto iter = scope.second.find(name);
if (iter != scope.second.end())
return iter->second;
}
return nullptr;
}
auto insert(const std::string &name, Value *value) {
assert(not symbols.empty());
return symbols.front().second.emplace(name, value);
}
// void remove(const std::string &name) {
// for (auto &scope : symbols) {
// auto iter = scope.find(name);
// if (iter != scope.end()) {
// scope.erase(iter);
// return;
// }
// }
// }
private:
void enter(Kind kind) {
symbols.emplace_front();
symbols.front().first = kind;
}
void exit() { symbols.pop_front(); }
}; // class SymbolTable
class SysYIRGenerator : public SysYBaseVisitor {
private:
std::unique_ptr<Module> module;
IRBuilder builder;
SymbolTable symbols;
public:
SysYIRGenerator() = default;
public:
Module *get() const { return module.get(); }
public:
virtual std::any visitModule(SysYParser::ModuleContext *ctx) override;
virtual std::any visitDecl(SysYParser::DeclContext *ctx) override;
virtual std::any visitBtype(SysYParser::BtypeContext *ctx) override;
virtual std::any visitVarDef(SysYParser::VarDefContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any
visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) override {
virtual std::any visitVarDef(SysYParser::VarDefContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any
visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) override {
virtual std::any visitInitValue(SysYParser::InitValueContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any visitFunc(SysYParser::FuncContext *ctx) override;
virtual std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override;
virtual std::any
visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override {
@ -55,16 +119,14 @@ public:
virtual std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override;
virtual std::any visitBlockItem(SysYParser::BlockItemContext *ctx) override;
// virtual std::any visitBlockItem(SysYParser::BlockItemContext *ctx)
// override;
virtual std::any visitStmt(SysYParser::StmtContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any
visitAssignStmt(SysYParser::AssignStmtContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override;
virtual std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override {
return visitChildren(ctx);
@ -88,9 +150,7 @@ public:
}
virtual std::any
visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override {
return visitChildren(ctx);
}
visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override;
virtual std::any visitEmptyStmt(SysYParser::EmptyStmtContext *ctx) override {
return visitChildren(ctx);
@ -102,17 +162,11 @@ public:
}
virtual std::any
visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override {
return visitChildren(ctx);
}
visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override;
virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) override;
virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override {
return visitChildren(ctx);
}
virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) override;
virtual std::any visitAndExp(SysYParser::AndExpContext *ctx) override {
return visitChildren(ctx);
@ -139,9 +193,7 @@ public:
}
virtual std::any
visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override {
return visitChildren(ctx);
}
visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override;
virtual std::any visitEqualExp(SysYParser::EqualExpContext *ctx) override {
return visitChildren(ctx);
@ -168,6 +220,13 @@ public:
return visitChildren(ctx);
}
private:
std::any visitGlobalDecl(SysYParser::DeclContext *ctx);
std::any visitLocalDecl(SysYParser::DeclContext *ctx);
Type *getArithmeticResultType(Type *lhs, Type *rhs) {
assert(lhs->isIntOrFloat() and rhs->isIntOrFloat());
return lhs == rhs ? lhs : Type::getFloatType();
}
}; // class SysYIRGenerator
} // namespace sysy

@ -1,4 +1,3 @@
#include "ASTPrinter.h"
#include "tree/ParseTreeWalker.h"
#include <cstdlib>
#include <fstream>
@ -7,7 +6,7 @@ using namespace std;
#include "SysYLexer.h"
#include "SysYParser.h"
using namespace antlr4;
#include "SysYFormatter.h"
// #include "SysYFormatter.h"
#include "SysYIRGenerator.h"
using namespace sysy;
@ -25,10 +24,12 @@ int main(int argc, char **argv) {
SysYLexer lexer(&input);
CommonTokenStream tokens(&lexer);
SysYParser parser(&tokens);
auto module = parser.module();
auto moduleAST = parser.module();
SysYIRGenerator generator;
generator.visitModule(module);
generator.visitModule(moduleAST);
auto moduleIR = generator.get();
moduleIR->print(cout);
return EXIT_SUCCESS;
}
Loading…
Cancel
Save