You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.1 KiB
40 lines
1.1 KiB
### Copyright (C) 2020 Roy Or-El. All rights reserved.
|
|
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
|
import torch.utils.data
|
|
from data.multiclass_unaligned_dataset import MulticlassUnalignedDataset
|
|
from pdb import set_trace as st
|
|
|
|
class AgingDataLoader():
|
|
def name(self):
|
|
return 'AgingDataLoader'
|
|
|
|
def initialize(self, opt):
|
|
self.opt = opt
|
|
self.dataset = CreateDataset(opt)
|
|
self.dataloader = torch.utils.data.DataLoader(
|
|
self.dataset,
|
|
batch_size=opt.batchSize,
|
|
shuffle=not opt.serial_batches,
|
|
drop_last=True,
|
|
num_workers=int(opt.nThreads))
|
|
|
|
def load_data(self):
|
|
return self.dataloader
|
|
|
|
def __len__(self):
|
|
return min(len(self.dataset), self.opt.max_dataset_size)
|
|
|
|
|
|
def CreateDataset(opt):
|
|
dataset = MulticlassUnalignedDataset()
|
|
print("dataset [%s] was created" % (dataset.name()))
|
|
dataset.initialize(opt)
|
|
return dataset
|
|
|
|
|
|
def CreateDataLoader(opt):
|
|
data_loader = AgingDataLoader()
|
|
print(data_loader.name())
|
|
data_loader.initialize(opt)
|
|
return data_loader
|