You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
8.2 KiB

import os
import dlib
import shutil
import requests
import numpy as np
import scipy.ndimage
import torch
import torchvision.transforms as transforms
import util.deeplab as deeplab
from PIL import Image
from util.util import download_file
from pdb import set_trace as st
resnet_file_path = 'deeplab_model/R-101-GN-WS.pth.tar'
deeplab_file_path = 'deeplab_model/deeplab_model.pth'
predictor_file_path = 'util/shape_predictor_68_face_landmarks.dat'
model_fname = 'deeplab_model/deeplab_model.pth'
deeplab_classes = ['background' ,'skin','nose','eye_g','l_eye','r_eye','l_brow','r_brow','l_ear','r_ear','mouth','u_lip','l_lip','hair','hat','ear_r','neck_l','neck','cloth']
class preprocessInTheWildImage():
def __init__(self, out_size=256):
self.out_size = out_size
# load landmark detector models
self.detector = dlib.get_frontal_face_detector()
if not os.path.isfile(predictor_file_path):
print('Cannot find landmarks shape predictor model.\n'\
'Please run download_models.py to download the model')
raise OSError
self.predictor = dlib.shape_predictor(predictor_file_path)
# deeplab data properties
self.deeplab_data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.deeplab_input_size = 513
# load deeplab model
assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
if not os.path.isfile(resnet_file_path):
print('Cannot find DeeplabV3 backbone Resnet model.\n' \
'Please run download_models.py to download the model')
raise OSError
self.deeplab_model = getattr(deeplab, 'resnet101')(
pretrained=True,
num_classes=len(deeplab_classes),
num_groups=32,
weight_std=True,
beta=False)
self.deeplab_model.eval()
if not os.path.isfile(deeplab_file_path):
print('Cannot find DeeplabV3 model.\n' \
'Please run download_models.py to download the model')
raise OSError
checkpoint = torch.load(model_fname)
state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k}
self.deeplab_model.load_state_dict(state_dict)
def dlib_shape_to_landmarks(self, shape):
# initialize the list of (x, y)-coordinates
landmarks = np.zeros((68, 2), dtype=np.float32)
# loop over the 68 facial landmarks and convert them
# to a 2-tuple of (x, y)-coordinates
for i in range(0, 68):
landmarks[i] = (shape.part(i).x, shape.part(i).y)
# return the list of (x, y)-coordinates
return landmarks
def extract_face_landmarks(self, img):
# detect all faces in the image and
# keep the detection with the largest bounding box
dets = self.detector(img, 1)
if len(dets) == 0:
print ('Could not detect any face in the image, please try again with a different image')
raise
max_area = 0
max_idx = -1
for k, d in enumerate(dets):
area = (d.right() - d.left()) * (d.bottom() - d.top())
if area > max_area:
max_area = area
max_idx = k
# Get the landmarks/parts for the face in box d.
dlib_shape = self.predictor(img, dets[max_idx])
landmarks = self.dlib_shape_to_landmarks(dlib_shape)
return landmarks
def align_in_the_wild_image(self, np_img, lm, transform_size=4096, enable_padding=True):
# Parse landmarks.
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise
# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg
# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 2.2) # This results in larger crops then the original FFHQ. For the original crops, replace 2.2 with 1.8
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2
# Load in-the-wild image.
img = Image.fromarray(np_img)
# Shrink.
shrink = int(np.floor(qsize / self.out_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]
# Pad.
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]
# Transform.
img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
if self.out_size < transform_size:
img = img.resize((self.out_size, self.out_size), Image.ANTIALIAS)
return img
def get_segmentation_maps(self, img):
img = img.resize((self.deeplab_input_size,self.deeplab_input_size),Image.BILINEAR)
img = self.deeplab_data_transform(img)
img = img.cuda()
self.deeplab_model.cuda()
outputs = self.deeplab_model(img.unsqueeze(0))
self.deeplab_model.cpu()
_, pred = torch.max(outputs, 1)
pred = pred.data.cpu().numpy().squeeze().astype(np.uint8)
seg_map = Image.fromarray(pred)
seg_map = np.uint8(seg_map.resize((self.out_size,self.out_size), Image.NEAREST))
return seg_map
def forward(self, img):
landmarks = self.extract_face_landmarks(img)
aligned_img = self.align_in_the_wild_image(img, landmarks)
seg_map = self.get_segmentation_maps(aligned_img)
aligned_img = np.array(aligned_img.getdata(), dtype=np.uint8).reshape(self.out_size, self.out_size, 3)
return aligned_img, seg_map