You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
6.3 KiB
173 lines
6.3 KiB
from flask import Flask, request, send_file, jsonify, send_from_directory
|
|
from flask_cors import CORS
|
|
from PIL import Image, ImageFilter, ImageEnhance, ImageOps
|
|
import io
|
|
import uuid
|
|
import os
|
|
import sys
|
|
import yaml
|
|
import torch
|
|
import imageio
|
|
import imageio_ffmpeg as ffmpeg
|
|
import numpy as np
|
|
from moviepy.editor import VideoFileClip
|
|
from skimage.transform import resize
|
|
from skimage import img_as_ubyte
|
|
from tqdm import tqdm
|
|
sys.path.append(os.path.abspath('./firstordermodel'))
|
|
sys.path.append(".")
|
|
from modules.generator import OcclusionAwareGenerator
|
|
from modules.keypoint_detector import KPDetector
|
|
from animate import normalize_kp
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
@app.route('/process-image', methods=['POST'])
|
|
def process_image():
|
|
file = request.files['file']
|
|
operation = request.form['operation']
|
|
parameter = request.form['parameter']
|
|
|
|
image = Image.open(file.stream)
|
|
|
|
if operation == 'rotate':
|
|
image = image.rotate(float(parameter))
|
|
elif operation == 'flip':
|
|
image = ImageOps.flip(image)
|
|
elif operation == 'scale':
|
|
scale_factor = float(parameter)
|
|
w, h = image.size
|
|
new_width = int(w*scale_factor)
|
|
new_height = int(h*scale_factor)
|
|
image = image.resize((new_width,new_height), Image.ANTIALIAS)
|
|
# blank=(new_width-new_height)*scale_factor
|
|
# image=image.crop((0,-blank,new_width,new_width-blank))
|
|
elif operation == 'filter':
|
|
if parameter == 'BLUR':
|
|
image = image.filter(ImageFilter.BLUR)
|
|
elif parameter == 'EMBOSS':
|
|
image = image.filter(ImageFilter.EMBOSS)
|
|
elif parameter == 'CONTOUR':
|
|
image = image.filter(ImageFilter.CONTOUR)
|
|
elif parameter == 'SHARPEN':
|
|
image = image.filter(ImageFilter.SHARPEN)
|
|
elif operation == 'color_adjust':
|
|
r, g, b = map(float, parameter.split(','))
|
|
r = r if r else 1.0
|
|
g = g if g else 1.0
|
|
b = b if b else 1.0
|
|
r_channel, g_channel, b_channel = image.split()
|
|
r_channel = r_channel.point(lambda i: i * r)
|
|
g_channel = g_channel.point(lambda i: i * g)
|
|
b_channel = b_channel.point(lambda i: i * b)
|
|
image = Image.merge('RGB', (r_channel, g_channel, b_channel))
|
|
elif operation == 'contrast':
|
|
enhancer = ImageEnhance.Contrast(image)
|
|
image = enhancer.enhance(2)
|
|
elif operation == 'smooth':
|
|
image = image.filter(ImageFilter.SMOOTH)
|
|
|
|
img_io = io.BytesIO()
|
|
image.save(img_io, 'JPEG')
|
|
img_io.seek(0)
|
|
|
|
return send_file(img_io, mimetype='image/jpeg')
|
|
|
|
|
|
def load_checkpoints(config_path, checkpoint_path, cpu=True):
|
|
with open(config_path) as f:
|
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
|
**config['model_params']['common_params'])
|
|
if not cpu:
|
|
generator.cuda()
|
|
|
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
|
**config['model_params']['common_params'])
|
|
if not cpu:
|
|
kp_detector.cuda()
|
|
|
|
if cpu:
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
|
else:
|
|
checkpoint = torch.load(checkpoint_path)
|
|
|
|
generator.load_state_dict(checkpoint['generator'])
|
|
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
|
|
|
if not cpu:
|
|
generator = DataParallelWithCallback(generator)
|
|
kp_detector = DataParallelWithCallback(kp_detector)
|
|
|
|
generator.eval()
|
|
kp_detector.eval()
|
|
|
|
return generator, kp_detector
|
|
|
|
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True):
|
|
with torch.no_grad():
|
|
predictions = []
|
|
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
|
if not cpu:
|
|
source = source.cuda()
|
|
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
|
kp_source = kp_detector(source)
|
|
kp_driving_initial = kp_detector(driving[:, :, 0])
|
|
|
|
for frame_idx in tqdm(range(driving.shape[2])):
|
|
driving_frame = driving[:, :, frame_idx]
|
|
if not cpu:
|
|
driving_frame = driving_frame.cuda()
|
|
kp_driving = kp_detector(driving_frame)
|
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
|
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
|
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
|
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
|
|
|
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
|
return predictions
|
|
|
|
@app.route('/motion-drive', methods=['POST'])
|
|
def motion_drive():
|
|
image_file = request.files['image']
|
|
video_file = request.files['video']
|
|
|
|
source_image = imageio.imread(image_file)
|
|
|
|
# 保存视频文件到临时路径
|
|
video_path = f"./data/{uuid.uuid1().hex}.mp4"
|
|
video_file.save(video_path)
|
|
|
|
reader = imageio.get_reader(video_path, 'ffmpeg')
|
|
fps = reader.get_meta_data()['fps']
|
|
driving_video = []
|
|
try:
|
|
for im in reader:
|
|
driving_video.append(im)
|
|
except RuntimeError:
|
|
pass
|
|
reader.close()
|
|
|
|
source_image = resize(source_image, (256, 256))[..., :3]
|
|
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
|
config_path = "./firstordermodel/config/vox-adv-256.yaml"
|
|
checkpoint_path = "./firstordermodel/checkpoints/vox-adv-cpk.pth.tar"
|
|
generator, kp_detector = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path, cpu=True)
|
|
|
|
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True)
|
|
|
|
result_filename = f"result_{uuid.uuid1().hex}.mp4"
|
|
result_path = os.path.join('data', result_filename)
|
|
imageio.mimsave(result_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
|
|
|
return jsonify({"video_url": f"/data/{result_filename}"})
|
|
|
|
@app.route('/data/<path:filename>', methods=['GET'])
|
|
def download_file(filename):
|
|
return send_from_directory('data', filename)
|
|
|
|
if __name__ == "__main__":
|
|
app.run(debug=True)
|