You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
197 lines
8.2 KiB
197 lines
8.2 KiB
### Copyright (C) 2020 Roy Or-El. All rights reserved.
|
|
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
|
import os
|
|
import html
|
|
import glob
|
|
import uuid
|
|
import hashlib
|
|
import requests
|
|
import torch
|
|
import zipfile
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
from pdb import set_trace as st
|
|
|
|
|
|
males_model_spec = dict(file_url='https://drive.google.com/uc?id=1MsXN54hPi9PWDmn1HKdmKfv-J5hWYFVZ',
|
|
alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/males_model.zip',
|
|
file_path='checkpoints/males_model.zip', file_size=213175683, file_md5='0079186147ec816176b946a073d1f396')
|
|
females_model_spec = dict(file_url='https://drive.google.com/uc?id=1LNm0zAuiY0CIJnI0lHTq1Ttcu9_M1NAJ',
|
|
alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/females_model.zip',
|
|
file_path='checkpoints/females_model.zip', file_size=213218113, file_md5='0675f809413c026170cf1f22b27f3c5d')
|
|
resnet_file_spec = dict(file_url='https://drive.google.com/uc?id=1oRGgrI4KNdefbWVpw0rRkEP1gbJIRokM',
|
|
alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/R-101-GN-WS.pth.tar',
|
|
file_path='deeplab_model/R-101-GN-WS.pth.tar', file_size=178260167, file_md5='aa48cc3d3ba3b7ac357c1489b169eb32')
|
|
deeplab_file_spec = dict(file_url='https://drive.google.com/uc?id=1w2XjDywFr2NjuUWaLQDRktH7VwIfuNlY',
|
|
alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/deeplab_model.pth',
|
|
file_path='deeplab_model/deeplab_model.pth', file_size=464446305, file_md5='8e8345b1b9d95e02780f9bed76cc0293')
|
|
predictor_file_spec = dict(file_url='https://drive.google.com/uc?id=1fhq5lvWy-rjrzuHdMoZfLsULvF0gJGwD',
|
|
alt_url='https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/pretrained_models/shape_predictor_68_face_landmarks.dat',
|
|
file_path='util/shape_predictor_68_face_landmarks.dat', file_size=99693937, file_md5='73fde5e05226548677a050913eed4e04')
|
|
|
|
# Converts a Tensor into a Numpy array
|
|
# |imtype|: the desired type of the converted numpy array
|
|
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
|
|
im_sz = image_tensor.size()
|
|
ndims = image_tensor.dim()
|
|
if ndims == 2:
|
|
image_numpy = image_tensor.cpu().float().numpy()
|
|
image_numpy = (image_numpy + 1) / 2.0 * 255.0
|
|
elif ndims == 3:
|
|
image_numpy = image_tensor.cpu().float().numpy()
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
|
elif ndims == 4 and im_sz[0] == 1:
|
|
image_numpy = image_tensor[0].cpu().float().numpy()
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
|
elif ndims == 4:
|
|
image_numpy = image_tensor.cpu().float().numpy()
|
|
image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
|
|
else: # ndims == 5
|
|
image_numpy = image_tensor.cpu().float().numpy()
|
|
image_numpy = (np.transpose(image_numpy, (0, 1, 3, 4, 2)) + 1) / 2.0 * 255.0
|
|
|
|
return image_numpy.astype(imtype)
|
|
|
|
def save_image(image_numpy, image_path):
|
|
image_pil = Image.fromarray(image_numpy)
|
|
image_pil.save(image_path)
|
|
|
|
def mkdirs(paths):
|
|
if isinstance(paths, list) and not isinstance(paths, str):
|
|
for path in paths:
|
|
mkdir(path)
|
|
else:
|
|
mkdir(paths)
|
|
|
|
def mkdir(path):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
def download_pretrained_models():
|
|
print('Downloading males model')
|
|
with requests.Session() as session:
|
|
try:
|
|
download_file(session, males_model_spec)
|
|
except:
|
|
print('Google Drive download failed.\n' \
|
|
'Trying do download from alternate server')
|
|
download_file(session, males_model_spec, use_alt_url=True)
|
|
|
|
print('Extracting males model zip file')
|
|
with zipfile.ZipFile('./checkpoints/males_model.zip','r') as zip_fname:
|
|
zip_fname.extractall('./checkpoints')
|
|
|
|
print('Done!')
|
|
os.remove(males_model_spec['file_path'])
|
|
|
|
print('Downloading females model')
|
|
with requests.Session() as session:
|
|
try:
|
|
download_file(session, females_model_spec)
|
|
except:
|
|
print('Google Drive download failed.\n' \
|
|
'Trying do download from alternate server')
|
|
download_file(session, females_model_spec, use_alt_url=True)
|
|
|
|
print('Extracting females model zip file')
|
|
with zipfile.ZipFile('./checkpoints/females_model.zip','r') as zip_fname:
|
|
zip_fname.extractall('./checkpoints')
|
|
|
|
print('Done!')
|
|
os.remove(females_model_spec['file_path'])
|
|
|
|
print('Downloading face landmarks shape predictor')
|
|
with requests.Session() as session:
|
|
try:
|
|
download_file(session, predictor_file_spec)
|
|
except:
|
|
print('Google Drive download failed.\n' \
|
|
'Trying do download from alternate server')
|
|
download_file(session, predictor_file_spec, use_alt_url=True)
|
|
|
|
print('Done!')
|
|
|
|
print('Downloading DeeplabV3 backbone Resnet Model parameters')
|
|
with requests.Session() as session:
|
|
try:
|
|
download_file(session, resnet_file_spec)
|
|
except:
|
|
print('Google Drive download failed.\n' \
|
|
'Trying do download from alternate server')
|
|
download_file(session, resnet_file_spec, use_alt_url=True)
|
|
|
|
print('Done!')
|
|
|
|
print('Downloading DeeplabV3 Model parameters')
|
|
with requests.Session() as session:
|
|
try:
|
|
download_file(session, deeplab_file_spec)
|
|
except:
|
|
print('Google Drive download failed.\n' \
|
|
'Trying do download from alternate server')
|
|
download_file(session, deeplab_file_spec, use_alt_url=True)
|
|
|
|
print('Done!')
|
|
|
|
def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
|
|
file_path = file_spec['file_path']
|
|
if use_alt_url:
|
|
file_url = file_spec['alt_url']
|
|
else:
|
|
file_url = file_spec['file_url']
|
|
|
|
file_dir = os.path.dirname(file_path)
|
|
tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
|
|
if file_dir:
|
|
os.makedirs(file_dir, exist_ok=True)
|
|
|
|
progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
|
|
for attempts_left in reversed(range(num_attempts)):
|
|
data_size = 0
|
|
progress_bar.reset()
|
|
try:
|
|
# Download.
|
|
data_md5 = hashlib.md5()
|
|
with session.get(file_url, stream=True) as res:
|
|
res.raise_for_status()
|
|
with open(tmp_path, 'wb') as f:
|
|
for chunk in res.iter_content(chunk_size=chunk_size<<10):
|
|
progress_bar.update(len(chunk))
|
|
f.write(chunk)
|
|
data_size += len(chunk)
|
|
data_md5.update(chunk)
|
|
|
|
# Validate.
|
|
if 'file_size' in file_spec and data_size != file_spec['file_size']:
|
|
raise IOError('Incorrect file size', file_path)
|
|
if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
|
|
raise IOError('Incorrect file MD5', file_path)
|
|
break
|
|
|
|
except:
|
|
# Last attempt => raise error.
|
|
if not attempts_left:
|
|
raise
|
|
|
|
# Handle Google Drive virus checker nag.
|
|
if data_size > 0 and data_size < 8192:
|
|
with open(tmp_path, 'rb') as f:
|
|
data = f.read()
|
|
links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'export=download' in link]
|
|
if len(links) == 1:
|
|
file_url = requests.compat.urljoin(file_url, links[0])
|
|
continue
|
|
|
|
progress_bar.close()
|
|
|
|
# Rename temp file to the correct name.
|
|
os.replace(tmp_path, file_path) # atomic
|
|
|
|
# Attempt to clean up any leftover temps.
|
|
for filename in glob.glob(file_path + '.tmp.*'):
|
|
try:
|
|
os.remove(filename)
|
|
except:
|
|
pass
|