Support function call and IR printing.

main
Xing Su 3 years ago
parent 4795a4d813
commit 8208469b13

@ -28,7 +28,7 @@ ostream &interleave(ostream &os, const T &container, const string sep = ", ") {
return os;
}
static inline ostream &printVarName(ostream &os, const Value *var) {
return os << (dynamic_cast<const GlobalValue *>(var) ? '@' : '%')
return os << (dyncast<GlobalValue>(var) ? '@' : '%')
<< var->getName();
}
static inline ostream &printBlockName(ostream &os, const BasicBlock *block) {
@ -38,7 +38,7 @@ static inline ostream &printFunctionName(ostream &os, const Function *fn) {
return os << '@' << fn->getName();
}
static inline ostream &printOperand(ostream &os, const Value *value) {
auto constant = dynamic_cast<const ConstantValue *>(value);
auto constant = dyncast<ConstantValue>(value);
if (constant) {
constant->print(os);
return os;
@ -162,16 +162,16 @@ void Value::replaceAllUsesWith(Value *value) {
}
bool Value::isConstant() const {
if (dynamic_cast<const ConstantValue *>(this))
if (dyncast<ConstantValue>(this))
return true;
if (dynamic_cast<const GlobalValue *>(this) or
dynamic_cast<const Function *>(this))
if (dyncast<GlobalValue>(this) or
dyncast<Function>(this))
return true;
if (auto array = dynamic_cast<const ArrayValue *>(this)) {
auto elements = array->getValues();
return all_of(elements.begin(), elements.end(),
[](Value *v) -> bool { return v->isConstant(); });
}
// if (auto array = dyncast<const ArrayValue *>(this)) {
// auto elements = array->getValues();
// return all_of(elements.begin(), elements.end(),
// [](Value *v) -> bool { return v->isConstant(); });
// }
return false;
}
@ -206,7 +206,7 @@ void ConstantValue::print(ostream &os) const {
Argument::Argument(Type *type, BasicBlock *block, int index,
const std::string &name)
: Value(type, name), block(block), index(index) {
: Value(kArgument, type, name), block(block), index(index) {
if (not hasName())
setName(to_string(block->getParent()->allocateVariableID()));
}
@ -217,8 +217,8 @@ void Argument::print(std::ostream &os) const {
}
BasicBlock::BasicBlock(Function *parent, const std::string &name)
: Value(Type::getLabelType(), name), parent(parent), instructions(),
arguments(), successors(), predecessors() {
: Value(kBasicBlock, Type::getLabelType(), name), parent(parent),
instructions(), arguments(), successors(), predecessors() {
if (not hasName())
setName("bb" + to_string(getParent()->allocateblockID()));
}
@ -230,11 +230,13 @@ void BasicBlock::print(std::ostream &os) const {
auto args = getArguments();
auto b = args.begin(), e = args.end();
if (b != e) {
printVarName(os, &*b) << ": " << *b->getType();
for (auto arg : make_range(std::next(b), e)) {
os << '(';
printVarName(os, b->get()) << ": " << *b->get()->getType();
for (auto &arg : make_range(std::next(b), e)) {
os << ", ";
printVarName(os, &arg) << ": " << arg.getType();
printVarName(os, arg.get()) << ": " << *arg->getType();
}
os << ')';
}
os << ":\n";
for (auto &inst : instructions) {
@ -244,14 +246,14 @@ void BasicBlock::print(std::ostream &os) const {
Instruction::Instruction(Kind kind, Type *type, BasicBlock *parent,
const std::string &name)
: User(type, name), kind(kind), parent(parent) {
: User(kind, type, name), kind(kind), parent(parent) {
if (not type->isVoid() and not hasName())
setName(to_string(getFunction()->allocateVariableID()));
}
void CallInst::print(std::ostream &os) const {
if (not getType()->isVoid())
printVarName(os, this) << " = ";
printVarName(os, this) << " = call ";
printFunctionName(os, getCallee()) << '(';
auto args = getArguments();
auto b = args.begin(), e = args.end();
@ -361,6 +363,7 @@ void BinaryInst::print(std::ostream &os) const {
default:
assert(false);
}
os << ' ';
printOperand(os, getLhs()) << ", ";
printOperand(os, getRhs()) << " : " << *getType();
}
@ -453,8 +456,13 @@ void Function::print(std::ostream &os) const {
auto paramTypes = getParamTypes();
os << *returnType << ' ';
printFunctionName(os, this) << '(';
interleave(os, paramTypes) << ')';
os << "{\n";
auto b = paramTypes.begin(), e = paramTypes.end();
if (b != e) {
os << *(*b);
for (auto type : make_range(std::next(b), e))
os << ", " << *type;
}
os << ") {\n";
for (auto &bb : getBasicBlocks()) {
os << *bb << '\n';
}
@ -468,35 +476,35 @@ void Module::print(std::ostream &os) const {
os << *f.second << '\n';
}
ArrayValue *ArrayValue::get(Type *type, const vector<Value *> &values) {
static map<pair<Type *, size_t>, unique_ptr<ArrayValue>> arrayConstants;
hash<string> hasher;
auto key = make_pair(
type, hasher(string(reinterpret_cast<const char *>(values.data()),
values.size() * sizeof(Value *))));
auto iter = arrayConstants.find(key);
if (iter != arrayConstants.end())
return iter->second.get();
auto constant = new ArrayValue(type, values);
assert(constant);
auto result = arrayConstants.emplace(key, constant);
return result.first->second.get();
}
ArrayValue *ArrayValue::get(const std::vector<int> &values) {
vector<Value *> vals(values.size(), nullptr);
std::transform(values.begin(), values.end(), vals.begin(),
[](int v) { return ConstantValue::get(v); });
return get(Type::getIntType(), vals);
}
ArrayValue *ArrayValue::get(const std::vector<float> &values) {
vector<Value *> vals(values.size(), nullptr);
std::transform(values.begin(), values.end(), vals.begin(),
[](float v) { return ConstantValue::get(v); });
return get(Type::getFloatType(), vals);
}
// ArrayValue *ArrayValue::get(Type *type, const vector<Value *> &values) {
// static map<pair<Type *, size_t>, unique_ptr<ArrayValue>> arrayConstants;
// hash<string> hasher;
// auto key = make_pair(
// type, hasher(string(reinterpret_cast<const char *>(values.data()),
// values.size() * sizeof(Value *))));
// auto iter = arrayConstants.find(key);
// if (iter != arrayConstants.end())
// return iter->second.get();
// auto constant = new ArrayValue(type, values);
// assert(constant);
// auto result = arrayConstants.emplace(key, constant);
// return result.first->second.get();
// }
// ArrayValue *ArrayValue::get(const std::vector<int> &values) {
// vector<Value *> vals(values.size(), nullptr);
// std::transform(values.begin(), values.end(), vals.begin(),
// [](int v) { return ConstantValue::get(v); });
// return get(Type::getIntType(), vals);
// }
// ArrayValue *ArrayValue::get(const std::vector<float> &values) {
// vector<Value *> vals(values.size(), nullptr);
// std::transform(values.begin(), values.end(), vals.begin(),
// [](float v) { return ConstantValue::get(v); });
// return get(Type::getFloatType(), vals);
// }
void User::setOperand(int index, Value *value) {
assert(index < getNumOperands());
@ -519,7 +527,7 @@ CallInst::CallInst(Function *callee, const std::vector<Value *> &args,
}
Function *CallInst::getCallee() const {
return dynamic_cast<Function *>(getOperand(0));
return dyncast<Function>(getOperand(0));
}
} // namespace sysy

@ -2,6 +2,7 @@
#include "range.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iterator>
@ -173,18 +174,91 @@ public:
void setValue(Value *value) { value = value; }
}; // class Use
template <typename T>
inline std::enable_if_t<std::is_base_of_v<Value, T>, bool> isa(const Value *value) {
return T::classof(value);
}
template <typename T>
inline std::enable_if_t<std::is_base_of_v<Value, T>, T *> dyncast(Value *value) {
return isa<T>(value) ? static_cast<T *>(value) : nullptr;
}
template <typename T>
inline std::enable_if_t<std::is_base_of_v<Value, T>, const T *> dyncast(const Value *value) {
return isa<T>(value) ? static_cast<const T *>(value) : nullptr;
}
//! The base class of all value types
class Value {
public:
enum Kind : uint64_t {
kInvalid,
// Instructions
// Binary
kAdd = 0x1UL << 0,
kSub = 0x1UL << 1,
kMul = 0x1UL << 2,
kDiv = 0x1UL << 3,
kRem = 0x1UL << 4,
kICmpEQ = 0x1UL << 5,
kICmpNE = 0x1UL << 6,
kICmpLT = 0x1UL << 7,
kICmpGT = 0x1UL << 8,
kICmpLE = 0x1UL << 9,
kICmpGE = 0x1UL << 10,
kFAdd = 0x1UL << 14,
kFSub = 0x1UL << 15,
kFMul = 0x1UL << 16,
kFDiv = 0x1UL << 17,
kFRem = 0x1UL << 18,
kFCmpEQ = 0x1UL << 19,
kFCmpNE = 0x1UL << 20,
kFCmpLT = 0x1UL << 21,
kFCmpGT = 0x1UL << 22,
kFCmpLE = 0x1UL << 23,
kFCmpGE = 0x1UL << 24,
// Unary
kNeg = 0x1UL << 25,
kNot = 0x1UL << 26,
kFNeg = 0x1UL << 27,
kFtoI = 0x1UL << 28,
kIToF = 0x1UL << 29,
// call
kCall = 0x1UL << 30,
// terminator
kCondBr = 0x1UL << 31,
kBr = 0x1UL << 32,
kReturn = 0x1UL << 33,
// mem op
kAlloca = 0x1UL << 34,
kLoad = 0x1UL << 35,
kStore = 0x1UL << 36,
kFirstInst = kAdd,
kLastInst = kStore,
// others
kArgument = 0x1UL << 37,
kBasicBlock = 0x1UL << 38,
kFunction = 0x1UL << 39,
kConstant = 0x1UL << 40,
kGlobal = 0x1UL << 41,
};
protected:
Kind kind;
Type *type;
std::string name;
std::list<Use *> uses;
protected:
Value(Type *type, const std::string &name = "")
: type(type), name(name), uses() {}
Value(Kind kind, Type *type, const std::string &name = "")
: kind(kind), type(type), name(name), uses() {}
virtual ~Value() = default;
public:
Kind getKind() const { return kind; }
static bool classof(const Value *) { return true; }
public:
Type *getType() const { return type; }
const std::string &getName() const { return name; }
@ -200,7 +274,7 @@ public:
bool isConstant() const;
public:
virtual void print(std::ostream &os) const = 0;
virtual void print(std::ostream &os) const {};
}; // class Value
/*!
@ -217,14 +291,18 @@ protected:
};
protected:
ConstantValue(int value) : Value(Type::getIntType(), ""), iScalar(value) {}
ConstantValue(int value)
: Value(kConstant, Type::getIntType(), ""), iScalar(value) {}
ConstantValue(float value)
: Value(Type::getFloatType(), ""), fScalar(value) {}
: Value(kConstant, Type::getFloatType(), ""), fScalar(value) {}
public:
static ConstantValue *get(int value);
static ConstantValue *get(float value);
public:
static bool classof(const Value *value) { return value->getKind() == kConstant; }
public:
int getInt() const {
assert(isInt());
@ -259,6 +337,9 @@ public:
Argument(Type *type, BasicBlock *block, int index,
const std::string &name = "");
public:
static bool classof(const Value *value) { return value->getKind() == kConstant; }
public:
BasicBlock *getParent() const { return block; }
int getIndex() const { return index; }
@ -282,7 +363,7 @@ class BasicBlock : public Value {
public:
using inst_list = std::list<std::unique_ptr<Instruction>>;
using iterator = inst_list::iterator;
using arg_list = std::vector<Argument>;
using arg_list = std::vector<std::unique_ptr<Argument>>;
using block_list = std::vector<BasicBlock *>;
protected:
@ -295,6 +376,9 @@ protected:
protected:
explicit BasicBlock(Function *parent, const std::string &name = "");
public:
static bool classof(const Value *value) { return value->getKind() == kBasicBlock; }
public:
int getNumInstructions() const { return instructions.size(); }
int getNumArguments() const { return arguments.size(); }
@ -309,8 +393,10 @@ public:
iterator end() { return instructions.end(); }
iterator terminator() { return std::prev(end()); }
Argument *createArgument(Type *type, const std::string &name = "") {
arguments.emplace_back(type, this, arguments.size(), name);
return &arguments.back();
auto arg = new Argument(type, this, arguments.size(), name);
assert(arg);
arguments.emplace_back(arg);
return arguments.back().get();
};
public:
@ -325,8 +411,8 @@ protected:
std::vector<Use> operands;
protected:
User(Type *type, const std::string &name = "")
: Value(type, name), operands() {}
User(Kind kind, Type *type, const std::string &name = "")
: Value(kind, type, name), operands() {}
public:
using use_iterator = std::vector<Use>::const_iterator;
@ -369,50 +455,50 @@ public:
*/
class Instruction : public User {
public:
enum Kind : uint64_t {
kInvalid = 0x0UL,
// Binary
kAdd = 0x1UL << 0,
kSub = 0x1UL << 1,
kMul = 0x1UL << 2,
kDiv = 0x1UL << 3,
kRem = 0x1UL << 4,
kICmpEQ = 0x1UL << 5,
kICmpNE = 0x1UL << 6,
kICmpLT = 0x1UL << 7,
kICmpGT = 0x1UL << 8,
kICmpLE = 0x1UL << 9,
kICmpGE = 0x1UL << 10,
kFAdd = 0x1UL << 14,
kFSub = 0x1UL << 15,
kFMul = 0x1UL << 16,
kFDiv = 0x1UL << 17,
kFRem = 0x1UL << 18,
kFCmpEQ = 0x1UL << 19,
kFCmpNE = 0x1UL << 20,
kFCmpLT = 0x1UL << 21,
kFCmpGT = 0x1UL << 22,
kFCmpLE = 0x1UL << 23,
kFCmpGE = 0x1UL << 24,
// Unary
kNeg = 0x1UL << 25,
kNot = 0x1UL << 26,
kFNeg = 0x1UL << 27,
kFtoI = 0x1UL << 28,
kIToF = 0x1UL << 29,
// call
kCall = 0x1UL << 30,
// terminator
kCondBr = 0x1UL << 31,
kBr = 0x1UL << 32,
kReturn = 0x1UL << 33,
// mem op
kAlloca = 0x1UL << 34,
kLoad = 0x1UL << 35,
kStore = 0x1UL << 36,
// constant
// kConstant = 0x1UL << 37,
};
// enum Kind : uint64_t {
// kInvalid = 0x0UL,
// // Binary
// kAdd = 0x1UL << 0,
// kSub = 0x1UL << 1,
// kMul = 0x1UL << 2,
// kDiv = 0x1UL << 3,
// kRem = 0x1UL << 4,
// kICmpEQ = 0x1UL << 5,
// kICmpNE = 0x1UL << 6,
// kICmpLT = 0x1UL << 7,
// kICmpGT = 0x1UL << 8,
// kICmpLE = 0x1UL << 9,
// kICmpGE = 0x1UL << 10,
// kFAdd = 0x1UL << 14,
// kFSub = 0x1UL << 15,
// kFMul = 0x1UL << 16,
// kFDiv = 0x1UL << 17,
// kFRem = 0x1UL << 18,
// kFCmpEQ = 0x1UL << 19,
// kFCmpNE = 0x1UL << 20,
// kFCmpLT = 0x1UL << 21,
// kFCmpGT = 0x1UL << 22,
// kFCmpLE = 0x1UL << 23,
// kFCmpGE = 0x1UL << 24,
// // Unary
// kNeg = 0x1UL << 25,
// kNot = 0x1UL << 26,
// kFNeg = 0x1UL << 27,
// kFtoI = 0x1UL << 28,
// kIToF = 0x1UL << 29,
// // call
// kCall = 0x1UL << 30,
// // terminator
// kCondBr = 0x1UL << 31,
// kBr = 0x1UL << 32,
// kReturn = 0x1UL << 33,
// // mem op
// kAlloca = 0x1UL << 34,
// kLoad = 0x1UL << 35,
// kStore = 0x1UL << 36,
// // constant
// // kConstant = 0x1UL << 37,
// };
protected:
Kind kind;
@ -422,6 +508,11 @@ protected:
Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr,
const std::string &name = "");
public:
static bool classof(const Value *value) {
return value->getKind() >= kFirstInst and value->getKind() <= kLastInst;
}
public:
Kind getKind() const { return kind; }
BasicBlock *getParent() const { return parent; }
@ -476,6 +567,9 @@ protected:
CallInst(Function *callee, const std::vector<Value *> &args = {},
BasicBlock *parent = nullptr, const std::string &name = "");
public:
static bool classof(const Value *value) { return value->getKind() == kCall; }
public:
Function *getCallee() const;
auto getArguments() const {
@ -497,6 +591,12 @@ protected:
addOperand(operand);
}
public:
static bool classof(const Value *value) {
return Instruction::classof(value) and
static_cast<const Instruction *>(value)->isUnary();
}
public:
Value *getOperand() const { return User::getOperand(0); }
@ -516,6 +616,12 @@ protected:
addOperand(rhs);
}
public:
static bool classof(const Value *value) {
return Instruction::classof(value) and
static_cast<const Instruction *>(value)->isBinary();
}
public:
Value *getLhs() const { return getOperand(0); }
Value *getRhs() const { return getOperand(1); }
@ -535,6 +641,9 @@ protected:
addOperand(value);
}
public:
static bool classof(const Value *value) { return value->getKind() == kReturn; }
public:
bool hasReturnValue() const { return not operands.empty(); }
Value *getReturnValue() const {
@ -558,9 +667,12 @@ protected:
addOperands(args);
}
public:
static bool classof(const Value *value) { return value->getKind() == kBr; }
public:
BasicBlock *getBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(0));
return dyncast<BasicBlock>(getOperand(0));
}
auto getArguments() const {
return make_range(std::next(operand_begin()), operand_end());
@ -588,13 +700,16 @@ protected:
addOperands(elseArgs);
}
public:
static bool classof(const Value *value) { return value->getKind() == kCondBr; }
public:
Value *getCondition() const { return getOperand(0); }
BasicBlock *getThenBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(1));
return dyncast<BasicBlock>(getOperand(1));
}
BasicBlock *getElseBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(2));
return dyncast<BasicBlock>(getOperand(2));
}
auto getThenArguments() const {
auto begin = std::next(operand_begin(), 3);
@ -623,6 +738,9 @@ protected:
addOperands(dims);
}
public:
static bool classof(const Value *value) { return value->getKind() == kAlloca; }
public:
int getNumDims() const { return getNumOperands(); }
auto getDims() const { return getOperands(); }
@ -645,6 +763,9 @@ protected:
addOperands(indices);
}
public:
static bool classof(const Value *value) { return value->getKind() == kLoad; }
public:
int getNumIndices() const { return getNumOperands() - 1; }
Value *getPointer() const { return getOperand(0); }
@ -671,6 +792,9 @@ protected:
addOperands(indices);
}
public:
static bool classof(const Value *value) { return value->getKind() == kStore; }
public:
int getNumIndices() const { return getNumOperands() - 2; }
Value *getValue() const { return getOperand(0); }
@ -691,10 +815,13 @@ class Function : public Value {
protected:
Function(Module *parent, Type *type, const std::string &name)
: Value(type, name), parent(parent), variableID(0), blocks() {
: Value(kFunction, type, name), parent(parent), variableID(0), blocks() {
blocks.emplace_back(new BasicBlock(this, "entry"));
}
public:
static bool classof(const Value *value) { return value->getKind() == kFunction; }
public:
using block_list = std::list<std::unique_ptr<BasicBlock>>;
@ -729,24 +856,24 @@ public:
void print(std::ostream &os) const override;
}; // class Function
class ArrayValue : public User {
protected:
ArrayValue(Type *type, const std::vector<Value *> &values = {})
: User(type, "") {
addOperands(values);
}
// class ArrayValue : public User {
// protected:
// ArrayValue(Type *type, const std::vector<Value *> &values = {})
// : User(type, "") {
// addOperands(values);
// }
public:
static ArrayValue *get(Type *type, const std::vector<Value *> &values);
static ArrayValue *get(const std::vector<int> &values);
static ArrayValue *get(const std::vector<float> &values);
// public:
// static ArrayValue *get(Type *type, const std::vector<Value *> &values);
// static ArrayValue *get(const std::vector<int> &values);
// static ArrayValue *get(const std::vector<float> &values);
public:
auto getValues() const { return getOperands(); }
// public:
// auto getValues() const { return getOperands(); }
public:
void print(std::ostream &os) const override{};
}; // class ConstantArray
// public:
// void print(std::ostream &os) const override{};
// }; // class ConstantArray
//! Global value declared at file scope
class GlobalValue : public User {
@ -760,13 +887,15 @@ protected:
protected:
GlobalValue(Module *parent, Type *type, const std::string &name,
const std::vector<Value *> &dims = {}, Value *init = nullptr)
: User(type, name), parent(parent), hasInit(init) {
: User(kGlobal, type, name), parent(parent), hasInit(init) {
assert(type->isPointer());
addOperands(dims);
if (init)
addOperand(init);
}
public:
static bool classof(const Value *value) { return value->getKind() == kGlobal; }
public:
Value *init() const { return hasInit ? operands.back().getValue() : nullptr; }
int getNumDims() const { return getNumOperands() - (hasInit ? 1 : 0); }

@ -71,8 +71,6 @@ any SysYIRGenerator::visitLocalDecl(SysYParser::DeclContext *ctx) {
}
any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) {
// create the function scope
SymbolTable::FunctionScope scope(symbols);
// obtain function name and type signature
auto name = ctx->ID()->getText();
vector<Type *> paramTypes;
@ -88,6 +86,10 @@ any SysYIRGenerator::visitFunc(SysYParser::FuncContext *ctx) {
auto funcType = Type::getFunctionType(returnType, paramTypes);
// create the IR function
auto function = module->createFunction(name, funcType);
// update the symbol table
symbols.insert(name, function);
// create the function scope
SymbolTable::FunctionScope scope(symbols);
// create the entry block with the same parameters as the function,
// and update the symbol table
auto entry = function->getEntryBlock();
@ -149,10 +151,7 @@ any SysYIRGenerator::visitLValueExp(SysYParser::LValueExpContext *ctx) {
Value *value = symbols.lookup(name);
if (not value)
error(ctx, "undefined variable");
auto a = dynamic_cast<Instruction *>(value);
if (dynamic_cast<GlobalValue *>(value) or
(dynamic_cast<Instruction *>(value) and
dynamic_cast<AllocaInst *>(value)))
if (isa<GlobalValue>(value) or isa<AllocaInst>(value))
value = builder.createLoadInst(value);
return value;
}
@ -217,7 +216,7 @@ any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
auto funcName = ctx->ID()->getText();
auto func = dynamic_cast<Function *>(symbols.lookup(funcName));
auto func = dyncast<Function>(symbols.lookup(funcName));
assert(func);
vector<Value *> args;
if (auto rArgs = ctx->funcRParams()) {

Loading…
Cancel
Save