Refine IR and IRBuilder

main
Su Xing 3 years ago
parent 026717d3d8
commit 083c639034

@ -4,8 +4,8 @@
#include <cstddef>
#include <iterator>
#include <map>
#include <set>
#include <memory>
#include <set>
#include <vector>
namespace sysy {
@ -47,24 +47,30 @@ Type *Type::getFunctionType(Type *returnType,
PointerType *PointerType::get(Type *baseType) {
static std::map<Type *, std::unique_ptr<PointerType>> pointerTypes;
auto iter = pointerTypes.find(baseType);
if (iter != pointerTypes.end())
return iter->second.get();
auto type = new PointerType(baseType);
assert(type);
auto result = pointerTypes.try_emplace(baseType, type);
return result.first->second.get();
}
bool operator<(const FunctionType &lhs, const FunctionType &rhs) {
return lhs.getReturnType() < rhs.getReturnType() or
lhs.getParamTypes().size() < rhs.getParamTypes().size() and
std::lexicographical_compare(lhs.getParamTypes().begin(),
lhs.getParamTypes().end(),
rhs.getParamTypes().begin(),
rhs.getParamTypes().end());
}
FunctionType *FunctionType::get(Type *returnType,
const std::vector<Type *> &paramTypes) {
static std::set<std::unique_ptr<FunctionType>> functionTypes;
auto iter = std::find_if(functionTypes.begin(), functionTypes.end(),
[&](const std::unique_ptr<FunctionType> &type) -> bool {
if (returnType != type->getReturnType() or
paramTypes.size() != type->getParamTypes().size())
return false;
for (int i = 0; i < paramTypes.size(); ++i)
if (paramTypes[i] != type->getParamTypes()[i])
return false;
return true;
});
if (iter != functionTypes.end())
return iter->get();
auto type = new FunctionType(returnType, paramTypes);
assert(type);
auto result = functionTypes.emplace(type);
@ -77,6 +83,19 @@ void Value::replaceAllUsesWith(Value *value) {
uses.clear();
}
ConstantValue *ConstantValue::getInt(int value, const std::string &name) {
static std::map<int, Value *> intConstants;
auto inst = new ConstantValue(value);
assert(inst);
return inst;
}
ConstantValue *ConstantValue::getFloat(float value, const std::string &name) {
auto inst = new ConstantValue(value);
assert(inst);
return inst;
}
void User::setOperand(int index, Value *value) {
assert(index < getNumOperands());
operands[index].setValue(value);

@ -37,6 +37,9 @@ template <typename IterT> range<IterT> make_range(IterT b, IterT e) {
// targets.
// 2. `PointerType` and `FunctionType` derive from `Type` and represent pointer
// type and function type, respectively.
//
// NOTE `Type` and its derived classes have their ctors declared as 'protected'.
// Users must use Type::getXXXType() methods to obtain `Type` pointers.
//===----------------------------------------------------------------------===//
class Type {
@ -169,6 +172,36 @@ public:
void removeUse(Use *use) { uses.remove(use); }
}; // class Value
class ConstantValue : public Value {
friend class IRBuilder;
protected:
union {
int iConstant;
float fConstant;
};
protected:
ConstantValue(int value, const std::string &name = "")
: Value(Type::getIntType(), name), iConstant(value) {}
ConstantValue(float value, const std::string &name = "")
: Value(Type::getFloatType(), name), fConstant(value) {}
public:
static ConstantValue *getInt(int value, const std::string &name = "");
static ConstantValue *getFloat(float value, const std::string &name = "");
public:
int getInt() const {
assert(isInt());
return iConstant;
}
float getFloat() const {
assert(isFloat());
return fConstant;
}
}; // class ConstantValue
class BasicBlock;
class Argument : public Value {
protected:
@ -215,6 +248,7 @@ public:
block_list &getSuccessors() { return successors; }
iterator begin() { return instructions.begin(); }
iterator end() { return instructions.end(); }
iterator terminator() { return std::prev(end()); }
}; // class BasicBlock
class User : public Value {
@ -242,29 +276,6 @@ public:
const std::string &getName() const { return name; }
}; // class User
// class Constant : public User {
// protected:
// union scalar {
// int iConstant;
// float fConstant;
// };
// std::vector<int> dims;
// std::vector<scalar> data;
// protected:
// Constant(Type *type, const std::string &name = "") : User(type, name) {}
// public:
// int getInt() const {
// assert(isInt());
// return iConstant;
// }
// float getFloat() const {
// assert(isFloat());
// return fConstant;
// }
// }; // class ConstantInst
class Instruction : public User {
public:
enum Kind : uint64_t {
@ -273,17 +284,14 @@ public:
kAdd = 0x1UL << 0,
kSub = 0x1UL << 1,
kMul = 0x1UL << 2,
kSDiv = 0x1UL << 3,
kSRem = 0x1UL << 4,
kDiv = 0x1UL << 3,
kRem = 0x1UL << 4,
kICmpEQ = 0x1UL << 5,
kICmpNE = 0x1UL << 6,
kICmpLT = 0x1UL << 7,
kICmpGT = 0x1UL << 8,
kICmpLE = 0x1UL << 9,
kICmpGE = 0x1UL << 10,
kAShr = 0x1UL << 11,
kLShr = 0x1UL << 12,
kShl = 0x1UL << 13,
kFAdd = 0x1UL << 14,
kFSub = 0x1UL << 15,
kFMul = 0x1UL << 16,
@ -301,19 +309,18 @@ public:
kFNeg = 0x1UL << 26,
kFtoI = 0x1UL << 28,
kIToF = 0x1UL << 29,
kBitCast = 0x1UL << 30,
// call
kCall = 0x1UL << 33,
kCall = 0x1UL << 30,
// terminator
kCondBr = 0x1UL << 31,
kBr = 0x1UL << 32,
kReturn = 0x1UL << 34,
kReturn = 0x1UL << 33,
// mem op
kAlloca = 0x1UL << 35,
kLoad = 0x1UL << 36,
kStore = 0x1UL << 37,
kAlloca = 0x1UL << 34,
kLoad = 0x1UL << 35,
kStore = 0x1UL << 36,
// constant
// kConstant = 0x1UL << 38,
// kConstant = 0x1UL << 37,
};
protected:
@ -330,17 +337,10 @@ public:
BasicBlock *getParent() const { return parent; }
void setParent(BasicBlock *bb) { parent = bb; }
bool isInteger() const {
static constexpr uint64_t IntegerOpMask =
kAdd | kSub | kMul | kSDiv | kSRem | kICmpEQ | kICmpNE | kICmpLT |
kICmpGT | kICmpLE | kICmpGE | kAShr | kLShr | kShl | kNeg | kNot |
kIToF;
return kind & IntegerOpMask;
}
bool isCmp() const {
static constexpr uint64_t CondOpMask =
kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE | kFCmpEQ |
kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE;
(kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) |
(kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE);
return kind & CondOpMask;
}
bool isTerminator() {
@ -547,11 +547,13 @@ public:
Value *getIndex(int index) const { return getOperand(index + 2).getValue(); }
}; // class StoreInst
class Module;
class Function : public Value {
friend class Module;
protected:
Function(Type *type, const std::string &name) : Value(type, name) {
Function(Module *parent, Type *type, const std::string &name)
: Value(type, name), parent(parent), blocks() {
blocks.emplace_back(new BasicBlock(this, "entry"));
}
@ -559,6 +561,7 @@ public:
using block_list = std::list<std::unique_ptr<BasicBlock>>;
protected:
Module *parent;
block_list blocks;
public:
@ -585,9 +588,12 @@ class GlobalValue : public User {
friend class Module;
protected:
GlobalValue(Type *type, const std::vector<Value *> &dims = {},
const std::string &name = "")
: User(type, name) {
Module *parent;
protected:
GlobalValue(Module *parent, Type *type, const std::string &name,
const std::vector<Value *> &dims = {})
: User(type, name), parent(parent) {
addOperands(dims);
}
@ -600,22 +606,21 @@ class Module {
protected:
std::map<std::string, std::unique_ptr<Function>> functions;
std::map<std::string, std::unique_ptr<GlobalValue>> globals;
// std::list<std::unique_ptr<Function>> functions;
// std::list<std::unique_ptr<GlobalValue>> globals;
public:
Module() = default;
public:
Function *createFunction(const std::string &name, Type *type) {
auto result = functions.try_emplace(name, new Function(type, name));
auto result = functions.try_emplace(name, new Function(this, type, name));
if (not result.second)
return nullptr;
return result.first->second.get();
};
GlobalValue *createGlobalValue(const std::string &name, Type *type,
const std::vector<Value *> &dims = {}) {
auto result = globals.try_emplace(name, new GlobalValue(type, dims, name));
auto result =
globals.try_emplace(name, new GlobalValue(this, type, name, dims));
if (not result.second)
return nullptr;
return result.first->second.get();

@ -27,14 +27,6 @@ public:
void setPosition(BasicBlock::iterator position) { this->position = position; }
public:
CallInst *createCallInst(Function *callee,
const std::vector<Value *> args = {},
const std::string &name = "") {
auto inst = new CallInst(callee, args, block, name);
assert(inst);
block->getInstructions().emplace(position, inst);
return inst;
}
UnaryInst *createUnaryInst(Instruction::Kind kind, Type *type, Value *operand,
const std::string &name = "") {
@ -43,6 +35,26 @@ public:
block->getInstructions().emplace(position, inst);
return inst;
}
UnaryInst *createNegInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand,
name);
}
UnaryInst *createNotInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kNot, Type::getIntType(), operand,
name);
}
UnaryInst *createFtoIInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand,
name);
}
UnaryInst *createFNegInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand,
name);
}
UnaryInst *createIToFInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kIToF, Type::getFloatType(), operand,
name);
}
BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs,
Value *rhs, const std::string &name = "") {
auto inst = new BinaryInst(kind, type, lhs, rhs, block, name);
@ -50,6 +62,116 @@ public:
block->getInstructions().emplace(position, inst);
return inst;
}
BinaryInst *createAddInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createSubInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createMulInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createDivInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createRemInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpEQInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpNEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpLTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpLEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpGTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createICmpGEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs,
name);
}
BinaryInst *createFAddInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFSubInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFMulInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFDivInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFRemInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFRem, Type::getFloatType(), lhs, rhs,
name);
}
BinaryInst *createFCmpEQInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpEQ, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpNEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpNE, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpLTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpLT, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpLEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpLE, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpGTInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpGT, Type::getFloatType(), lhs,
rhs, name);
}
BinaryInst *createFCmpGEInst(Value *lhs, Value *rhs,
const std::string &name = "") {
return createBinaryInst(Instruction::kFCmpGE, Type::getFloatType(), lhs,
rhs, name);
}
ReturnInst *createReturnInst(Value *value = nullptr) {
auto inst = new ReturnInst(value);
assert(inst);

Loading…
Cancel
Save