diff --git a/lapstyle.py b/lapstyle.py new file mode 100644 index 0000000..68d0b77 --- /dev/null +++ b/lapstyle.py @@ -0,0 +1,146 @@ +import cv2 as cv +import numpy as np +import paddle +import paddle.nn.functional as F +from PIL import Image +from paddle.vision.transforms import functional +from ppgan.apps.base_predictor import BasePredictor +from ppgan.models.generators import DecoderNet, Encoder, RevisionNet +from ppgan.utils.download import get_path_from_url +from ppgan.utils.visual import tensor2img + +LapStyle_circuit_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams' +LapStyle_ocean_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams' +LapStyle_starrynew_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams' +LapStyle_stars_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams' + + +def img(src): + # some images have 4 channels + if src.shape[2] > 3: + src = src[:, :, :3] + # HWC to CHW + return src + + +# rgb图像 +def img_read(content_img: np.ndarray, style_img_path: str): + if content_img.ndim == 2: + content_img = cv.cvtColor(content_img, cv.COLOR_GRAY2RGB) + h, w, c = content_img.shape + content_img = Image.fromarray(content_img) + content_img = content_img.resize((512, 512), Image.BILINEAR) + content_img = np.array(content_img) + content_img = img(content_img) + content_img = functional.to_tensor(content_img) + + style_img = cv.imread(style_img_path) + style_img = cv.cvtColor(style_img, cv.COLOR_BGR2RGB) + style_img = Image.fromarray(style_img) + style_img = style_img.resize((512, 512), Image.BILINEAR) + style_img = np.array(style_img) + style_img = img(style_img) + style_img = functional.to_tensor(style_img) + + content_img = paddle.unsqueeze(content_img, axis=0) + style_img = paddle.unsqueeze(style_img, axis=0) + return content_img, style_img, h, w + + +def tensor_resample(tensor, dst_size, mode='bilinear'): + return F.interpolate(tensor, dst_size, mode=mode, align_corners=False) + + +def laplacian(x): + """ + Laplacian + + return: + x - upsample(downsample(x)) + """ + return x - tensor_resample( + tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]), + [x.shape[2], x.shape[3]]) + + +def make_laplace_pyramid(x, levels): + """ + Make Laplacian Pyramid + """ + pyramid = [] + current = x + for i in range(levels): + pyramid.append(laplacian(current)) + current = tensor_resample( + current, + (max(current.shape[2] // 2, 1), max(current.shape[3] // 2, 1))) + pyramid.append(current) + return pyramid + + +def fold_laplace_pyramid(pyramid): + """ + Fold Laplacian Pyramid + """ + current = pyramid[-1] + for i in range(len(pyramid) - 2, -1, -1): # iterate from len-2 to 0 + up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3] + current = pyramid[i] + tensor_resample(current, (up_h, up_w)) + return current + + +class LapStylePredictor(BasePredictor): + def __init__(self, style='starrynew'): + super().__init__() + self.input = input + self.net_enc = Encoder() + self.net_dec = DecoderNet() + self.net_rev = RevisionNet() + self.net_rev_2 = RevisionNet() + self.style_img_path = None + self.weight_path = None + if style == 'starrynew': + weight_path = get_path_from_url(LapStyle_starrynew_WEIGHT_URL) + self.style_img_path = "assets/starrynew.png" + elif style == 'circuit': + weight_path = get_path_from_url(LapStyle_circuit_WEIGHT_URL) + self.style_img_path = "assets/circuit.jpg" + elif style == 'ocean': + weight_path = get_path_from_url(LapStyle_ocean_WEIGHT_URL) + self.style_img_path = "assets/ocean.png" + elif style == 'stars': + weight_path = get_path_from_url(LapStyle_stars_WEIGHT_URL) + self.style_img_path = "assets/stars.png" + else: + raise Exception(f'has not implemented {style}.') + self.net_enc.set_dict(paddle.load(weight_path)['net_enc']) + self.net_enc.eval() + self.net_dec.set_dict(paddle.load(weight_path)['net_dec']) + self.net_dec.eval() + self.net_rev.set_dict(paddle.load(weight_path)['net_rev']) + self.net_rev.eval() + self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2']) + self.net_rev_2.eval() + + def run(self, content_img): + content_img, style_img, h, w = img_read(content_img, self.style_img_path) + pyr_ci = make_laplace_pyramid(content_img, 2) + pyr_si = make_laplace_pyramid(style_img, 2) + pyr_ci.append(content_img) + pyr_si.append(style_img) + cF = self.net_enc(pyr_ci[2]) + sF = self.net_enc(pyr_si[2]) + stylized_small = self.net_dec(cF, sF) + stylized_up = F.interpolate(stylized_small, scale_factor=2) + revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1) + stylized_rev_lap = self.net_rev(revnet_input) + stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small]) + stylized_up = F.interpolate(stylized_rev, scale_factor=2) + revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1) + stylized_rev_lap_second = self.net_rev_2(revnet_input) + stylized_rev_second = fold_laplace_pyramid( + [stylized_rev_lap_second, stylized_rev_lap, stylized_small]) + stylized = stylized_rev_second + stylized_visual = tensor2img(stylized, min_max=(0., 1.)) + stylized_visual = cv.resize(stylized_visual, (w, h)) + return stylized_visual