You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
9.3 KiB

from PIL import Image
from torchvision import transforms
from torchvision.datasets import STL10
from torchvision.datasets import CIFAR10, CIFAR100
from random import sample
import cv2
import numpy as np
import torch
class CIFAR10Pair(CIFAR10):
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
pos_1 = self.transform(img)
pos_2 = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return pos_1, pos_2, target
class CIFAR100Pair_true_label(CIFAR100):
#dataloader where pairs of positive samples are randomly sampled from pairs
#of inputs with the same label.
def __init__(self, root='../data', train=True, transform=None):
super().__init__(root=root, train=train, transform=transform)
def get_labels(i):
return [index for index in range(len(self)) if self.targets[index]==i]
self.label_index = [get_labels(i) for i in range(100)]
def __getitem__(self, index):
img1, target = self.data[index], self.targets[index]
index_example_same_label=sample(self.label_index[self.targets[index]],1)[0]
img2 = self.data[index_example_same_label]
img1 = Image.fromarray(img1)
img2 = Image.fromarray(img2)
if self.transform is not None:
pos_1 = self.transform(img1)
pos_2 = self.transform(img2)
if self.target_transform is not None:
target = self.target_transform(target)
return pos_1, pos_2, target
class CIFAR100Pair(CIFAR100):
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
pos_1 = self.transform(img)
pos_2 = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return pos_1, pos_2, target
class STL10Pair(STL10):
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
if self.transform is not None:
pos_1 = self.transform(img)
pos_2 = self.transform(img)
return pos_1, pos_2, target
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, min=0.1, max=2.0):
self.min = min
self.max = max
# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
def __call__(self, sample):
sample = np.array(sample)
# blur the image with a 50% chance
prob = np.random.random_sample()
if prob < 0.5:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
return sample
train_transform = transforms.Compose([
transforms.RandomResizedCrop(28),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(kernel_size=int(0.1 * 32)),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
def get_dataset(dataset_name, root='../data', pair=True):
if pair:
if dataset_name=='cifar10':
train_data = CIFAR10Pair(root=root, train=True, transform=train_transform, download=True)
memory_data = CIFAR10Pair(root=root, train=True, transform=test_transform, download=True)
test_data = CIFAR10Pair(root=root, train=False, transform=test_transform, download=True)
elif dataset_name=='cifar100':
train_data = CIFAR100Pair(root=root, train=True, transform=train_transform)
memory_data = CIFAR100Pair(root=root, train=True, transform=test_transform)
test_data = CIFAR100Pair(root=root, train=False, transform=test_transform)
elif dataset_name=='stl10':
train_data = STL10Pair(root=root, split='train+unlabeled', transform=train_transform)
memory_data = STL10Pair(root=root, split='train', transform=test_transform)
test_data = STL10Pair(root=root, split='test', transform=test_transform)
elif dataset_name=='cifar100_true_label':
train_data = CIFAR100Pair_true_label(root=root, train=True, transform=train_transform)
memory_data = CIFAR100Pair_true_label(root=root, train=True, transform=test_transform)
test_data = CIFAR100Pair_true_label(root=root, train=False, transform=test_transform)
else:
raise Exception('Invalid dataset name')
else:
if dataset_name in ['cifar10', 'cifar10_true_label']:
train_data = CIFAR10(root=root, train=True, transform=train_transform)
memory_data = CIFAR10(root=root, train=True, transform=test_transform)
test_data = CIFAR10(root=root, train=False, transform=test_transform)
elif dataset_name in ['cifar100', 'cifar100_true_label']:
train_data = CIFAR100(root=root, train=True, transform=train_transform)
memory_data = CIFAR100(root=root, train=True, transform=test_transform)
test_data = CIFAR100(root=root, train=False, transform=test_transform)
elif dataset_name=='stl10':
train_data = STL10(root=root, split='train', transform=train_transform)
memory_data = STL10(root=root, split='train', transform=test_transform)
test_data = STL10(root=root, split='test', transform=test_transform)
else:
raise Exception('Invalid dataset name')
return train_data, memory_data, test_data
class DATASET:
def __init__(self, dataset_path, transform_fn=None):
self.dataset = np.load(dataset_path, allow_pickle=True).item() #need to use HWC version of the data
self.transform_fn = transform_fn
def __getitem__(self, idx):
image = self.dataset["images"][idx]
label = self.dataset["labels"][idx]
image = Image.fromarray(image)
if self.transform_fn is not None:
pos_1 = self.transform_fn(image)
pos_2 = self.transform_fn(image)
return pos_1, pos_2, label
def __len__(self):
return len(self.dataset["images"])
def get_medical_dataset(root='../data/MedMNIST/TissueMNIST/unnormalized_HWC/n_per_cls10/'):
l_train_dataset = DATASET(root+"l_train.npy", transform_fn=train_transform)
u_train_dataset = DATASET(root+"u_train.npy", transform_fn=train_transform)
train_data = torch.utils.data.ConcatDataset([l_train_dataset, u_train_dataset])
memory_dataset = DATASET(root+"l_train.npy", transform_fn=test_transform)
val_dataset = DATASET(root+"val.npy", transform_fn=test_transform)
test_dataset = DATASET(root+"test.npy", transform_fn=test_transform)
return train_data, memory_dataset, val_dataset, test_dataset
import numpy as np
import torch
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
trace_func (function): trace print function.
Default: print
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
self.trace_func = trace_func
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss