@ -1,3 +0,0 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
@ -1,6 +0,0 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class ActiondriveConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'actionDrive'
|
||||
@ -1,3 +0,0 @@
|
||||
from django.db import models
|
||||
|
||||
# Create your models here.
|
||||
@ -1,3 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
@ -1,3 +0,0 @@
|
||||
from django.shortcuts import render
|
||||
|
||||
# Create your views here.
|
||||
@ -1,3 +0,0 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
@ -1,6 +0,0 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class AgetransferConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'ageTransfer'
|
||||
@ -1,3 +0,0 @@
|
||||
from django.db import models
|
||||
|
||||
# Create your models here.
|
||||
@ -1,3 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
@ -1,3 +0,0 @@
|
||||
from django.shortcuts import render
|
||||
|
||||
# Create your views here.
|
||||
|
After Width: | Height: | Size: 85 KiB |
@ -0,0 +1,179 @@
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 131, in keying
|
||||
result=cvtobase64(image)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 32, in cvtobase64
|
||||
res_b = cv2.imencode('.jpg', img)[1].tostring()
|
||||
cv2.error: OpenCV(4.6.0) D:\a\opencv-python\opencv-python\opencv\modules\imgcodecs\src\loadsave.cpp:976: error: (-215:Assertion failed) !image.empty() in function 'cv::imencode'
|
||||
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 131, in keying
|
||||
result=cvtobase64(image)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 32, in cvtobase64
|
||||
res_b = cv2.imencode('.jpg', img)[1].tostring()
|
||||
cv2.error: OpenCV(4.6.0) D:\a\opencv-python\opencv-python\opencv\modules\imgcodecs\src\loadsave.cpp:976: error: (-215:Assertion failed) !image.empty() in function 'cv::imencode'
|
||||
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 132, in keying
|
||||
return JsonResponse(data={"image": result}, json_dumps_params={'ensure_ascii': False})
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\http\response.py", line 603, in __init__
|
||||
data = json.dumps(data, cls=encoder, **json_dumps_params)
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\__init__.py", line 234, in dumps
|
||||
return cls(
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\encoder.py", line 199, in encode
|
||||
chunks = self.iterencode(o, _one_shot=True)
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\encoder.py", line 257, in iterencode
|
||||
return _iterencode(o, 0)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\serializers\json.py", line 105, in default
|
||||
return super().default(o)
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\encoder.py", line 179, in default
|
||||
raise TypeError(f'Object of type {o.__class__.__name__} '
|
||||
TypeError: Object of type bytes is not JSON serializable
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 129, in keying
|
||||
rmbg.remove_background_from_base64_img(image,bg_color=(0,0,0))
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\removebg\removebg.py", line 72, in remove_background_from_base64_img
|
||||
response.raise_for_status()
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\requests\models.py", line 943, in raise_for_status
|
||||
raise HTTPError(http_error_msg, response=self)
|
||||
requests.exceptions.HTTPError: 400 Client Error: Bad Request for url: https://api.remove.bg/v1.0/removebg
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 129, in keying
|
||||
rmbg.remove_background_from_base64_img(image,bg_color=(0,0,0))
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\removebg\removebg.py", line 72, in remove_background_from_base64_img
|
||||
response.raise_for_status()
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\requests\models.py", line 943, in raise_for_status
|
||||
raise HTTPError(http_error_msg, response=self)
|
||||
requests.exceptions.HTTPError: 400 Client Error: Bad Request for url: https://api.remove.bg/v1.0/removebg
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 128, in keying
|
||||
result=getbase64byndarray(image)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 42, in getbase64byndarray
|
||||
retval, buffer = cv2.imencode('.jpg', pic_img)
|
||||
cv2.error: OpenCV(4.6.0) :-1: error: (-5:Bad argument) in function 'imencode'
|
||||
> Overload resolution failed:
|
||||
> - img is not a numpy array, neither a scalar
|
||||
> - Expected Ptr<cv::UMat> for argument 'img'
|
||||
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 128, in keying
|
||||
result=getbase64byndarray(image)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 42, in getbase64byndarray
|
||||
retval, buffer = cv2.imencode('.jpg', pic_img)
|
||||
cv2.error: OpenCV(4.6.0) :-1: error: (-5:Bad argument) in function 'imencode'
|
||||
> Overload resolution failed:
|
||||
> - img is not a numpy array, neither a scalar
|
||||
> - Expected Ptr<cv::UMat> for argument 'img'
|
||||
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 129, in keying
|
||||
return JsonResponse(data={"image": result}, json_dumps_params={'ensure_ascii': False})
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\http\response.py", line 603, in __init__
|
||||
data = json.dumps(data, cls=encoder, **json_dumps_params)
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\__init__.py", line 234, in dumps
|
||||
return cls(
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\encoder.py", line 199, in encode
|
||||
chunks = self.iterencode(o, _one_shot=True)
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\encoder.py", line 257, in iterencode
|
||||
return _iterencode(o, 0)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\serializers\json.py", line 105, in default
|
||||
return super().default(o)
|
||||
File "C:\Users\ASUS\anaconda3\lib\json\encoder.py", line 179, in default
|
||||
raise TypeError(f'Object of type {o.__class__.__name__} '
|
||||
TypeError: Object of type ndarray is not JSON serializable
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 127, in keying
|
||||
rmbg.remove_background_from_base64_img(image,new_file_name="output.jpg",bg_color=(0,0,0))
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\removebg\removebg.py", line 72, in remove_background_from_base64_img
|
||||
response.raise_for_status()
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\requests\models.py", line 943, in raise_for_status
|
||||
raise HTTPError(http_error_msg, response=self)
|
||||
requests.exceptions.HTTPError: 400 Client Error: Bad Request for url: https://api.remove.bg/v1.0/removebg
|
||||
ERROR:django.request:Internal Server Error: /keying/
|
||||
Traceback (most recent call last):
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\exception.py", line 47, in inner
|
||||
response = get_response(request)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\core\handlers\base.py", line 181, in _get_response
|
||||
response = wrapped_callback(request, *callback_args, **callback_kwargs)
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\django\views\decorators\http.py", line 40, in inner
|
||||
return func(request, *args, **kwargs)
|
||||
File "F:\imageprocess\imageProcess\basicFunction\views.py", line 127, in keying
|
||||
rmbg.remove_background_from_base64_img(image,new_file_name="output.jpg",bg_color=(0,0,0))
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\removebg\removebg.py", line 72, in remove_background_from_base64_img
|
||||
response.raise_for_status()
|
||||
File "C:\Users\ASUS\anaconda3\lib\site-packages\requests\models.py", line 943, in raise_for_status
|
||||
raise HTTPError(http_error_msg, response=self)
|
||||
requests.exceptions.HTTPError: 400 Client Error: Bad Request for url: https://api.remove.bg/v1.0/removebg
|
||||
INFO:django.utils.autoreload:F:\imageprocess\imageProcess\basicFunction\views.py changed, reloading.
|
||||
|
After Width: | Height: | Size: 111 KiB |
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
|
||||
This program is free software; you can redistribute it and/or modify
|
||||
it under the terms of the Apache License Version 2.0.You may not use this file
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
Apache License for more details at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
|
||||
class Constant:
|
||||
# error code
|
||||
ACL_ERROR_NONE = 0
|
||||
|
||||
# rule for mem
|
||||
ACL_MEM_MALLOC_HUGE_FIRST = 0
|
||||
ACL_MEM_MALLOC_HUGE_ONLY = 1
|
||||
ACL_MEM_MALLOC_NORMAL_ONLY = 2
|
||||
|
||||
# rule for memory copy
|
||||
ACL_MEMCPY_HOST_TO_HOST = 0
|
||||
ACL_MEMCPY_HOST_TO_DEVICE = 1
|
||||
ACL_MEMCPY_DEVICE_TO_HOST = 2
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE = 3
|
||||
|
||||
# images format
|
||||
IMG_EXT = ['.jpg', '.JPG', '.png', '.PNG', '.bmp', '.BMP', '.jpeg', '.JPEG']
|
||||
|
||||
# numpy data type
|
||||
NPY_FLOAT32 = 11
|
||||
|
After Width: | Height: | Size: 30 KiB |
@ -0,0 +1,135 @@
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
# import moxing as mox
|
||||
from mindspore import numpy
|
||||
from PIL import Image
|
||||
from mindspore import ops
|
||||
import os
|
||||
import cv2 as cv
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as c_version
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
from mindspore import Model, context, nn, Tensor, Parameter, load_checkpoint
|
||||
|
||||
|
||||
def save_img(i, optimizer, output_path):
|
||||
if not os.path.exists(output_path):
|
||||
os.mkdir(output_path)
|
||||
final_img = optimizer.parameters[0].asnumpy()
|
||||
final_img = final_img.squeeze(axis=0)
|
||||
final_img = np.moveaxis(final_img, 0, 2)
|
||||
dump_img = np.copy(final_img)
|
||||
dump_img += np.array([123.675, 116.28, 103.53]).reshape((1, 1, 3))
|
||||
dump_img = np.clip(dump_img, 0, 255).astype('uint8')
|
||||
# dump_img = cv.resize(dump_img, (224, 224), interpolation=cv.INTER_CUBIC)
|
||||
img_path = output_path+"/"+"iter_"+str(i)+".jpg" # imgput_path = ./output_path/content2_to_sty3/lr=0.5/iter_1
|
||||
cv.imwrite(img_path, dump_img[:, :, ::-1])
|
||||
# mox.file.copy_parallel(img_path, args.train_url+img_path[12:])
|
||||
|
||||
|
||||
def create_dataset(img):
|
||||
"""生成数据集"""
|
||||
dataset = ds.NumpySlicesDataset(data=img, column_names=['data'])
|
||||
return dataset
|
||||
|
||||
|
||||
def gram_matrix(x, should_normalize=True):
|
||||
"""
|
||||
Generate gram matrices of the representations of content and style images.
|
||||
"""
|
||||
# 对网络的特征进行矩阵编码
|
||||
b, ch, h, w = x.shape # x的形状
|
||||
features = x.view(b, ch, w * h) # 将x降维
|
||||
transpose = ops.Transpose()
|
||||
batmatmul = ops.BatchMatMul(transpose_a=False)
|
||||
features_t = transpose(features, (0, 2, 1))
|
||||
gram = batmatmul(features, features_t) # gram 为矩阵相乘计算得新图片的像素
|
||||
if should_normalize: # 标准化
|
||||
gram /= ch * h * w
|
||||
return gram
|
||||
|
||||
|
||||
def load_image(img_path, target_shape=None):
|
||||
# 图像预处理 返回 1 * 3 * 400 * x
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception(f'Path not found: {img_path}')
|
||||
img = cv.imread(img_path)[:, :, ::-1] # convert BGR to RGB when reading
|
||||
if target_shape is not None:
|
||||
if isinstance(target_shape, int) and target_shape != -1:
|
||||
current_height, current_width = img.shape[:2]
|
||||
new_height = target_shape
|
||||
new_width = int(current_width * (new_height / current_height))
|
||||
img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC)
|
||||
else:
|
||||
img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC)
|
||||
img = img.astype(np.float32)
|
||||
|
||||
to_tensor = py_vision.ToTensor() # channel conversion and pixel value normalization
|
||||
normalize = c_version.Normalize(mean=[123.675, 116.28, 103.53], std=[1, 1, 1])
|
||||
img = normalize(img) # <class 'numpy.ndarray'> (400, 533, 3)
|
||||
img = to_tensor(img) * 225 # <class 'numpy.ndarray'> (3, 400, 533)
|
||||
img = np.expand_dims(img, axis=0)
|
||||
# img /= 255
|
||||
# transform = transforms.Compose([
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Lambda(lambda x: x.mul(255)), # 乘255
|
||||
# transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[1, 1, 1])
|
||||
# ])
|
||||
# img = transform(img).unsqueeze(0)
|
||||
# img = img.numpy()
|
||||
return img
|
||||
|
||||
|
||||
class Optim_Loss(nn.Cell):
|
||||
def __init__(self, net, target_maps):
|
||||
super(Optim_Loss, self).__init__()
|
||||
self.net = net
|
||||
self.target_maps = target_maps[:-1]
|
||||
self.weight = [100000.0, 30000.0, 1.0]
|
||||
self.get_style_loss = nn.MSELoss(reduction='sum')
|
||||
self.get_content_loss = nn.MSELoss(reduction='mean')
|
||||
self.cast = ops.Cast() # 转换为 mindspore.tensor
|
||||
self.ct = target_maps[2]
|
||||
|
||||
def construct(self):
|
||||
optimize_img = self.ct
|
||||
current_maps = self.net(self.cast(optimize_img, mindspore.float32)) # 6个特征图
|
||||
# 当前图片的特征
|
||||
current_content_maps = current_maps[4].squeeze(axis=0) # 内容特征 # 4_2的内容map
|
||||
for i in range(len(current_maps)): # 0, 1, 2, 3, 4, 5 1, 2, 3, 4, 4_2, 5
|
||||
if i != 4:
|
||||
current_maps[i] = gram_matrix(current_maps[i])
|
||||
target_content_maps = self.target_maps[0] # 任务的内容特征
|
||||
target_content_gram = self.target_maps[1] # 任务的风格特征
|
||||
content_loss = self.get_content_loss(current_content_maps, target_content_maps)
|
||||
style_loss = 0
|
||||
for j in range(6):
|
||||
if j == 5:
|
||||
style_loss += self.get_style_loss(current_maps[j], target_content_gram[j-1])
|
||||
if j < 4:
|
||||
style_loss += self.get_style_loss(current_maps[j], target_content_gram[j])
|
||||
|
||||
style_loss /= 5
|
||||
tv_loss = numpy.sum(numpy.abs(optimize_img[:, :, :, :-1] - optimize_img[:, :, :, 1:])) \
|
||||
+ numpy.sum(numpy.abs(optimize_img[:, :, :-1, :] - optimize_img[:, :, 1:, :]))
|
||||
|
||||
total_loss = content_loss * self.weight[0] + style_loss * self.weight[1] + tv_loss * self.weight[2]
|
||||
return total_loss/130001
|
||||
|
||||
|
||||
def load_parameters(file_name):
|
||||
param_dict = load_checkpoint(file_name)
|
||||
param_dict_new = {}
|
||||
# print(param_dict)
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith("layers."):
|
||||
param_dict_new['l'+key[7:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
return param_dict_new
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from django.db import models
|
||||
|
||||
# Create your models here.
|
||||
|
After Width: | Height: | Size: 479 KiB |
@ -0,0 +1,54 @@
|
||||
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
from mindspore import nn, context
|
||||
|
||||
|
||||
def load_parameters(file_name):
|
||||
param_dict = load_checkpoint(file_name)
|
||||
param_dict_new = {}
|
||||
# print(param_dict)
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith("layers."):
|
||||
param_dict_new['l'+key[7:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
return param_dict_new
|
||||
|
||||
|
||||
class Vgg19(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l0 = nn.Conv2d(3, 64, kernel_size=3, weight_init='ones')
|
||||
self.l2 = nn.Conv2d(64, 64, kernel_size=3, weight_init='ones')
|
||||
self.l5 = nn.Conv2d(64, 128, kernel_size=3, weight_init='ones')
|
||||
self.l7 = nn.Conv2d(128, 128, kernel_size=3, weight_init='ones')
|
||||
self.l10 = nn.Conv2d(128, 256, kernel_size=3, weight_init='ones')
|
||||
self.l12 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
|
||||
self.l14 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
|
||||
self.l16 = nn.Conv2d(256, 256, kernel_size=3, weight_init='ones')
|
||||
self.l19 = nn.Conv2d(256, 512, kernel_size=3, weight_init='ones')
|
||||
self.l21 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
||||
self.l23 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
||||
self.l25 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
||||
self.l28 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
||||
self.l30 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
||||
self.l32 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
||||
self.l34 = nn.Conv2d(512, 512, kernel_size=3, weight_init='ones')
|
||||
self.relu = nn.ReLU()
|
||||
self.mp = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
layer_1 = self.relu(self.l0(x)) # 3-64
|
||||
layer_2 = self.relu(self.l5(self.mp(self.relu(self.l2(layer_1)))))
|
||||
layer_3 = self.relu(self.l10(self.relu(self.l7(self.mp(layer_2)))))
|
||||
layer_4 = self.relu(self.l19(self.mp(self.relu(self.l16(self.relu(self.l14(self.relu(self.l12(layer_3)))))))))
|
||||
layer_4_2 = self.relu(self.l21(layer_4))
|
||||
layer_5 = self.relu(self.l28(self.mp(self.relu(self.l25(self.relu(self.l23(layer_4_2)))))))
|
||||
return [layer_1, layer_2, layer_3, layer_4, layer_4_2, layer_5]
|
||||
|
||||
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 187 KiB |
|
After Width: | Height: | Size: 348 KiB |
|
Before Width: | Height: | Size: 248 KiB |
|
After Width: | Height: | Size: 431 KiB |
|
After Width: | Height: | Size: 447 KiB |
|
After Width: | Height: | Size: 352 KiB |
|
After Width: | Height: | Size: 346 KiB |