You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.4 KiB
147 lines
5.4 KiB
import cv2 as cv
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn.functional as F
|
|
from PIL import Image
|
|
from paddle.vision.transforms import functional
|
|
from ppgan.apps.base_predictor import BasePredictor
|
|
from ppgan.models.generators import DecoderNet, Encoder, RevisionNet
|
|
from ppgan.utils.download import get_path_from_url
|
|
from ppgan.utils.visual import tensor2img
|
|
|
|
LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams'
|
|
LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams'
|
|
LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams'
|
|
LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams'
|
|
|
|
|
|
def img(src):
|
|
# some images have 4 channels
|
|
if src.shape[2] > 3:
|
|
src = src[:, :, :3]
|
|
# HWC to CHW
|
|
return src
|
|
|
|
|
|
# rgb图像
|
|
def img_read(content_img: np.ndarray, style_img_path: str):
|
|
if content_img.ndim == 2:
|
|
content_img = cv.cvtColor(content_img, cv.COLOR_GRAY2RGB)
|
|
h, w, c = content_img.shape
|
|
content_img = Image.fromarray(content_img)
|
|
content_img = content_img.resize((512, 512), Image.BILINEAR)
|
|
content_img = np.array(content_img)
|
|
content_img = img(content_img)
|
|
content_img = functional.to_tensor(content_img)
|
|
|
|
style_img = cv.imread(style_img_path)
|
|
style_img = cv.cvtColor(style_img, cv.COLOR_BGR2RGB)
|
|
style_img = Image.fromarray(style_img)
|
|
style_img = style_img.resize((512, 512), Image.BILINEAR)
|
|
style_img = np.array(style_img)
|
|
style_img = img(style_img)
|
|
style_img = functional.to_tensor(style_img)
|
|
|
|
content_img = paddle.unsqueeze(content_img, axis=0)
|
|
style_img = paddle.unsqueeze(style_img, axis=0)
|
|
return content_img, style_img, h, w
|
|
|
|
|
|
def tensor_resample(tensor, dst_size, mode='bilinear'):
|
|
return F.interpolate(tensor, dst_size, mode=mode, align_corners=False)
|
|
|
|
|
|
def laplacian(x):
|
|
"""
|
|
Laplacian
|
|
|
|
return:
|
|
x - upsample(downsample(x))
|
|
"""
|
|
return x - tensor_resample(
|
|
tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]),
|
|
[x.shape[2], x.shape[3]])
|
|
|
|
|
|
def make_laplace_pyramid(x, levels):
|
|
"""
|
|
Make Laplacian Pyramid
|
|
"""
|
|
pyramid = []
|
|
current = x
|
|
for i in range(levels):
|
|
pyramid.append(laplacian(current))
|
|
current = tensor_resample(
|
|
current,
|
|
(max(current.shape[2] // 2, 1), max(current.shape[3] // 2, 1)))
|
|
pyramid.append(current)
|
|
return pyramid
|
|
|
|
|
|
def fold_laplace_pyramid(pyramid):
|
|
"""
|
|
Fold Laplacian Pyramid
|
|
"""
|
|
current = pyramid[-1]
|
|
for i in range(len(pyramid) - 2, -1, -1): # iterate from len-2 to 0
|
|
up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3]
|
|
current = pyramid[i] + tensor_resample(current, (up_h, up_w))
|
|
return current
|
|
|
|
|
|
class LapStylePredictor(BasePredictor):
|
|
def __init__(self, style='starrynew'):
|
|
super().__init__()
|
|
self.input = input
|
|
self.net_enc = Encoder()
|
|
self.net_dec = DecoderNet()
|
|
self.net_rev = RevisionNet()
|
|
self.net_rev_2 = RevisionNet()
|
|
self.style_img_path = None
|
|
self.weight_path = None
|
|
if style == 'starrynew':
|
|
weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL)
|
|
self.style_img_path = "assets/starrynew.png"
|
|
elif style == 'circuit':
|
|
weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL)
|
|
self.style_img_path = "assets/circuit.jpg"
|
|
elif style == 'ocean':
|
|
weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL)
|
|
self.style_img_path = "assets/ocean.png"
|
|
elif style == 'stars':
|
|
weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL)
|
|
self.style_img_path = "assets/stars.png"
|
|
else:
|
|
raise Exception(f'has not implemented {style}.')
|
|
self.net_enc.set_dict(paddle.load(weight_path)['net_enc'])
|
|
self.net_enc.eval()
|
|
self.net_dec.set_dict(paddle.load(weight_path)['net_dec'])
|
|
self.net_dec.eval()
|
|
self.net_rev.set_dict(paddle.load(weight_path)['net_rev'])
|
|
self.net_rev.eval()
|
|
self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2'])
|
|
self.net_rev_2.eval()
|
|
|
|
def run(self, content_img):
|
|
content_img, style_img, h, w = img_read(content_img, self.style_img_path)
|
|
pyr_ci = make_laplace_pyramid(content_img, 2)
|
|
pyr_si = make_laplace_pyramid(style_img, 2)
|
|
pyr_ci.append(content_img)
|
|
pyr_si.append(style_img)
|
|
cF = self.net_enc(pyr_ci[2])
|
|
sF = self.net_enc(pyr_si[2])
|
|
stylized_small = self.net_dec(cF, sF)
|
|
stylized_up = F.interpolate(stylized_small, scale_factor=2)
|
|
revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1)
|
|
stylized_rev_lap = self.net_rev(revnet_input)
|
|
stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small])
|
|
stylized_up = F.interpolate(stylized_rev, scale_factor=2)
|
|
revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1)
|
|
stylized_rev_lap_second = self.net_rev_2(revnet_input)
|
|
stylized_rev_second = fold_laplace_pyramid(
|
|
[stylized_rev_lap_second, stylized_rev_lap, stylized_small])
|
|
stylized = stylized_rev_second
|
|
stylized_visual = tensor2img(stylized, min_max=(0., 1.))
|
|
stylized_visual = cv.resize(stylized_visual, (w, h))
|
|
return stylized_visual
|