main
Su Xing 3 years ago
parent f83512e305
commit d94dce0488

@ -45,6 +45,21 @@ Type *Type::getFunctionType(Type *returnType,
return FunctionType::get(returnType, paramTypes);
}
int Type::getSize() const {
switch (kind) {
case kInt:
case kFloat:
return 4;
case kLabel:
case kPointer:
case kFunction:
return 8;
case kVoid:
return 0;
}
return 0;
}
PointerType *PointerType::get(Type *baseType) {
static std::map<Type *, std::unique_ptr<PointerType>> pointerTypes;
auto iter = pointerTypes.find(baseType);
@ -59,16 +74,15 @@ PointerType *PointerType::get(Type *baseType) {
FunctionType *FunctionType::get(Type *returnType,
const std::vector<Type *> &paramTypes) {
static std::set<std::unique_ptr<FunctionType>> functionTypes;
auto iter = std::find_if(functionTypes.begin(), functionTypes.end(),
[&](const std::unique_ptr<FunctionType> &type) -> bool {
if (returnType != type->getReturnType() or
paramTypes.size() != type->getParamTypes().size())
return false;
for (int i = 0; i < paramTypes.size(); ++i)
if (paramTypes[i] != type->getParamTypes()[i])
return false;
return true;
});
auto iter =
std::find_if(functionTypes.begin(), functionTypes.end(),
[&](const std::unique_ptr<FunctionType> &type) -> bool {
if (returnType != type->getReturnType() or
paramTypes.size() != type->getParamTypes().size())
return false;
return std::equal(paramTypes.begin(), paramTypes.end(),
type->getParamTypes().begin());
});
if (iter != functionTypes.end())
return iter->get();
auto type = new FunctionType(returnType, paramTypes);
@ -83,7 +97,7 @@ void Value::replaceAllUsesWith(Value *value) {
uses.clear();
}
ConstantValue *ConstantValue::getInt(int value, const std::string &name) {
ConstantValue *ConstantValue::get(int value, const std::string &name) {
static std::map<int, std::unique_ptr<ConstantValue>> intConstants;
auto iter = intConstants.find(value);
if (iter != intConstants.end())
@ -94,7 +108,7 @@ ConstantValue *ConstantValue::getInt(int value, const std::string &name) {
return result.first->second.get();
}
ConstantValue *ConstantValue::getFloat(float value, const std::string &name) {
ConstantValue *ConstantValue::get(float value, const std::string &name) {
static std::map<float, std::unique_ptr<ConstantValue>> floatConstants;
auto iter = floatConstants.find(value);
if (iter != floatConstants.end())

@ -22,11 +22,21 @@ template <typename IterT> struct range {
explicit range(iterator b, iterator e) : b(b), e(e) {}
iterator begin() { return b; }
iterator end() { return e; }
auto size() const { return std::distance(b, e); }
auto empty() const { return b == e; }
};
template <typename IterT> range<IterT> make_range(IterT b, IterT e) {
return range(b, e);
}
template <typename ContainerT>
range<typename ContainerT::iterator> make_range(ContainerT &c) {
return make_range(c.begin(), c.end());
}
template <typename ContainerT>
range<typename ContainerT::const_iterator> make_range(const ContainerT &c) {
return make_range(c.begin(), c.end());
}
//===----------------------------------------------------------------------===//
// Types
@ -66,7 +76,6 @@ public:
static Type *getPointerType(Type *baseType);
static Type *getFunctionType(Type *returnType,
const std::vector<Type *> &paramTypes = {});
static int getTypeSize();
public:
Kind getKind() const { return kind; }
@ -76,6 +85,7 @@ public:
bool isLabel() const { return kind == kLabel; }
bool isPointer() const { return kind == kPointer; }
bool isFunction() const { return kind == kFunction; }
int getSize() const;
}; // class Type
class PointerType : public Type {
@ -107,7 +117,7 @@ public:
public:
Type *getReturnType() const { return returnType; }
const std::vector<Type *> &getParamTypes() const { return paramTypes; }
auto getParamTypes() const { return make_range(paramTypes); }
int getNumParams() const { return paramTypes.size(); }
}; // class FunctionType
@ -120,6 +130,7 @@ public:
class User;
class Value;
// `Use` represents the relation between a `Value` and its `User`
class Use {
public:
enum Kind {
@ -173,8 +184,6 @@ public:
}; // class Value
class ConstantValue : public Value {
friend class IRBuilder;
protected:
union {
int iConstant;
@ -188,8 +197,8 @@ protected:
: Value(Type::getFloatType(), name), fConstant(value) {}
public:
static ConstantValue *getInt(int value, const std::string &name = "");
static ConstantValue *getFloat(float value, const std::string &name = "");
static ConstantValue *get(int value, const std::string &name = "");
static ConstantValue *get(float value, const std::string &name = "");
public:
int getInt() const {
@ -238,6 +247,7 @@ protected:
arguments(), successors(), predecessors() {}
public:
int getNumInstructions() const { return instructions.size(); }
int getNumArguments() const { return arguments.size(); }
int getNumPredecessors() const { return predecessors.size(); }
int getNumSuccessors() const { return successors.size(); }
@ -259,15 +269,27 @@ protected:
User(Type *type, const std::string &name = "")
: Value(type, name), operands() {}
public:
struct operand_iterator : std::vector<Use>::iterator {
using Base = std::vector<Use>::iterator;
using Base::Base;
using value_type = Value *;
value_type operator->() { return operator*().getValue(); }
};
public:
int getNumOperands() const { return operands.size(); }
const std::vector<Use> &getOperands() const { return operands; }
const Use &getOperand(int index) const { return operands[index]; }
auto operand_begin() const { return operands.begin(); }
auto operand_end() const { return operands.end(); }
auto getOperands() const {
return make_range(operand_begin(), operand_end());
}
Value *getOperand(int index) const { return operands[index].getValue(); }
void addOperand(Value *value, Use::Kind mode = Use::kRead) {
operands.emplace_back(mode, operands.size(), this, value);
value->addUse(&operands.back());
}
void addOperands(const std::vector<Value *> &operands) {
template <typename ContainerT> void addOperands(const ContainerT &operands) {
for (auto value : operands)
addOperand(value);
}
@ -343,15 +365,19 @@ public:
(kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE);
return kind & CondOpMask;
}
bool isTerminator() {
bool isTerminator() const {
static constexpr uint64_t TerminatorOpMask = kCondBr | kBr | kReturn;
return kind & TerminatorOpMask;
}
bool isCommutative() {
bool isCommutative() const {
static constexpr uint64_t CommutativeOpMask =
kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE;
return kind & CommutativeOpMask;
}
bool isMemory() const {
static constexpr uint64_t MemoryOpMask = kAlloca | kLoad | kStore;
return kind & MemoryOpMask;
}
}; // class Instruction
class Function;
@ -365,7 +391,7 @@ protected:
public:
Function *getCallee();
auto getArguments() {
return make_range(std::next(operands.begin()), operands.end());
return make_range(std::next(operand_begin()), operand_end());
}
}; // class CallInst
@ -380,7 +406,7 @@ protected:
}
public:
Value *getOperand() const { return operands[0].getValue(); }
Value *getOperand() const { return User::getOperand(0); }
}; // class UnaryInst
class BinaryInst : public Instruction {
@ -395,8 +421,8 @@ protected:
}
public:
Value *getLhs() const { return operands[0].getValue(); }
Value *getRhs() const { return operands[1].getValue(); }
Value *getLhs() const { return getOperand(0); }
Value *getRhs() const { return getOperand(1); }
}; // class BinaryInst
class TerminatorInst : public Instruction {
@ -413,6 +439,12 @@ protected:
if (value)
addOperand(value);
}
public:
bool hasReturnValue() const { return not operands.empty(); }
Value *getReturnValue() const {
return hasReturnValue() ? getOperand(0) : nullptr;
}
}; // class ReturnInst
class BranchInst : public TerminatorInst {
@ -438,7 +470,7 @@ protected:
public:
BasicBlock *getBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(0).getValue());
return dynamic_cast<BasicBlock *>(getOperand(0));
}
auto getArguments() const {
return make_range(std::next(operands.begin()), operands.end());
@ -464,10 +496,10 @@ protected:
public:
BasicBlock *getThenBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(0).getValue());
return dynamic_cast<BasicBlock *>(getOperand(0));
}
BasicBlock *getElseBlock() const {
return dynamic_cast<BasicBlock *>(getOperand(1).getValue());
return dynamic_cast<BasicBlock *>(getOperand(1));
}
auto getThenArguments() const {
auto begin = operands.begin() + 2;
@ -498,8 +530,8 @@ protected:
public:
int getNumDims() const { return getNumOperands(); }
auto &getDims() const { return operands; }
Value *getDim(int index) { return getOperand(index).getValue(); }
auto getDims() const { return getOperands(); }
Value *getDim(int index) { return getOperand(index); }
}; // class AllocaInst
class LoadInst : public MemoryInst {
@ -517,11 +549,11 @@ protected:
public:
int getNumIndices() const { return getNumOperands() - 1; }
Value *getPointer() const { return operands.front().getValue(); }
Value *getPointer() const { return getOperand(0); }
auto getIndices() const {
return make_range(std::next(operands.begin()), operands.end());
return make_range(std::next(operand_begin()), operand_end());
}
Value *getIndex(int index) const { return getOperand(index + 1).getValue(); }
Value *getIndex(int index) const { return getOperand(index + 1); }
}; // class LoadInst
class StoreInst : public MemoryInst {
@ -539,12 +571,12 @@ protected:
public:
int getNumIndices() const { return getNumOperands() - 2; }
Value *getValue() const { return operands[0].getValue(); }
Value *getPointer() const { return operands[1].getValue(); }
Value *getValue() const { return getOperand(0); }
Value *getPointer() const { return getOperand(1); }
auto getIndices() const {
return make_range(operands.begin() + 2, operands.end());
return make_range(operand_begin() + 2, operand_end());
}
Value *getIndex(int index) const { return getOperand(index + 2).getValue(); }
Value *getIndex(int index) const { return getOperand(index + 2); }
}; // class StoreInst
class Module;
@ -568,10 +600,10 @@ public:
Type *getReturnType() const {
return dynamic_cast<FunctionType *>(getType())->getReturnType();
}
auto &getParamTypes() const {
auto getParamTypes() const {
return dynamic_cast<FunctionType *>(getType())->getParamTypes();
}
block_list &getBasicBlocks() { return blocks; }
auto getBasicBlocks() { return make_range(blocks); }
BasicBlock *getEntryBlock() { return blocks.front().get(); }
BasicBlock *addBasicBlock(const std::string &name = "") {
blocks.emplace_back(new BasicBlock(this, name));
@ -599,7 +631,7 @@ protected:
public:
int getNumDims() const { return getNumOperands(); }
Value *getDim(int index) { return getOperand(index).getValue(); }
Value *getDim(int index) { return getOperand(index); }
}; // class GlobalValue
class Module {
@ -648,7 +680,7 @@ inline CallInst::CallInst(Function *callee, const std::vector<Value *> args,
}
inline Function *CallInst::getCallee() {
return dynamic_cast<Function *>(getOperand(0).getValue());
return dynamic_cast<Function *>(getOperand(0));
}
} // namespace sysy
Loading…
Cancel
Save