You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
4.4 KiB

import cv2
import numpy as np
import torch
from torchvision.transforms import Normalize as Normalize_th
class CustomTransform:
def __call__(self, *args, **kwargs):
raise NotImplementedError
def __str__(self):
return self.__class__.__name__
def __eq__(self, name):
return str(self) == name
def __iter__(self):
def iter_fn():
for t in [self]:
yield t
return iter_fn()
def __contains__(self, name):
for t in self.__iter__():
if isinstance(t, Compose):
if name in t:
return True
elif name == t:
return True
return False
class Compose(CustomTransform):
"""
All transform in Compose should be able to accept two non None variable, img and boxes
"""
def __init__(self, *transforms):
self.transforms = [*transforms]
def __call__(self, sample):
for t in self.transforms:
sample = t(sample)
return sample
def __iter__(self):
return iter(self.transforms)
def modules(self):
yield self
for t in self.transforms:
if isinstance(t, Compose):
for _t in t.modules():
yield _t
else:
yield t
class Resize(CustomTransform):
def __init__(self, size):
if isinstance(size, int):
size = (size, size)
self.size = size #(W, H)
def __call__(self, sample):
img = sample.get('img')
segLabel = sample.get('segLabel', None)
img = cv2.resize(img, self.size, interpolation=cv2.INTER_CUBIC)
if segLabel is not None:
segLabel = cv2.resize(segLabel, self.size, interpolation=cv2.INTER_NEAREST)
_sample = sample.copy()
_sample['img'] = img
_sample['segLabel'] = segLabel
return _sample
def reset_size(self, size):
if isinstance(size, int):
size = (size, size)
self.size = size
class RandomResize(Resize):
"""
Resize to (w, h), where w randomly samples from (minW, maxW) and h randomly samples from (minH, maxH)
"""
def __init__(self, minW, maxW, minH=None, maxH=None, batch=False):
if minH is None or maxH is None:
minH, maxH = minW, maxW
super(RandomResize, self).__init__((minW, minH))
self.minW = minW
self.maxW = maxW
self.minH = minH
self.maxH = maxH
self.batch = batch
def random_set_size(self):
w = np.random.randint(self.minW, self.maxW+1)
h = np.random.randint(self.minH, self.maxH+1)
self.reset_size((w, h))
class Rotation(CustomTransform):
def __init__(self, theta):
self.theta = theta
def __call__(self, sample):
img = sample.get('img')
segLabel = sample.get('segLabel', None)
u = np.random.uniform()
degree = (u-0.5) * self.theta
R = cv2.getRotationMatrix2D((img.shape[1]//2, img.shape[0]//2), degree, 1)
img = cv2.warpAffine(img, R, (img.shape[1], img.shape[0]), flags=cv2.INTER_LINEAR)
if segLabel is not None:
segLabel = cv2.warpAffine(segLabel, R, (segLabel.shape[1], segLabel.shape[0]), flags=cv2.INTER_NEAREST)
_sample = sample.copy()
_sample['img'] = img
_sample['segLabel'] = segLabel
return _sample
def reset_theta(self, theta):
self.theta = theta
class Normalize(CustomTransform):
def __init__(self, mean, std):
self.transform = Normalize_th(mean, std)
def __call__(self, sample):
img = sample.get('img')
img = self.transform(img)
_sample = sample.copy()
_sample['img'] = img
return _sample
class ToTensor(CustomTransform):
def __init__(self, dtype=torch.float):
self.dtype=dtype
def __call__(self, sample):
img = sample.get('img')
segLabel = sample.get('segLabel', None)
exist = sample.get('exist', None)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).type(self.dtype) / 255.
if segLabel is not None:
segLabel = torch.from_numpy(segLabel).type(torch.long)
if exist is not None:
exist = torch.from_numpy(exist).type(torch.float32) # BCEloss requires float tensor
_sample = sample.copy()
_sample['img'] = img
_sample['segLabel'] = segLabel
_sample['exist'] = exist
return _sample