You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.5 KiB

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from random import randint
import numpy as np
import cv2
from PIL import Image
import random
###################################################################
# random mask generation
###################################################################
def random_regular_mask(img):
"""Generates a random regular hole"""
mask = torch.ones_like(img)
s = img.size()
N_mask = random.randint(1, 5)
limx = s[1] - s[1] / (N_mask + 1)
limy = s[2] - s[2] / (N_mask + 1)
for _ in range(N_mask):
x = random.randint(0, int(limx))
y = random.randint(0, int(limy))
range_x = x + random.randint(int(s[1] / (N_mask + 7)), int(s[1] - x))
range_y = y + random.randint(int(s[2] / (N_mask + 7)), int(s[2] - y))
mask[:, int(x):int(range_x), int(y):int(range_y)] = 0
return mask
def center_mask(img):
"""Generates a center hole with 1/4*W and 1/4*H"""
mask = torch.ones_like(img)
size = img.size()
x = int(size[1] / 4)
y = int(size[2] / 4)
range_x = int(size[1] * 3 / 4)
range_y = int(size[2] * 3 / 4)
mask[:, x:range_x, y:range_y] = 0
# print(mask[:, x:range_x, y:range_y].shape)
return mask
def random_irregular_mask(img):
"""Generates a random irregular mask with lines, circles and elipses"""
transform = transforms.Compose([transforms.ToTensor()])
mask = torch.ones_like(img)
size = img.size()
img = np.zeros((size[1], size[2], 1), np.uint8)
# Set size scale
max_width = 20
if size[1] < 64 or size[2] < 64:
raise Exception("Width and Height of mask must be at least 64!")
number = random.randint(16, 64)
for _ in range(number):
model = random.random()
if model < 0.6:
# Draw random lines
x1, x2 = randint(1, size[1]), randint(1, size[1])
y1, y2 = randint(1, size[2]), randint(1, size[2])
thickness = randint(4, max_width)
cv2.line(img, (x1, y1), (x2, y2), (1, 1, 1), thickness)
elif model > 0.6 and model < 0.8:
# Draw random circles
x1, y1 = randint(1, size[1]), randint(1, size[2])
radius = randint(4, max_width)
cv2.circle(img, (x1, y1), radius, (1, 1, 1), -1)
elif model > 0.8:
# Draw random ellipses
x1, y1 = randint(1, size[1]), randint(1, size[2])
s1, s2 = randint(1, size[1]), randint(1, size[2])
a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)
thickness = randint(4, max_width)
cv2.ellipse(img, (x1, y1), (s1, s2), a1, a2, a3, (1, 1, 1), thickness)
img = img.reshape(size[2], size[1])
img = Image.fromarray(img*255)
img_mask = transform(img)
for j in range(size[0]):
mask[j, :, :] = img_mask < 1
return mask
###################################################################
# multi scale for image generation
###################################################################
def scale_img(img, size):
scaled_img = F.interpolate(img, size=size, mode='bilinear', align_corners=True)
return scaled_img
def scale_pyramid(img, num_scales):
scaled_imgs = [img]
s = img.size()
h = s[2]
w = s[3]
for i in range(1, num_scales):
ratio = 2**i
nh = h // ratio
nw = w // ratio
scaled_img = scale_img(img, size=[nh, nw])
scaled_imgs.append(scaled_img)
scaled_imgs.reverse()
return scaled_imgs