You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

805 lines
28 KiB

5 months ago
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a><a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cdO_RxQZLahB"
},
"source": [
"# Demo for paper \"First Order Motion Model for Image Animation\"\n",
"To try the demo, press the 2 play buttons in order and scroll to the bottom. Note that it may take several minutes to load."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "UCMFMJV7K-ag"
},
"outputs": [],
"source": [
"%%capture\n",
"%pip install ffmpeg-python imageio-ffmpeg\n",
"!git init .\n",
"!git remote add origin https://github.com/AliaksandrSiarohin/first-order-model\n",
"!git pull origin master\n",
"!git clone https://github.com/graphemecluster/first-order-model-demo demo"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {
"id": "Oxi6-riLOgnm"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
".widget-box > * {\n",
"\tflex-shrink: 0;\n",
"}\n",
".widget-tab {\n",
"\tmin-width: 0;\n",
"\tflex: 1 1 auto;\n",
"}\n",
".widget-tab .p-TabBar-tabLabel {\n",
"\tfont-size: 15px;\n",
"}\n",
".widget-upload {\n",
"\tbackground-color: tan;\n",
"}\n",
".widget-button {\n",
"\tfont-size: 18px;\n",
"\twidth: 160px;\n",
"\theight: 34px;\n",
"\tline-height: 34px;\n",
"}\n",
".widget-dropdown {\n",
"\twidth: 250px;\n",
"}\n",
".widget-checkbox {\n",
"\twidth: 650px;\n",
"}\n",
".widget-checkbox + .widget-checkbox {\n",
"\tmargin-top: -6px;\n",
"}\n",
".input-widget .output_html {\n",
"\ttext-align: center;\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tline-height: 266px;\n",
"\tcolor: lightgray;\n",
"\tfont-size: 72px;\n",
"}\n",
".title {\n",
"\tfont-size: 20px;\n",
"\tfont-weight: bold;\n",
"\tmargin: 12px 0 6px 0;\n",
"}\n",
".warning {\n",
"\tdisplay: none;\n",
"\tcolor: red;\n",
"\tmargin-left: 10px;\n",
"}\n",
".warn {\n",
"\tdisplay: initial;\n",
"}\n",
".resource {\n",
"\tcursor: pointer;\n",
"\tborder: 1px solid gray;\n",
"\tmargin: 5px;\n",
"\twidth: 160px;\n",
"\theight: 160px;\n",
"\tmin-width: 160px;\n",
"\tmin-height: 160px;\n",
"\tmax-width: 160px;\n",
"\tmax-height: 160px;\n",
"\t-webkit-box-sizing: initial;\n",
"\tbox-sizing: initial;\n",
"}\n",
".resource:hover {\n",
"\tborder: 6px solid crimson;\n",
"\tmargin: 0;\n",
"}\n",
".selected {\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".input-widget {\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".input-button {\n",
"\twidth: 268px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".output-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".output-button {\n",
"\twidth: 258px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".uploaded {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".label-or {\n",
"\talign-self: center;\n",
"\tfont-size: 20px;\n",
"\tmargin: 16px;\n",
"}\n",
".loading {\n",
"\talign-items: center;\n",
"\twidth: fit-content;\n",
"}\n",
".loader {\n",
"\tmargin: 32px 0 16px 0;\n",
"\twidth: 48px;\n",
"\theight: 48px;\n",
"\tmin-width: 48px;\n",
"\tmin-height: 48px;\n",
"\tmax-width: 48px;\n",
"\tmax-height: 48px;\n",
"\tborder: 4px solid whitesmoke;\n",
"\tborder-top-color: gray;\n",
"\tborder-radius: 50%;\n",
"\tanimation: spin 1.8s linear infinite;\n",
"}\n",
".loading-label {\n",
"\tcolor: gray;\n",
"}\n",
".video {\n",
"\tmargin: 0;\n",
"}\n",
".comparison-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"\tmargin-left: 2px;\n",
"}\n",
".comparison-label {\n",
"\tcolor: gray;\n",
"\tfont-size: 14px;\n",
"\ttext-align: center;\n",
"\tposition: relative;\n",
"\tbottom: 3px;\n",
"}\n",
"@keyframes spin {\n",
"\tfrom { transform: rotate(0deg); }\n",
"\tto { transform: rotate(360deg); }\n",
"}\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f53c7ccd3ec34f7ea8491237d5bf03ff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(VBox(children=(Label(value='Choose Image', _dom_classes=('title',)), HBox(children=(VBox(childr…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
"var images, videos;\n",
"function deselectImages() {\n",
"\timages.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function deselectVideos() {\n",
"\tvideos.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function invokePython(func) {\n",
"\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n",
"}\n",
"setTimeout(function() {\n",
"\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectImages();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\timages[0].classList.add(\"selected\");\n",
"\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectVideos();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\tvideos[0].classList.add(\"selected\");\n",
"}, 1000);\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import IPython.display\n",
"import PIL.Image\n",
"import cv2\n",
"import ffmpeg\n",
"import imageio\n",
"import io\n",
"import ipywidgets\n",
"import numpy\n",
"import os.path\n",
"import requests\n",
"import skimage.transform\n",
"import warnings\n",
"from base64 import b64encode\n",
"from demo import load_checkpoints, make_animation # type: ignore (local file)\n",
"from IPython.display import HTML, Javascript\n",
"from shutil import copyfileobj\n",
"from skimage import img_as_ubyte\n",
"from tempfile import NamedTemporaryFile\n",
"import os\n",
"import ipywidgets as ipyw\n",
"from IPython.display import display, FileLink\n",
"warnings.filterwarnings(\"ignore\")\n",
"os.makedirs(\"user\", exist_ok=True)\n",
"\n",
"display(HTML(\"\"\"\n",
"<style>\n",
".widget-box > * {\n",
"\tflex-shrink: 0;\n",
"}\n",
".widget-tab {\n",
"\tmin-width: 0;\n",
"\tflex: 1 1 auto;\n",
"}\n",
".widget-tab .p-TabBar-tabLabel {\n",
"\tfont-size: 15px;\n",
"}\n",
".widget-upload {\n",
"\tbackground-color: tan;\n",
"}\n",
".widget-button {\n",
"\tfont-size: 18px;\n",
"\twidth: 160px;\n",
"\theight: 34px;\n",
"\tline-height: 34px;\n",
"}\n",
".widget-dropdown {\n",
"\twidth: 250px;\n",
"}\n",
".widget-checkbox {\n",
"\twidth: 650px;\n",
"}\n",
".widget-checkbox + .widget-checkbox {\n",
"\tmargin-top: -6px;\n",
"}\n",
".input-widget .output_html {\n",
"\ttext-align: center;\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tline-height: 266px;\n",
"\tcolor: lightgray;\n",
"\tfont-size: 72px;\n",
"}\n",
".title {\n",
"\tfont-size: 20px;\n",
"\tfont-weight: bold;\n",
"\tmargin: 12px 0 6px 0;\n",
"}\n",
".warning {\n",
"\tdisplay: none;\n",
"\tcolor: red;\n",
"\tmargin-left: 10px;\n",
"}\n",
".warn {\n",
"\tdisplay: initial;\n",
"}\n",
".resource {\n",
"\tcursor: pointer;\n",
"\tborder: 1px solid gray;\n",
"\tmargin: 5px;\n",
"\twidth: 160px;\n",
"\theight: 160px;\n",
"\tmin-width: 160px;\n",
"\tmin-height: 160px;\n",
"\tmax-width: 160px;\n",
"\tmax-height: 160px;\n",
"\t-webkit-box-sizing: initial;\n",
"\tbox-sizing: initial;\n",
"}\n",
".resource:hover {\n",
"\tborder: 6px solid crimson;\n",
"\tmargin: 0;\n",
"}\n",
".selected {\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".input-widget {\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".input-button {\n",
"\twidth: 268px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".output-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".output-button {\n",
"\twidth: 258px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".uploaded {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".label-or {\n",
"\talign-self: center;\n",
"\tfont-size: 20px;\n",
"\tmargin: 16px;\n",
"}\n",
".loading {\n",
"\talign-items: center;\n",
"\twidth: fit-content;\n",
"}\n",
".loader {\n",
"\tmargin: 32px 0 16px 0;\n",
"\twidth: 48px;\n",
"\theight: 48px;\n",
"\tmin-width: 48px;\n",
"\tmin-height: 48px;\n",
"\tmax-width: 48px;\n",
"\tmax-height: 48px;\n",
"\tborder: 4px solid whitesmoke;\n",
"\tborder-top-color: gray;\n",
"\tborder-radius: 50%;\n",
"\tanimation: spin 1.8s linear infinite;\n",
"}\n",
".loading-label {\n",
"\tcolor: gray;\n",
"}\n",
".video {\n",
"\tmargin: 0;\n",
"}\n",
".comparison-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"\tmargin-left: 2px;\n",
"}\n",
".comparison-label {\n",
"\tcolor: gray;\n",
"\tfont-size: 14px;\n",
"\ttext-align: center;\n",
"\tposition: relative;\n",
"\tbottom: 3px;\n",
"}\n",
"@keyframes spin {\n",
"\tfrom { transform: rotate(0deg); }\n",
"\tto { transform: rotate(360deg); }\n",
"}\n",
"</style>\n",
"\"\"\"))\n",
"\n",
"\n",
"def uploaded_file(change):\n",
" save_dir = 'uploads'\n",
" if not os.path.exists(save_dir): os.mkdir(save_dir)\n",
" \n",
" uploads = change['new']\n",
" for upload in uploads:\n",
" filename = upload['name']\n",
" content = upload['content']\n",
" with open(os.path.join(save_dir,filename), 'wb') as f:\n",
" f.write(content)\n",
" with out:\n",
" print(change)\n",
" \n",
"def create_uploader():\n",
" uploader = ipyw.FileUpload(multiple=True)\n",
" display(uploader)\n",
" uploader.description = '📂 Upload'\n",
" uploader.observe(uploaded_file, names='value')\n",
"\n",
"def download_file(filename='./face.mp4') -> HTML:\n",
" fl=FileLink(filename)\n",
" fl.html_link_str =\"<a href='%s' target='' class='downloadLink'>%s</a>\"\n",
" \n",
" display(fl)\n",
" display(HTML(f\"\"\"\n",
"<script>\n",
" var links = document.getElementsByClassName('downloadLink');\n",
" Array.from(links).map(e => e.setAttribute('download', ''))\n",
" var links = document.getElementsByClassName('downloadLink');\n",
" ['data-commandlinker-args','data-commandlinker-command'].map(e => {{ Array.from(links).map(i => {{ i.removeAttribute(e) }} ) }})\n",
" links[0].click()\n",
"</script>\n",
"\"\"\"))\n",
" \n",
"def thumbnail(file):\n",
"\treturn imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()\n",
"\n",
"def create_image(i, j):\n",
"\timage_widget = ipywidgets.Image.from_file('demo/images/%d%d.png' % (i, j))\n",
"\timage_widget.add_class('resource')\n",
"\timage_widget.add_class('resource-image')\n",
"\timage_widget.add_class('resource-image%d%d' % (i, j))\n",
"\treturn image_widget\n",
"\n",
"def create_video(i):\n",
"\tvideo_widget = ipywidgets.Image(\n",
"\t\tvalue=cv2.imencode('.png', cv2.cvtColor(thumbnail('demo/videos/%d.mp4' % i), cv2.COLOR_RGB2BGR))[1].tostring(),\n",
"\t\tformat='png'\n",
"\t)\n",
"\tvideo_widget.add_class('resource')\n",
"\tvideo_widget.add_class('resource-video')\n",
"\tvideo_widget.add_class('resource-video%d' % i)\n",
"\treturn video_widget\n",
"\n",
"def create_title(title):\n",
"\ttitle_widget = ipywidgets.Label(title)\n",
"\ttitle_widget.add_class('title')\n",
"\treturn title_widget\n",
"\n",
"def download_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tdownload_file('./output.mp4')\n",
"\t# files.download('output.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def convert_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tffmpeg.input('output.mp4').output('scaled.mp4', vf='scale=1080x1080:flags=lanczos,pad=1920:1080:420:0').overwrite_output().run()\n",
"\tfiles.download('scaled.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def back_to_main(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tmain.layout.display = ''\n",
"\n",
"label_or = ipywidgets.Label('or')\n",
"label_or.add_class('label-or')\n",
"\n",
"image_titles = ['Peoples', 'Cartoons', 'Dolls', 'Game of Thrones', 'Statues']\n",
"image_lengths = [8, 4, 8, 9, 4]\n",
"\n",
"image_tab = ipywidgets.Tab()\n",
"image_tab.children = [ipywidgets.HBox([create_image(i, j) for j in range(length)]) for i, length in enumerate(image_lengths)]\n",
"for i, title in enumerate(image_titles):\n",
"\timage_tab.set_title(i, title)\n",
"\n",
"input_image_widget = ipywidgets.Output()\n",
"input_image_widget.add_class('input-widget')\n",
"upload_input_image_button = ipywidgets.FileUpload(accept='image/*', button_style='primary')\n",
"upload_input_image_button.add_class('input-button')\n",
"image_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_image_widget, upload_input_image_button]),\n",
"\tlabel_or,\n",
"\timage_tab\n",
"])\n",
"\n",
"video_tab = ipywidgets.Tab()\n",
"video_tab.children = [ipywidgets.HBox([create_video(i) for i in range(5)])]\n",
"video_tab.set_title(0, 'All Videos')\n",
"\n",
"input_video_widget = ipywidgets.Output()\n",
"input_video_widget.add_class('input-widget')\n",
"upload_input_video_button = ipywidgets.FileUpload(accept='video/*', button_style='primary')\n",
"upload_input_video_button.add_class('input-button')\n",
"video_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_video_widget, upload_input_video_button]),\n",
"\tlabel_or,\n",
"\tvideo_tab\n",
"])\n",
"\n",
"model = ipywidgets.Dropdown(\n",
"\tdescription=\"Model:\",\n",
"\toptions=[\n",
"\t\t'vox',\n",
"\t\t'vox-adv',\n",
"\t\t'taichi',\n",
"\t\t'taichi-adv',\n",
"\t\t'nemo',\n",
"\t\t'mgif',\n",
"\t\t'fashion',\n",
"\t\t'bair'\n",
"\t]\n",
")\n",
"warning = ipywidgets.HTML('<b>Warning:</b> Upload your own images and videos (see README)')\n",
"warning.add_class('warning')\n",
"model_part = ipywidgets.HBox([model, warning])\n",
"\n",
"relative = ipywidgets.Checkbox(description=\"Relative keypoint displacement (Inherit object proporions from the video)\", value=True)\n",
"adapt_movement_scale = ipywidgets.Checkbox(description=\"Adapt movement scale (Dont touch unless you know want you are doing)\", value=True)\n",
"generate_button = ipywidgets.Button(description=\"Generate\", button_style='primary')\n",
"main = ipywidgets.VBox([\n",
"\tcreate_title('Choose Image'),\n",
"\timage_part,\n",
"\tcreate_title('Choose Video'),\n",
"\tvideo_part,\n",
"\tcreate_title('Settings'),\n",
"\tmodel_part,\n",
"\trelative,\n",
"\tadapt_movement_scale,\n",
"\tgenerate_button\n",
"])\n",
"\n",
"loader = ipywidgets.Label()\n",
"loader.add_class(\"loader\")\n",
"loading_label = ipywidgets.Label(\"This may take several minutes to process…\")\n",
"loading_label.add_class(\"loading-label\")\n",
"progress_bar = ipywidgets.Output()\n",
"loading = ipywidgets.VBox([loader, loading_label, progress_bar])\n",
"loading.add_class('loading')\n",
"\n",
"output_widget = ipywidgets.Output()\n",
"output_widget.add_class('output-widget')\n",
"download = ipywidgets.Button(description='Download', button_style='primary')\n",
"download.add_class('output-button')\n",
"download.on_click(download_output)\n",
"convert = ipywidgets.Button(description='Convert to 1920×1080', button_style='primary')\n",
"convert.add_class('output-button')\n",
"convert.on_click(convert_output)\n",
"back = ipywidgets.Button(description='Back', button_style='primary')\n",
"back.add_class('output-button')\n",
"back.on_click(back_to_main)\n",
"\n",
"comparison_widget = ipywidgets.Output()\n",
"comparison_widget.add_class('comparison-widget')\n",
"comparison_label = ipywidgets.Label('Comparison')\n",
"comparison_label.add_class('comparison-label')\n",
"complete = ipywidgets.HBox([\n",
"\tipywidgets.VBox([output_widget, download, convert, back]),\n",
"\tipywidgets.VBox([comparison_widget, comparison_label])\n",
"])\n",
"\n",
"display(ipywidgets.VBox([main, loading, complete]))\n",
"display(Javascript(\"\"\"\n",
"var images, videos;\n",
"function deselectImages() {\n",
"\timages.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function deselectVideos() {\n",
"\tvideos.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function invokePython(func) {\n",
"\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n",
"}\n",
"setTimeout(function() {\n",
"\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectImages();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\timages[0].classList.add(\"selected\");\n",
"\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectVideos();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\tvideos[0].classList.add(\"selected\");\n",
"}, 1000);\n",
"\"\"\"))\n",
"\n",
"selected_image = None\n",
"def select_image(filename):\n",
"\tglobal selected_image\n",
"\tselected_image = resize(PIL.Image.open('demo/images/%s.png' % filename).convert(\"RGB\"))\n",
"\tinput_image_widget.clear_output(wait=True)\n",
"\twith input_image_widget:\n",
"\t\tdisplay(HTML('Image'))\n",
"\tinput_image_widget.remove_class('uploaded')\n",
"# output.register_callback(\"notebook.select_image\", select_image)\n",
"\n",
"selected_video = None\n",
"def select_video(filename):\n",
"\tglobal selected_video\n",
"\tselected_video = 'demo/videos/%s.mp4' % filename\n",
"\tinput_video_widget.clear_output(wait=True)\n",
"\twith input_video_widget:\n",
"\t\tdisplay(HTML('Video'))\n",
"\tinput_video_widget.remove_class('uploaded')\n",
"# output.register_callback(\"notebook.select_video\", select_video)\n",
"\n",
"def resize(image, size=(256, 256)):\n",
"\tw, h = image.size\n",
"\td = min(w, h)\n",
"\tr = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n",
"\treturn image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n",
"\n",
"def upload_image(change):\n",
"\tglobal selected_image\n",
"\tcontent = upload_input_image_button.value[0]['content']\n",
"\tname = upload_input_image_button.value[0]['name']\n",
" \n",
"\t# for name, file_info in upload_input_image_button.value.items():\n",
"\t\t# content = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_image = resize(PIL.Image.open(io.BytesIO(content)).convert(\"RGB\"))\n",
"\t\tinput_image_widget.clear_output(wait=True)\n",
"\t\twith input_image_widget:\n",
"\t\t\tdisplay(selected_image)\n",
"\t\tinput_image_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectImages()'))\n",
"upload_input_image_button.observe(upload_image, names='value')\n",
"\n",
"def upload_video(change):\n",
"\tglobal selected_video\n",
"\t# for name, file_info in upload_input_video_button.value.items():\n",
"\t\t# content = file_info['content']\n",
"\tcontent = upload_input_video_button.value[0]['content']\n",
"\tname = upload_input_video_button.value[0]['name']\n",
"\tif content is not None:\n",
"\t\tselected_video = 'user/' + name\n",
"\t\tpreview = resize(PIL.Image.fromarray(thumbnail(content)).convert(\"RGB\"))\n",
"\t\tinput_video_widget.clear_output(wait=True)\n",
"\t\twith input_video_widget:\n",
"\t\t\tdisplay(preview)\n",
"\t\tinput_video_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectVideos()'))\n",
"\t\twith open(selected_video, 'wb') as video:\n",
"\t\t\tvideo.write(content)\n",
"upload_input_video_button.observe(upload_video, names='value')\n",
"\n",
"def change_model(change):\n",
"\tif model.value.startswith('vox'):\n",
"\t\twarning.remove_class('warn')\n",
"\telse:\n",
"\t\twarning.add_class('warn')\n",
"model.observe(change_model, names='value')\n",
"\n",
"def generate(button):\n",
"\tmain.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfilename = model.value + ('' if model.value == 'fashion' else '-cpk') + '.pth.tar'\n",
"\tif not os.path.isfile(filename):\n",
"\t\tdownload = requests.get(requests.get('https://cloud-api.yandex.net/v1/disk/public/resources/download?public_key=https://yadi.sk/d/lEw8uRm140L_eQ&path=/' + filename).json().get('href'))\n",
"\t\twith open(filename, 'wb') as checkpoint:\n",
"\t\t\tcheckpoint.write(download.content)\n",
"\treader = imageio.get_reader(selected_video, mode='I', format='FFMPEG')\n",
"\tfps = reader.get_meta_data()['fps']\n",
"\tdriving_video = []\n",
"\tfor frame in reader:\n",
"\t\tdriving_video.append(frame)\n",
"\tgenerator, kp_detector = load_checkpoints(config_path='config/%s-256.yaml' % model.value, checkpoint_path=filename)\n",
"\twith progress_bar:\n",
"\t\tpredictions = make_animation(\n",
"\t\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n",
"\t\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n",
"\t\t\tgenerator,\n",
"\t\t\tkp_detector,\n",
"\t\t\trelative=relative.value,\n",
"\t\t\tadapt_movement_scale=adapt_movement_scale.value\n",
"\t\t)\n",
"\tprogress_bar.clear_output()\n",
"\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\tif selected_video.startswith('user/') or selected_video == 'demo/videos/0.mp4':\n",
"\t\twith NamedTemporaryFile(suffix='.mp4') as output:\n",
"\t\t\tffmpeg.output(ffmpeg.input('output.mp4').video, ffmpeg.input(selected_video).audio, output.name, c='copy').overwrite_output().run()\n",
"\t\t\twith open('output.mp4', 'wb') as result:\n",
"\t\t\t\tcopyfileobj(output, result)\n",
"\twith output_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file('output.mp4', autoplay=False, loop=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-left')\n",
"\t\tdisplay(video_widget)\n",
"\twith comparison_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file(selected_video, autoplay=False, loop=False, controls=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-right')\n",
"\t\tdisplay(video_widget)\n",
"\tdisplay(Javascript(\"\"\"\n",
"\tsetTimeout(function() {\n",
"\t\t(function(left, right) {\n",
"\t\t\tleft.addEventListener(\"play\", function() {\n",
"\t\t\t\tright.play();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"pause\", function() {\n",
"\t\t\t\tright.pause();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"seeking\", function() {\n",
"\t\t\t\tright.currentTime = left.currentTime;\n",
"\t\t\t});\n",
"\t\t\tright.muted = true;\n",
"\t\t})(document.getElementsByClassName(\"video-left\")[0], document.getElementsByClassName(\"video-right\")[0]);\n",
"\t}, 1000);\n",
"\t\"\"\"))\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"generate_button.on_click(generate)\n",
"\n",
"loading.layout.display = 'none'\n",
"complete.layout.display = 'none'\n",
"select_image('00')\n",
"select_video('0')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "first-order-model-demo",
"provenance": []
},
"kernelspec": {
"display_name": "ldm",
"language": "python",
"name": "ldm"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}