from flask import Flask, request, send_file, jsonify, send_from_directory from flask_cors import CORS from PIL import Image, ImageFilter, ImageEnhance, ImageOps import io import uuid import os import sys import yaml import torch import imageio import imageio_ffmpeg as ffmpeg import numpy as np from moviepy.editor import VideoFileClip from skimage.transform import resize from skimage import img_as_ubyte from tqdm import tqdm sys.path.append(os.path.abspath('./firstordermodel')) sys.path.append(".") from modules.generator import OcclusionAwareGenerator from modules.keypoint_detector import KPDetector from animate import normalize_kp app = Flask(__name__) CORS(app) @app.route('/process-image', methods=['POST']) def process_image(): file = request.files['file'] operation = request.form['operation'] parameter = request.form['parameter'] image = Image.open(file.stream) if operation == 'rotate': image = image.rotate(float(parameter)) elif operation == 'flip': image = ImageOps.flip(image) elif operation == 'scale': scale_factor = float(parameter) w, h = image.size new_width = int(w*scale_factor) new_height = int(h*scale_factor) image = image.resize((new_width,new_height), Image.ANTIALIAS) # blank=(new_width-new_height)*scale_factor # image=image.crop((0,-blank,new_width,new_width-blank)) elif operation == 'filter': if parameter == 'BLUR': image = image.filter(ImageFilter.BLUR) elif parameter == 'EMBOSS': image = image.filter(ImageFilter.EMBOSS) elif parameter == 'CONTOUR': image = image.filter(ImageFilter.CONTOUR) elif parameter == 'SHARPEN': image = image.filter(ImageFilter.SHARPEN) elif operation == 'color_adjust': r, g, b = map(float, parameter.split(',')) r = r if r else 1.0 g = g if g else 1.0 b = b if b else 1.0 r_channel, g_channel, b_channel = image.split() r_channel = r_channel.point(lambda i: i * r) g_channel = g_channel.point(lambda i: i * g) b_channel = b_channel.point(lambda i: i * b) image = Image.merge('RGB', (r_channel, g_channel, b_channel)) elif operation == 'contrast': enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(2) elif operation == 'smooth': image = image.filter(ImageFilter.SMOOTH) img_io = io.BytesIO() image.save(img_io, 'JPEG') img_io.seek(0) return send_file(img_io, mimetype='image/jpeg') def load_checkpoints(config_path, checkpoint_path, cpu=True): with open(config_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']) if not cpu: generator.cuda() kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) if not cpu: kp_detector.cuda() if cpu: checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) else: checkpoint = torch.load(checkpoint_path) generator.load_state_dict(checkpoint['generator']) kp_detector.load_state_dict(checkpoint['kp_detector']) if not cpu: generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() return generator, kp_detector def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True): with torch.no_grad(): predictions = [] source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) if not cpu: source = source.cuda() driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) kp_source = kp_detector(source) kp_driving_initial = kp_detector(driving[:, :, 0]) for frame_idx in tqdm(range(driving.shape[2])): driving_frame = driving[:, :, frame_idx] if not cpu: driving_frame = driving_frame.cuda() kp_driving = kp_detector(driving_frame) kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, kp_driving_initial=kp_driving_initial, use_relative_movement=relative, use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) out = generator(source, kp_source=kp_source, kp_driving=kp_norm) predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) return predictions @app.route('/motion-drive', methods=['POST']) def motion_drive(): image_file = request.files['image'] video_file = request.files['video'] source_image = imageio.imread(image_file) # 保存视频文件到临时路径 video_path = f"./data/{uuid.uuid1().hex}.mp4" video_file.save(video_path) reader = imageio.get_reader(video_path, 'ffmpeg') fps = reader.get_meta_data()['fps'] driving_video = [] try: for im in reader: driving_video.append(im) except RuntimeError: pass reader.close() source_image = resize(source_image, (256, 256))[..., :3] driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] config_path = "./firstordermodel/config/vox-adv-256.yaml" checkpoint_path = "./firstordermodel/checkpoints/vox-adv-cpk.pth.tar" generator, kp_detector = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path, cpu=True) predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True) result_filename = f"result_{uuid.uuid1().hex}.mp4" result_path = os.path.join('data', result_filename) imageio.mimsave(result_path, [img_as_ubyte(frame) for frame in predictions], fps=fps) return jsonify({"video_url": f"/data/{result_filename}"}) @app.route('/data/', methods=['GET']) def download_file(filename): return send_from_directory('data', filename) if __name__ == "__main__": app.run(debug=True)