You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
5.4 KiB

import cv2 as cv
import numpy as np
import paddle
import paddle.nn.functional as F
from PIL import Image
from paddle.vision.transforms import functional
from ppgan.apps.base_predictor import BasePredictor
from ppgan.models.generators import DecoderNet, Encoder, RevisionNet
from ppgan.utils.download import get_path_from_url
from ppgan.utils.visual import tensor2img
LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams'
LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams'
LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams'
LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams'
def img(src):
# some images have 4 channels
if src.shape[2] > 3:
src = src[:, :, :3]
# HWC to CHW
return src
# rgb图像
def img_read(content_img: np.ndarray, style_img_path: str):
if content_img.ndim == 2:
content_img = cv.cvtColor(content_img, cv.COLOR_GRAY2RGB)
h, w, c = content_img.shape
content_img = Image.fromarray(content_img)
content_img = content_img.resize((512, 512), Image.BILINEAR)
content_img = np.array(content_img)
content_img = img(content_img)
content_img = functional.to_tensor(content_img)
style_img = cv.imread(style_img_path)
style_img = cv.cvtColor(style_img, cv.COLOR_BGR2RGB)
style_img = Image.fromarray(style_img)
style_img = style_img.resize((512, 512), Image.BILINEAR)
style_img = np.array(style_img)
style_img = img(style_img)
style_img = functional.to_tensor(style_img)
content_img = paddle.unsqueeze(content_img, axis=0)
style_img = paddle.unsqueeze(style_img, axis=0)
return content_img, style_img, h, w
def tensor_resample(tensor, dst_size, mode='bilinear'):
return F.interpolate(tensor, dst_size, mode=mode, align_corners=False)
def laplacian(x):
"""
Laplacian
return:
x - upsample(downsample(x))
"""
return x - tensor_resample(
tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]),
[x.shape[2], x.shape[3]])
def make_laplace_pyramid(x, levels):
"""
Make Laplacian Pyramid
"""
pyramid = []
current = x
for i in range(levels):
pyramid.append(laplacian(current))
current = tensor_resample(
current,
(max(current.shape[2] // 2, 1), max(current.shape[3] // 2, 1)))
pyramid.append(current)
return pyramid
def fold_laplace_pyramid(pyramid):
"""
Fold Laplacian Pyramid
"""
current = pyramid[-1]
for i in range(len(pyramid) - 2, -1, -1): # iterate from len-2 to 0
up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3]
current = pyramid[i] + tensor_resample(current, (up_h, up_w))
return current
class LapStylePredictor(BasePredictor):
def __init__(self, style='starrynew'):
super().__init__()
self.input = input
self.net_enc = Encoder()
self.net_dec = DecoderNet()
self.net_rev = RevisionNet()
self.net_rev_2 = RevisionNet()
self.style_img_path = None
self.weight_path = None
if style == 'starrynew':
weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL)
self.style_img_path = "assets/starrynew.png"
elif style == 'circuit':
weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL)
self.style_img_path = "assets/circuit.jpg"
elif style == 'ocean':
weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL)
self.style_img_path = "assets/ocean.png"
elif style == 'stars':
weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL)
self.style_img_path = "assets/stars.png"
else:
raise Exception(f'has not implemented {style}.')
self.net_enc.set_dict(paddle.load(weight_path)['net_enc'])
self.net_enc.eval()
self.net_dec.set_dict(paddle.load(weight_path)['net_dec'])
self.net_dec.eval()
self.net_rev.set_dict(paddle.load(weight_path)['net_rev'])
self.net_rev.eval()
self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2'])
self.net_rev_2.eval()
def run(self, content_img):
content_img, style_img, h, w = img_read(content_img, self.style_img_path)
pyr_ci = make_laplace_pyramid(content_img, 2)
pyr_si = make_laplace_pyramid(style_img, 2)
pyr_ci.append(content_img)
pyr_si.append(style_img)
cF = self.net_enc(pyr_ci[2])
sF = self.net_enc(pyr_si[2])
stylized_small = self.net_dec(cF, sF)
stylized_up = F.interpolate(stylized_small, scale_factor=2)
revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1)
stylized_rev_lap = self.net_rev(revnet_input)
stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small])
stylized_up = F.interpolate(stylized_rev, scale_factor=2)
revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1)
stylized_rev_lap_second = self.net_rev_2(revnet_input)
stylized_rev_second = fold_laplace_pyramid(
[stylized_rev_lap_second, stylized_rev_lap, stylized_small])
stylized = stylized_rev_second
stylized_visual = tensor2img(stylized, min_max=(0., 1.))
stylized_visual = cv.resize(stylized_visual, (w, h))
return stylized_visual