You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

544 lines
22 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "first-order-model-demo",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>",
"<a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cdO_RxQZLahB"
},
"source": [
"# Demo for paper \"First Order Motion Model for Image Animation\"\n",
"To try the demo, press the 2 play buttons in order and scroll to the bottom. Note that it may take several minutes to load."
]
},
{
"cell_type": "code",
"metadata": {
"id": "UCMFMJV7K-ag"
},
"source": [
"%%capture\n",
"%pip install ffmpeg-python imageio-ffmpeg\n",
"!git init .\n",
"!git remote add origin https://github.com/AliaksandrSiarohin/first-order-model\n",
"!git pull origin master\n",
"!git clone https://github.com/graphemecluster/first-order-model-demo demo"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Oxi6-riLOgnm"
},
"source": [
"import IPython.display\n",
"import PIL.Image\n",
"import cv2\n",
"import ffmpeg\n",
"import imageio\n",
"import io\n",
"import ipywidgets\n",
"import numpy\n",
"import os.path\n",
"import requests\n",
"import skimage.transform\n",
"import warnings\n",
"from base64 import b64encode\n",
"from demo import load_checkpoints, make_animation # type: ignore (local file)\n",
"from google.colab import files, output\n",
"from IPython.display import HTML, Javascript\n",
"from shutil import copyfileobj\n",
"from skimage import img_as_ubyte\n",
"from tempfile import NamedTemporaryFile\n",
"from tqdm.auto import tqdm\n",
"warnings.filterwarnings(\"ignore\")\n",
"os.makedirs(\"user\", exist_ok=True)\n",
"\n",
"display(HTML(\"\"\"\n",
"<style>\n",
".widget-box > * {\n",
"\tflex-shrink: 0;\n",
"}\n",
".widget-tab {\n",
"\tmin-width: 0;\n",
"\tflex: 1 1 auto;\n",
"}\n",
".widget-tab .p-TabBar-tabLabel {\n",
"\tfont-size: 15px;\n",
"}\n",
".widget-upload {\n",
"\tbackground-color: tan;\n",
"}\n",
".widget-button {\n",
"\tfont-size: 18px;\n",
"\twidth: 160px;\n",
"\theight: 34px;\n",
"\tline-height: 34px;\n",
"}\n",
".widget-dropdown {\n",
"\twidth: 250px;\n",
"}\n",
".widget-checkbox {\n",
"\twidth: 650px;\n",
"}\n",
".widget-checkbox + .widget-checkbox {\n",
"\tmargin-top: -6px;\n",
"}\n",
".input-widget .output_html {\n",
"\ttext-align: center;\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tline-height: 266px;\n",
"\tcolor: lightgray;\n",
"\tfont-size: 72px;\n",
"}\n",
".title {\n",
"\tfont-size: 20px;\n",
"\tfont-weight: bold;\n",
"\tmargin: 12px 0 6px 0;\n",
"}\n",
".warning {\n",
"\tdisplay: none;\n",
"\tcolor: red;\n",
"\tmargin-left: 10px;\n",
"}\n",
".warn {\n",
"\tdisplay: initial;\n",
"}\n",
".resource {\n",
"\tcursor: pointer;\n",
"\tborder: 1px solid gray;\n",
"\tmargin: 5px;\n",
"\twidth: 160px;\n",
"\theight: 160px;\n",
"\tmin-width: 160px;\n",
"\tmin-height: 160px;\n",
"\tmax-width: 160px;\n",
"\tmax-height: 160px;\n",
"\t-webkit-box-sizing: initial;\n",
"\tbox-sizing: initial;\n",
"}\n",
".resource:hover {\n",
"\tborder: 6px solid crimson;\n",
"\tmargin: 0;\n",
"}\n",
".selected {\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".input-widget {\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".input-button {\n",
"\twidth: 268px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".output-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".output-button {\n",
"\twidth: 258px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".uploaded {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".label-or {\n",
"\talign-self: center;\n",
"\tfont-size: 20px;\n",
"\tmargin: 16px;\n",
"}\n",
".loading {\n",
"\talign-items: center;\n",
"\twidth: fit-content;\n",
"}\n",
".loader {\n",
"\tmargin: 32px 0 16px 0;\n",
"\twidth: 48px;\n",
"\theight: 48px;\n",
"\tmin-width: 48px;\n",
"\tmin-height: 48px;\n",
"\tmax-width: 48px;\n",
"\tmax-height: 48px;\n",
"\tborder: 4px solid whitesmoke;\n",
"\tborder-top-color: gray;\n",
"\tborder-radius: 50%;\n",
"\tanimation: spin 1.8s linear infinite;\n",
"}\n",
".loading-label {\n",
"\tcolor: gray;\n",
"}\n",
".video {\n",
"\tmargin: 0;\n",
"}\n",
".comparison-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"\tmargin-left: 2px;\n",
"}\n",
".comparison-label {\n",
"\tcolor: gray;\n",
"\tfont-size: 14px;\n",
"\ttext-align: center;\n",
"\tposition: relative;\n",
"\tbottom: 3px;\n",
"}\n",
"@keyframes spin {\n",
"\tfrom { transform: rotate(0deg); }\n",
"\tto { transform: rotate(360deg); }\n",
"}\n",
"</style>\n",
"\"\"\"))\n",
"\n",
"def thumbnail(file):\n",
"\treturn imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()\n",
"\n",
"def create_image(i, j):\n",
"\timage_widget = ipywidgets.Image.from_file('demo/images/%d%d.png' % (i, j))\n",
"\timage_widget.add_class('resource')\n",
"\timage_widget.add_class('resource-image')\n",
"\timage_widget.add_class('resource-image%d%d' % (i, j))\n",
"\treturn image_widget\n",
"\n",
"def create_video(i):\n",
"\tvideo_widget = ipywidgets.Image(\n",
"\t\tvalue=cv2.imencode('.png', cv2.cvtColor(thumbnail('demo/videos/%d.mp4' % i), cv2.COLOR_RGB2BGR))[1].tostring(),\n",
"\t\tformat='png'\n",
"\t)\n",
"\tvideo_widget.add_class('resource')\n",
"\tvideo_widget.add_class('resource-video')\n",
"\tvideo_widget.add_class('resource-video%d' % i)\n",
"\treturn video_widget\n",
"\n",
"def create_title(title):\n",
"\ttitle_widget = ipywidgets.Label(title)\n",
"\ttitle_widget.add_class('title')\n",
"\treturn title_widget\n",
"\n",
"def download_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfiles.download('output.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def convert_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tffmpeg.input('output.mp4').output('scaled.mp4', vf='scale=1080x1080:flags=lanczos,pad=1920:1080:420:0').overwrite_output().run()\n",
"\tfiles.download('scaled.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def back_to_main(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tmain.layout.display = ''\n",
"\n",
"label_or = ipywidgets.Label('or')\n",
"label_or.add_class('label-or')\n",
"\n",
"image_titles = ['Peoples', 'Cartoons', 'Dolls', 'Game of Thrones', 'Statues']\n",
"image_lengths = [8, 4, 8, 9, 4]\n",
"\n",
"image_tab = ipywidgets.Tab()\n",
"image_tab.children = [ipywidgets.HBox([create_image(i, j) for j in range(length)]) for i, length in enumerate(image_lengths)]\n",
"for i, title in enumerate(image_titles):\n",
"\timage_tab.set_title(i, title)\n",
"\n",
"input_image_widget = ipywidgets.Output()\n",
"input_image_widget.add_class('input-widget')\n",
"upload_input_image_button = ipywidgets.FileUpload(accept='image/*', button_style='primary')\n",
"upload_input_image_button.add_class('input-button')\n",
"image_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_image_widget, upload_input_image_button]),\n",
"\tlabel_or,\n",
"\timage_tab\n",
"])\n",
"\n",
"video_tab = ipywidgets.Tab()\n",
"video_tab.children = [ipywidgets.HBox([create_video(i) for i in range(5)])]\n",
"video_tab.set_title(0, 'All Videos')\n",
"\n",
"input_video_widget = ipywidgets.Output()\n",
"input_video_widget.add_class('input-widget')\n",
"upload_input_video_button = ipywidgets.FileUpload(accept='video/*', button_style='primary')\n",
"upload_input_video_button.add_class('input-button')\n",
"video_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_video_widget, upload_input_video_button]),\n",
"\tlabel_or,\n",
"\tvideo_tab\n",
"])\n",
"\n",
"model = ipywidgets.Dropdown(\n",
"\tdescription=\"Model:\",\n",
"\toptions=[\n",
"\t\t'vox',\n",
"\t\t'vox-adv',\n",
"\t\t'taichi',\n",
"\t\t'taichi-adv',\n",
"\t\t'nemo',\n",
"\t\t'mgif',\n",
"\t\t'fashion',\n",
"\t\t'bair'\n",
"\t]\n",
")\n",
"warning = ipywidgets.HTML('<b>Warning:</b> Upload your own images and videos (see README)')\n",
"warning.add_class('warning')\n",
"model_part = ipywidgets.HBox([model, warning])\n",
"\n",
"relative = ipywidgets.Checkbox(description=\"Relative keypoint displacement (Inherit object proporions from the video)\", value=True)\n",
"adapt_movement_scale = ipywidgets.Checkbox(description=\"Adapt movement scale (Dont touch unless you know want you are doing)\", value=True)\n",
"generate_button = ipywidgets.Button(description=\"Generate\", button_style='primary')\n",
"main = ipywidgets.VBox([\n",
"\tcreate_title('Choose Image'),\n",
"\timage_part,\n",
"\tcreate_title('Choose Video'),\n",
"\tvideo_part,\n",
"\tcreate_title('Settings'),\n",
"\tmodel_part,\n",
"\trelative,\n",
"\tadapt_movement_scale,\n",
"\tgenerate_button\n",
"])\n",
"\n",
"loader = ipywidgets.Label()\n",
"loader.add_class(\"loader\")\n",
"loading_label = ipywidgets.Label(\"This may take several minutes to process…\")\n",
"loading_label.add_class(\"loading-label\")\n",
"progress_bar = ipywidgets.Output()\n",
"loading = ipywidgets.VBox([loader, loading_label, progress_bar])\n",
"loading.add_class('loading')\n",
"\n",
"output_widget = ipywidgets.Output()\n",
"output_widget.add_class('output-widget')\n",
"download = ipywidgets.Button(description='Download', button_style='primary')\n",
"download.add_class('output-button')\n",
"download.on_click(download_output)\n",
"convert = ipywidgets.Button(description='Convert to 1920×1080', button_style='primary')\n",
"convert.add_class('output-button')\n",
"convert.on_click(convert_output)\n",
"back = ipywidgets.Button(description='Back', button_style='primary')\n",
"back.add_class('output-button')\n",
"back.on_click(back_to_main)\n",
"\n",
"comparison_widget = ipywidgets.Output()\n",
"comparison_widget.add_class('comparison-widget')\n",
"comparison_label = ipywidgets.Label('Comparison')\n",
"comparison_label.add_class('comparison-label')\n",
"complete = ipywidgets.HBox([\n",
"\tipywidgets.VBox([output_widget, download, convert, back]),\n",
"\tipywidgets.VBox([comparison_widget, comparison_label])\n",
"])\n",
"\n",
"display(ipywidgets.VBox([main, loading, complete]))\n",
"display(Javascript(\"\"\"\n",
"var images, videos;\n",
"function deselectImages() {\n",
"\timages.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function deselectVideos() {\n",
"\tvideos.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function invokePython(func) {\n",
"\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n",
"}\n",
"setTimeout(function() {\n",
"\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectImages();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\timages[0].classList.add(\"selected\");\n",
"\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectVideos();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\tvideos[0].classList.add(\"selected\");\n",
"}, 1000);\n",
"\"\"\"))\n",
"\n",
"selected_image = None\n",
"def select_image(filename):\n",
"\tglobal selected_image\n",
"\tselected_image = resize(PIL.Image.open('demo/images/%s.png' % filename).convert(\"RGB\"))\n",
"\tinput_image_widget.clear_output(wait=True)\n",
"\twith input_image_widget:\n",
"\t\tdisplay(HTML('Image'))\n",
"\tinput_image_widget.remove_class('uploaded')\n",
"output.register_callback(\"notebook.select_image\", select_image)\n",
"\n",
"selected_video = None\n",
"def select_video(filename):\n",
"\tglobal selected_video\n",
"\tselected_video = 'demo/videos/%s.mp4' % filename\n",
"\tinput_video_widget.clear_output(wait=True)\n",
"\twith input_video_widget:\n",
"\t\tdisplay(HTML('Video'))\n",
"\tinput_video_widget.remove_class('uploaded')\n",
"output.register_callback(\"notebook.select_video\", select_video)\n",
"\n",
"def resize(image, size=(256, 256)):\n",
"\tw, h = image.size\n",
"\td = min(w, h)\n",
"\tr = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n",
"\treturn image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n",
"\n",
"def upload_image(change):\n",
"\tglobal selected_image\n",
"\tfor name, file_info in upload_input_image_button.value.items():\n",
"\t\tcontent = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_image = resize(PIL.Image.open(io.BytesIO(content)).convert(\"RGB\"))\n",
"\t\tinput_image_widget.clear_output(wait=True)\n",
"\t\twith input_image_widget:\n",
"\t\t\tdisplay(selected_image)\n",
"\t\tinput_image_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectImages()'))\n",
"upload_input_image_button.observe(upload_image, names='value')\n",
"\n",
"def upload_video(change):\n",
"\tglobal selected_video\n",
"\tfor name, file_info in upload_input_video_button.value.items():\n",
"\t\tcontent = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_video = 'user/' + name\n",
"\t\twith open(selected_video, 'wb') as video:\n",
"\t\t\tvideo.write(content)\n",
"\t\tpreview = resize(PIL.Image.fromarray(thumbnail(selected_video)).convert(\"RGB\"))\n",
"\t\tinput_video_widget.clear_output(wait=True)\n",
"\t\twith input_video_widget:\n",
"\t\t\tdisplay(preview)\n",
"\t\tinput_video_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectVideos()'))\n",
"upload_input_video_button.observe(upload_video, names='value')\n",
"\n",
"def change_model(change):\n",
"\tif model.value.startswith('vox'):\n",
"\t\twarning.remove_class('warn')\n",
"\telse:\n",
"\t\twarning.add_class('warn')\n",
"model.observe(change_model, names='value')\n",
"\n",
"def generate(button):\n",
"\tmain.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfilename = model.value + ('' if model.value == 'fashion' else '-cpk') + '.pth.tar'\n",
"\tif not os.path.isfile(filename):\n",
"\t\tresponse = requests.get('https://github.com/graphemecluster/first-order-model-demo/releases/download/checkpoints/' + filename, stream=True)\n",
"\t\twith progress_bar:\n",
"\t\t\twith tqdm.wrapattr(response.raw, 'read', total=int(response.headers.get('Content-Length', 0)), unit='B', unit_scale=True, unit_divisor=1024) as raw:\n",
"\t\t\t\twith open(filename, 'wb') as file:\n",
"\t\t\t\t\tcopyfileobj(raw, file)\n",
"\t\tprogress_bar.clear_output()\n",
"\treader = imageio.get_reader(selected_video, mode='I', format='FFMPEG')\n",
"\tfps = reader.get_meta_data()['fps']\n",
"\tdriving_video = []\n",
"\tfor frame in reader:\n",
"\t\tdriving_video.append(frame)\n",
"\tgenerator, kp_detector = load_checkpoints(config_path='config/%s-256.yaml' % model.value, checkpoint_path=filename)\n",
"\twith progress_bar:\n",
"\t\tpredictions = make_animation(\n",
"\t\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n",
"\t\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n",
"\t\t\tgenerator,\n",
"\t\t\tkp_detector,\n",
"\t\t\trelative=relative.value,\n",
"\t\t\tadapt_movement_scale=adapt_movement_scale.value\n",
"\t\t)\n",
"\tprogress_bar.clear_output()\n",
"\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\ttry:\n",
"\t\twith NamedTemporaryFile(suffix='.mp4') as output:\n",
"\t\t\tffmpeg.output(ffmpeg.input('output.mp4').video, ffmpeg.input(selected_video).audio, output.name, c='copy').run()\n",
"\t\t\twith open('output.mp4', 'wb') as result:\n",
"\t\t\t\tcopyfileobj(output, result)\n",
"\texcept ffmpeg.Error:\n",
"\t\tpass\n",
"\toutput_widget.clear_output(True)\n",
"\twith output_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file('output.mp4', autoplay=False, loop=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-left')\n",
"\t\tdisplay(video_widget)\n",
"\tcomparison_widget.clear_output(True)\n",
"\twith comparison_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file(selected_video, autoplay=False, loop=False, controls=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-right')\n",
"\t\tdisplay(video_widget)\n",
"\tdisplay(Javascript(\"\"\"\n",
"\tsetTimeout(function() {\n",
"\t\t(function(left, right) {\n",
"\t\t\tleft.addEventListener(\"play\", function() {\n",
"\t\t\t\tright.play();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"pause\", function() {\n",
"\t\t\t\tright.pause();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"seeking\", function() {\n",
"\t\t\t\tright.currentTime = left.currentTime;\n",
"\t\t\t});\n",
"\t\t\tright.muted = true;\n",
"\t\t})(document.getElementsByClassName(\"video-left\")[0], document.getElementsByClassName(\"video-right\")[0]);\n",
"\t}, 1000);\n",
"\t\"\"\"))\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"generate_button.on_click(generate)\n",
"\n",
"loading.layout.display = 'none'\n",
"complete.layout.display = 'none'\n",
"select_image('00')\n",
"select_video('0')"
],
"execution_count": null,
"outputs": []
}
]
}