@ -1,53 +1,5 @@
|
||||
# ---> Eagle
|
||||
# Ignore list for Eagle, a PCB layout tool
|
||||
|
||||
# Backup files
|
||||
*.s#?
|
||||
*.b#?
|
||||
*.l#?
|
||||
*.b$?
|
||||
*.s$?
|
||||
*.l$?
|
||||
|
||||
# Eagle project file
|
||||
# It contains a serial number and references to the file structure
|
||||
# on your computer.
|
||||
# comment the following line if you want to have your project file included.
|
||||
eagle.epf
|
||||
|
||||
# Autorouter files
|
||||
*.pro
|
||||
*.job
|
||||
|
||||
# CAM files
|
||||
*.$$$
|
||||
*.cmp
|
||||
*.ly2
|
||||
*.l15
|
||||
*.sol
|
||||
*.plc
|
||||
*.stc
|
||||
*.sts
|
||||
*.crc
|
||||
*.crs
|
||||
|
||||
*.dri
|
||||
*.drl
|
||||
*.gpi
|
||||
*.pls
|
||||
*.ger
|
||||
*.xln
|
||||
|
||||
*.drd
|
||||
*.drd.*
|
||||
|
||||
*.s#*
|
||||
*.b#*
|
||||
|
||||
*.info
|
||||
|
||||
*.eps
|
||||
|
||||
# file locks introduced since 7.x
|
||||
*.lck
|
||||
|
||||
*.pyc
|
||||
tags
|
||||
test
|
||||
models
|
||||
.idea
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2017 Christopher Hesse
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@ -1,19 +1,262 @@
|
||||
#### 从命令行创建一个新的仓库
|
||||
# pix2pix-tensorflow
|
||||
|
||||
```bash
|
||||
touch README.md
|
||||
git init
|
||||
git add README.md
|
||||
git commit -m "first commit"
|
||||
git remote add origin https://bdgit.educoder.net/ZhengHui/pix2pix.git
|
||||
git push -u origin master
|
||||
Based on [pix2pix](https://phillipi.github.io/pix2pix/) by Isola et al.
|
||||
|
||||
[Article about this implemention](https://affinelayer.com/pix2pix/)
|
||||
|
||||
[Interactive Demo](https://affinelayer.com/pixsrv/)
|
||||
|
||||
Tensorflow implementation of pix2pix. Learns a mapping from input images to output images, like these examples from the original paper:
|
||||
|
||||
<img src="docs/examples.jpg" width="900px"/>
|
||||
|
||||
This port is based directly on the torch implementation, and not on an existing Tensorflow implementation. It is meant to be a faithful implementation of the original work and so does not add anything. The processing speed on a GPU with cuDNN was equivalent to the Torch implementation in testing.
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
- Tensorflow 1.4.1
|
||||
|
||||
### Recommended
|
||||
- Linux with Tensorflow GPU edition + cuDNN
|
||||
|
||||
### Getting Started
|
||||
|
||||
```sh
|
||||
# clone this repo
|
||||
git clone https://github.com/affinelayer/pix2pix-tensorflow.git
|
||||
cd pix2pix-tensorflow
|
||||
# download the CMP Facades dataset (generated from http://cmp.felk.cvut.cz/~tylecr1/facade/)
|
||||
python tools/download-dataset.py facades
|
||||
# train the model (this may take 1-8 hours depending on GPU, on CPU you will be waiting for a bit)
|
||||
python pix2pix.py \
|
||||
--mode train \
|
||||
--output_dir facades_train \
|
||||
--max_epochs 200 \
|
||||
--input_dir facades/train \
|
||||
--which_direction BtoA
|
||||
# test the model
|
||||
python pix2pix.py \
|
||||
--mode test \
|
||||
--output_dir facades_test \
|
||||
--input_dir facades/val \
|
||||
--checkpoint facades_train
|
||||
```
|
||||
|
||||
The test run will output an HTML file at `facades_test/index.html` that shows input/output/target image sets.
|
||||
|
||||
If you have Docker installed, you can use the provided Docker image to run pix2pix without installing the correct version of Tensorflow:
|
||||
|
||||
```sh
|
||||
# train the model
|
||||
python tools/dockrun.py python pix2pix.py \
|
||||
--mode train \
|
||||
--output_dir facades_train \
|
||||
--max_epochs 200 \
|
||||
--input_dir facades/train \
|
||||
--which_direction BtoA
|
||||
# test the model
|
||||
python tools/dockrun.py python pix2pix.py \
|
||||
--mode test \
|
||||
--output_dir facades_test \
|
||||
--input_dir facades/val \
|
||||
--checkpoint facades_train
|
||||
```
|
||||
|
||||
## Datasets and Trained Models
|
||||
|
||||
The data format used by this program is the same as the original pix2pix format, which consists of images of input and desired output side by side like:
|
||||
|
||||
<img src="docs/ab.png" width="256px"/>
|
||||
|
||||
For example:
|
||||
|
||||
<img src="docs/418.png" width="256px"/>
|
||||
|
||||
Some datasets have been made available by the authors of the pix2pix paper. To download those datasets, use the included script `tools/download-dataset.py`. There are also links to pre-trained models alongside each dataset, note that these pre-trained models require the current version of pix2pix.py:
|
||||
|
||||
| dataset | example |
|
||||
| --- | --- |
|
||||
| `python tools/download-dataset.py facades` <br> 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). (31MB) <br> Pre-trained: [BtoA](https://mega.nz/#!H0AmER7Y!pBHcH4M11eiHBmJEWvGr-E_jxK4jluKBUlbfyLSKgpY) | <img src="docs/facades.jpg" width="256px"/> |
|
||||
| `python tools/download-dataset.py cityscapes` <br> 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). (113M) <br> Pre-trained: [AtoB](https://mega.nz/#!K1hXlbJA!rrZuEnL3nqOcRhjb-AnSkK0Ggf9NibhDymLOkhzwuQk) [BtoA](https://mega.nz/#!y1YxxB5D!1817IXQFcydjDdhk_ILbCourhA6WSYRttKLrGE97q7k) | <img src="docs/cityscapes.jpg" width="256px"/> |
|
||||
| `python tools/download-dataset.py maps` <br> 1096 training images scraped from Google Maps (246M) <br> Pre-trained: [AtoB](https://mega.nz/#!7oxklCzZ!8fRZoF3jMRS_rylCfw2RNBeewp4DFPVE_tSCjCKr-TI) [BtoA](https://mega.nz/#!S4AGzQJD!UH7B5SV7DJSTqKvtbFKqFkjdAh60kpdhTk9WerI-Q1I) | <img src="docs/maps.jpg" width="256px"/> |
|
||||
| `python tools/download-dataset.py edges2shoes` <br> 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (2.2GB) <br> Pre-trained: [AtoB](https://mega.nz/#!u9pnmC4Q!2uHCZvHsCkHBJhHZ7xo5wI-mfekTwOK8hFPy0uBOrb4) | <img src="docs/edges2shoes.jpg" width="256px"/> |
|
||||
| `python tools/download-dataset.py edges2handbags` <br> 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (8.6GB) <br> Pre-trained: [AtoB](https://mega.nz/#!G1xlDCIS!sFDN3ZXKLUWU1TX6Kqt7UG4Yp-eLcinmf6HVRuSHjrM) | <img src="docs/edges2handbags.jpg" width="256px"/> |
|
||||
|
||||
The `facades` dataset is the smallest and easiest to get started with.
|
||||
|
||||
### Creating your own dataset
|
||||
|
||||
#### Example: creating images with blank centers for [inpainting](https://people.eecs.berkeley.edu/~pathak/context_encoder/)
|
||||
|
||||
<img src="docs/combine.png" width="900px"/>
|
||||
|
||||
```sh
|
||||
# Resize source images
|
||||
python tools/process.py \
|
||||
--input_dir photos/original \
|
||||
--operation resize \
|
||||
--output_dir photos/resized
|
||||
# Create images with blank centers
|
||||
python tools/process.py \
|
||||
--input_dir photos/resized \
|
||||
--operation blank \
|
||||
--output_dir photos/blank
|
||||
# Combine resized images with blanked images
|
||||
python tools/process.py \
|
||||
--input_dir photos/resized \
|
||||
--b_dir photos/blank \
|
||||
--operation combine \
|
||||
--output_dir photos/combined
|
||||
# Split into train/val set
|
||||
python tools/split.py \
|
||||
--dir photos/combined
|
||||
```
|
||||
|
||||
The folder `photos/combined` will now have `train` and `val` subfolders that you can use for training and testing.
|
||||
|
||||
#### Creating image pairs from existing images
|
||||
|
||||
If you have two directories `a` and `b`, with corresponding images (same name, same dimensions, different data) you can combine them with `process.py`:
|
||||
|
||||
```sh
|
||||
python tools/process.py \
|
||||
--input_dir a \
|
||||
--b_dir b \
|
||||
--operation combine \
|
||||
--output_dir c
|
||||
```
|
||||
|
||||
This puts the images in a side-by-side combined image that `pix2pix.py` expects.
|
||||
|
||||
#### Colorization
|
||||
|
||||
For colorization, your images should ideally all be the same aspect ratio. You can resize and crop them with the resize command:
|
||||
```sh
|
||||
python tools/process.py \
|
||||
--input_dir photos/original \
|
||||
--operation resize \
|
||||
--output_dir photos/resized
|
||||
```
|
||||
|
||||
#### 从命令行推送已经创建的仓库
|
||||
No other processing is required, the colorization mode (see Training section below) uses single images instead of image pairs.
|
||||
|
||||
## Training
|
||||
|
||||
### Image Pairs
|
||||
|
||||
For normal training with image pairs, you need to specify which directory contains the training images, and which direction to train on. The direction options are `AtoB` or `BtoA`
|
||||
```sh
|
||||
python pix2pix.py \
|
||||
--mode train \
|
||||
--output_dir facades_train \
|
||||
--max_epochs 200 \
|
||||
--input_dir facades/train \
|
||||
--which_direction BtoA
|
||||
```
|
||||
|
||||
### Colorization
|
||||
|
||||
`pix2pix.py` includes special code to handle colorization with single images instead of pairs, using that looks like this:
|
||||
|
||||
```sh
|
||||
python pix2pix.py \
|
||||
--mode train \
|
||||
--output_dir photos_train \
|
||||
--max_epochs 200 \
|
||||
--input_dir photos/train \
|
||||
--lab_colorization
|
||||
```
|
||||
|
||||
In this mode, image A is the black and white image (lightness only), and image B contains the color channels of that image (no lightness information).
|
||||
|
||||
### Tips
|
||||
|
||||
You can look at the loss and computation graph using tensorboard:
|
||||
```sh
|
||||
tensorboard --logdir=facades_train
|
||||
```
|
||||
|
||||
<img src="docs/tensorboard-scalar.png" width="250px"/> <img src="docs/tensorboard-image.png" width="250px"/> <img src="docs/tensorboard-graph.png" width="250px"/>
|
||||
|
||||
If you wish to write in-progress pictures as the network is training, use `--display_freq 50`. This will update `facades_train/index.html` every 50 steps with the current training inputs and outputs.
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
git remote add origin https://bdgit.educoder.net/ZhengHui/pix2pix.git
|
||||
git push -u origin master
|
||||
Testing is done with `--mode test`. You should specify the checkpoint to use with `--checkpoint`, this should point to the `output_dir` that you created previously with `--mode train`:
|
||||
|
||||
```sh
|
||||
python pix2pix.py \
|
||||
--mode test \
|
||||
--output_dir facades_test \
|
||||
--input_dir facades/val \
|
||||
--checkpoint facades_train
|
||||
```
|
||||
|
||||
The testing mode will load some of the configuration options from the checkpoint provided so you do not need to specify `which_direction` for instance.
|
||||
|
||||
The test run will output an HTML file at `facades_test/index.html` that shows input/output/target image sets:
|
||||
|
||||
<img src="docs/test-html.png" width="300px"/>
|
||||
|
||||
## Code Validation
|
||||
|
||||
Validation of the code was performed on a Linux machine with a ~1.3 TFLOPS Nvidia GTX 750 Ti GPU and an Azure NC6 instance with a K80 GPU.
|
||||
|
||||
```sh
|
||||
git clone https://github.com/affinelayer/pix2pix-tensorflow.git
|
||||
cd pix2pix-tensorflow
|
||||
python tools/download-dataset.py facades
|
||||
sudo nvidia-docker run \
|
||||
--volume $PWD:/prj \
|
||||
--workdir /prj \
|
||||
--env PYTHONUNBUFFERED=x \
|
||||
affinelayer/pix2pix-tensorflow \
|
||||
python pix2pix.py \
|
||||
--mode train \
|
||||
--output_dir facades_train \
|
||||
--max_epochs 200 \
|
||||
--input_dir facades/train \
|
||||
--which_direction BtoA
|
||||
sudo nvidia-docker run \
|
||||
--volume $PWD:/prj \
|
||||
--workdir /prj \
|
||||
--env PYTHONUNBUFFERED=x \
|
||||
affinelayer/pix2pix-tensorflow \
|
||||
python pix2pix.py \
|
||||
--mode test \
|
||||
--output_dir facades_test \
|
||||
--input_dir facades/val \
|
||||
--checkpoint facades_train
|
||||
```
|
||||
|
||||
Comparison on facades dataset:
|
||||
|
||||
| Input | Tensorflow | Torch | Target |
|
||||
| --- | --- | --- | --- |
|
||||
| <img src="docs/1-inputs.png" width="256px"> | <img src="docs/1-tensorflow.png" width="256px"> | <img src="docs/1-torch.jpg" width="256px"> | <img src="docs/1-targets.png" width="256px"> |
|
||||
| <img src="docs/5-inputs.png" width="256px"> | <img src="docs/5-tensorflow.png" width="256px"> | <img src="docs/5-torch.jpg" width="256px"> | <img src="docs/5-targets.png" width="256px"> |
|
||||
| <img src="docs/51-inputs.png" width="256px"> | <img src="docs/51-tensorflow.png" width="256px"> | <img src="docs/51-torch.jpg" width="256px"> | <img src="docs/51-targets.png" width="256px"> |
|
||||
| <img src="docs/95-inputs.png" width="256px"> | <img src="docs/95-tensorflow.png" width="256px"> | <img src="docs/95-torch.jpg" width="256px"> | <img src="docs/95-targets.png" width="256px"> |
|
||||
|
||||
## Unimplemented Features
|
||||
|
||||
The following models have not been implemented:
|
||||
- defineG_encoder_decoder
|
||||
- defineG_unet_128
|
||||
- defineD_pixelGAN
|
||||
|
||||
## Citation
|
||||
If you use this code for your research, please cite the paper this code is based on: <a href="https://arxiv.org/pdf/1611.07004v1.pdf">Image-to-Image Translation Using Conditional Adversarial Networks</a>:
|
||||
|
||||
```
|
||||
@article{pix2pix2016,
|
||||
title={Image-to-Image Translation with Conditional Adversarial Networks},
|
||||
author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
|
||||
journal={arxiv},
|
||||
year={2016}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgments
|
||||
This is a port of [pix2pix](https://github.com/phillipi/pix2pix) from Torch to Tensorflow. It also contains colorspace conversion code ported from Torch. Thanks to the Tensorflow team for making such a quality library! And special thanks to Phillip Isola for answering my questions about the pix2pix code.
|
||||
|
||||
@ -0,0 +1,119 @@
|
||||
FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04
|
||||
|
||||
WORKDIR /root
|
||||
|
||||
RUN apt-get update
|
||||
|
||||
# caffe
|
||||
# from https://github.com/BVLC/caffe/blob/master/docker/cpu/Dockerfile
|
||||
RUN apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
cmake \
|
||||
git \
|
||||
wget \
|
||||
curl \
|
||||
libatlas-base-dev \
|
||||
libboost-all-dev \
|
||||
libgflags-dev \
|
||||
libgoogle-glog-dev \
|
||||
libhdf5-serial-dev \
|
||||
libleveldb-dev \
|
||||
liblmdb-dev \
|
||||
libopencv-dev \
|
||||
libprotobuf-dev \
|
||||
libsnappy-dev \
|
||||
protobuf-compiler \
|
||||
python-dev \
|
||||
python-numpy \
|
||||
python-pip \
|
||||
python-setuptools \
|
||||
python-scipy
|
||||
|
||||
ENV CAFFE_ROOT=/opt/caffe
|
||||
|
||||
RUN mkdir -p $CAFFE_ROOT && \
|
||||
cd $CAFFE_ROOT && \
|
||||
git clone https://github.com/s9xie/hed . && \
|
||||
git checkout 9e74dd710773d8d8a469ad905c76f4a7fa08f945 && \
|
||||
pip install --upgrade pip && \
|
||||
cd python && for req in $(cat requirements.txt) pydot; do pip install $req; done && cd .. && \
|
||||
# https://github.com/s9xie/hed/pull/23
|
||||
sed -i "s|add_subdirectory(examples)||g" CMakeLists.txt && \
|
||||
# https://github.com/s9xie/hed/issues/11
|
||||
sed -i "647s|//||" include/caffe/loss_layers.hpp && \
|
||||
sed -i "648s|//||" include/caffe/loss_layers.hpp && \
|
||||
mkdir build && cd build && \
|
||||
cmake -DCPU_ONLY=1 .. && \
|
||||
make -j"$(nproc)"
|
||||
|
||||
ENV PYCAFFE_ROOT $CAFFE_ROOT/python
|
||||
ENV PYTHONPATH $PYCAFFE_ROOT:$PYTHONPATH
|
||||
ENV PATH $CAFFE_ROOT/build/tools:$PYCAFFE_ROOT:$PATH
|
||||
RUN echo "$CAFFE_ROOT/build/lib" >> /etc/ld.so.conf.d/caffe.conf && ldconfig
|
||||
|
||||
RUN cd $CAFFE_ROOT && curl -O http://vcl.ucsd.edu/hed/hed_pretrained_bsds.caffemodel
|
||||
|
||||
# octave
|
||||
RUN apt-get install -y --no-install-recommends octave liboctave-dev && \
|
||||
octave --eval "pkg install -forge image" && \
|
||||
echo "pkg load image;" >> /root/.octaverc
|
||||
|
||||
RUN apt-get install -y --no-install-recommends unzip && \
|
||||
curl -O https://pdollar.github.io/toolbox/archive/piotr_toolbox.zip && \
|
||||
unzip piotr_toolbox.zip && \
|
||||
octave --eval "addpath(genpath('/root/toolbox')); savepath;" && \
|
||||
echo "#include <stdlib.h>" > wrappers.hpp && \
|
||||
cat /root/toolbox/channels/private/wrappers.hpp >> wrappers.hpp && \
|
||||
mv wrappers.hpp /root/toolbox/channels/private/wrappers.hpp && \
|
||||
mkdir /root/mex && \
|
||||
cd /root/toolbox/channels/private && \
|
||||
mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/convConst.mex convConst.cpp && \
|
||||
mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/gradientMex.mex gradientMex.cpp && \
|
||||
mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/imPadMex.mex imPadMex.cpp && \
|
||||
mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/imResampleMex.mex imResampleMex.cpp && \
|
||||
mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/rgbConvertMex.mex rgbConvertMex.cpp && \
|
||||
octave --eval "addpath('/root/mex'); savepath;"
|
||||
|
||||
RUN curl -O https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp && \
|
||||
octave --eval "mex edgesNmsMex.cpp" && \
|
||||
mv edgesNmsMex.mex /root/mex/
|
||||
|
||||
# from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.gpu
|
||||
RUN apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
libfreetype6-dev \
|
||||
libpng12-dev \
|
||||
libzmq3-dev \
|
||||
pkg-config \
|
||||
python \
|
||||
python-dev \
|
||||
rsync \
|
||||
software-properties-common \
|
||||
unzip
|
||||
|
||||
# gpu tracing in tensorflow
|
||||
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
RUN pip install \
|
||||
appdirs==1.4.0 \
|
||||
funcsigs==1.0.2 \
|
||||
google-api-python-client==1.6.2 \
|
||||
google-auth==0.7.0 \
|
||||
google-auth-httplib2==0.0.2 \
|
||||
google-cloud-core==0.22.1 \
|
||||
google-cloud-storage==0.22.0 \
|
||||
googleapis-common-protos==1.5.2 \
|
||||
httplib2==0.10.3 \
|
||||
mock==2.0.0 \
|
||||
numpy==1.12.0 \
|
||||
oauth2client==4.0.0 \
|
||||
packaging==16.8 \
|
||||
pbr==1.10.0 \
|
||||
protobuf==3.2.0 \
|
||||
pyasn1==0.2.2 \
|
||||
pyasn1-modules==0.0.8 \
|
||||
pyparsing==2.1.10 \
|
||||
rsa==3.4.2 \
|
||||
six==1.10.0 \
|
||||
uritemplate==3.0.0 \
|
||||
tensorflow-gpu==1.4.1
|
||||
|
After Width: | Height: | Size: 25 KiB |
|
After Width: | Height: | Size: 99 KiB |
|
After Width: | Height: | Size: 98 KiB |
|
After Width: | Height: | Size: 10 KiB |
|
After Width: | Height: | Size: 115 KiB |
|
After Width: | Height: | Size: 18 KiB |
|
After Width: | Height: | Size: 93 KiB |
|
After Width: | Height: | Size: 96 KiB |
|
After Width: | Height: | Size: 8.3 KiB |
|
After Width: | Height: | Size: 50 KiB |
|
After Width: | Height: | Size: 98 KiB |
|
After Width: | Height: | Size: 110 KiB |
|
After Width: | Height: | Size: 13 KiB |
|
After Width: | Height: | Size: 34 KiB |
|
After Width: | Height: | Size: 79 KiB |
|
After Width: | Height: | Size: 111 KiB |
|
After Width: | Height: | Size: 12 KiB |
|
After Width: | Height: | Size: 13 KiB |
|
After Width: | Height: | Size: 31 KiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 28 KiB |
|
After Width: | Height: | Size: 34 KiB |
|
After Width: | Height: | Size: 469 KiB |
|
After Width: | Height: | Size: 47 KiB |
|
After Width: | Height: | Size: 140 KiB |
|
After Width: | Height: | Size: 448 KiB |
|
After Width: | Height: | Size: 347 KiB |
|
After Width: | Height: | Size: 277 KiB |
|
After Width: | Height: | Size: 4.2 MiB |
@ -0,0 +1,28 @@
|
||||
{
|
||||
"aspect_ratio": 1.0,
|
||||
"batch_size": 1,
|
||||
"beta1": 0.5,
|
||||
"checkpoint": null,
|
||||
"display_freq": 0,
|
||||
"flip": true,
|
||||
"gan_weight": 1.0,
|
||||
"input_dir": "facades/train",
|
||||
"l1_weight": 100.0,
|
||||
"lab_colorization": false,
|
||||
"lr": 0.0002,
|
||||
"max_epochs": 200,
|
||||
"max_steps": null,
|
||||
"mode": "train",
|
||||
"ndf": 64,
|
||||
"ngf": 64,
|
||||
"output_dir": "facades_train",
|
||||
"output_filetype": "png",
|
||||
"progress_freq": 50,
|
||||
"save_freq": 5000,
|
||||
"scale_size": 286,
|
||||
"seed": 1550500280,
|
||||
"separable_conv": false,
|
||||
"summary_freq": 100,
|
||||
"trace_freq": 0,
|
||||
"which_direction": "BtoA"
|
||||
}
|
||||
@ -0,0 +1,803 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
import random
|
||||
import collections
|
||||
import math
|
||||
import time
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_dir", help="path to folder containing images")
|
||||
parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
|
||||
parser.add_argument("--output_dir", required=True, help="where to put output files")
|
||||
parser.add_argument("--seed", type=int)
|
||||
parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing")
|
||||
|
||||
parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
|
||||
parser.add_argument("--max_epochs", type=int, help="number of training epochs")
|
||||
parser.add_argument("--summary_freq", type=int, default=100, help="update summaries every summary_freq steps")
|
||||
parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
|
||||
parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps")
|
||||
parser.add_argument("--display_freq", type=int, default=0, help="write current training images every display_freq steps")
|
||||
parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
|
||||
|
||||
parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator")
|
||||
parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
|
||||
parser.add_argument("--lab_colorization", action="store_true", help="split input image into brightness (A) and color (B)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch")
|
||||
parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"])
|
||||
parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
|
||||
parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
|
||||
parser.add_argument("--scale_size", type=int, default=286, help="scale images to this size before cropping to 256x256")
|
||||
parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally")
|
||||
parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally")
|
||||
parser.set_defaults(flip=True)
|
||||
parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam")
|
||||
parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
|
||||
parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
|
||||
parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient")
|
||||
|
||||
# export options
|
||||
parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"])
|
||||
a = parser.parse_args()
|
||||
|
||||
EPS = 1e-12
|
||||
CROP_SIZE = 256
|
||||
|
||||
Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
|
||||
Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
with tf.name_scope("preprocess"):
|
||||
# [0, 1] => [-1, 1]
|
||||
return image * 2 - 1
|
||||
|
||||
|
||||
def deprocess(image):
|
||||
with tf.name_scope("deprocess"):
|
||||
# [-1, 1] => [0, 1]
|
||||
return (image + 1) / 2
|
||||
|
||||
|
||||
def preprocess_lab(lab):
|
||||
with tf.name_scope("preprocess_lab"):
|
||||
L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
|
||||
# L_chan: black and white with input range [0, 100]
|
||||
# a_chan/b_chan: color channels with input range ~[-110, 110], not exact
|
||||
# [0, 100] => [-1, 1], ~[-110, 110] => [-1, 1]
|
||||
return [L_chan / 50 - 1, a_chan / 110, b_chan / 110]
|
||||
|
||||
|
||||
def deprocess_lab(L_chan, a_chan, b_chan):
|
||||
with tf.name_scope("deprocess_lab"):
|
||||
# this is axis=3 instead of axis=2 because we process individual images but deprocess batches
|
||||
return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
|
||||
|
||||
|
||||
def augment(image, brightness):
|
||||
# (a, b) color channels, combine with L channel and convert to rgb
|
||||
a_chan, b_chan = tf.unstack(image, axis=3)
|
||||
L_chan = tf.squeeze(brightness, axis=3)
|
||||
lab = deprocess_lab(L_chan, a_chan, b_chan)
|
||||
rgb = lab_to_rgb(lab)
|
||||
return rgb
|
||||
|
||||
|
||||
def discrim_conv(batch_input, out_channels, stride):
|
||||
padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
|
||||
return tf.layers.conv2d(padded_input, out_channels, kernel_size=4, strides=(stride, stride), padding="valid", kernel_initializer=tf.random_normal_initializer(0, 0.02))
|
||||
|
||||
|
||||
def gen_conv(batch_input, out_channels):
|
||||
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
|
||||
initializer = tf.random_normal_initializer(0, 0.02)
|
||||
if a.separable_conv:
|
||||
return tf.layers.separable_conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
|
||||
else:
|
||||
return tf.layers.conv2d(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
|
||||
|
||||
|
||||
def gen_deconv(batch_input, out_channels):
|
||||
# [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
|
||||
initializer = tf.random_normal_initializer(0, 0.02)
|
||||
if a.separable_conv:
|
||||
_b, h, w, _c = batch_input.shape
|
||||
resized_input = tf.image.resize_images(batch_input, [h * 2, w * 2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
|
||||
else:
|
||||
return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)
|
||||
|
||||
|
||||
def lrelu(x, a):
|
||||
with tf.name_scope("lrelu"):
|
||||
# adding these together creates the leak part and linear part
|
||||
# then cancels them out by subtracting/adding an absolute value term
|
||||
# leak: a*x/2 - a*abs(x)/2
|
||||
# linear: x/2 + abs(x)/2
|
||||
|
||||
# this block looks like it has 2 inputs on the graph unless we do this
|
||||
x = tf.identity(x)
|
||||
return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
|
||||
|
||||
|
||||
def batchnorm(inputs):
|
||||
return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))
|
||||
|
||||
|
||||
def check_image(image):
|
||||
assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
|
||||
with tf.control_dependencies([assertion]):
|
||||
image = tf.identity(image)
|
||||
|
||||
if image.get_shape().ndims not in (3, 4):
|
||||
raise ValueError("image must be either 3 or 4 dimensions")
|
||||
|
||||
# make the last dimension 3 so that you can unstack the colors
|
||||
shape = list(image.get_shape())
|
||||
shape[-1] = 3
|
||||
image.set_shape(shape)
|
||||
return image
|
||||
|
||||
# based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
|
||||
def rgb_to_lab(srgb):
|
||||
with tf.name_scope("rgb_to_lab"):
|
||||
srgb = check_image(srgb)
|
||||
srgb_pixels = tf.reshape(srgb, [-1, 3])
|
||||
|
||||
with tf.name_scope("srgb_to_xyz"):
|
||||
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
|
||||
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
|
||||
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
|
||||
rgb_to_xyz = tf.constant([
|
||||
# X Y Z
|
||||
[0.412453, 0.212671, 0.019334], # R
|
||||
[0.357580, 0.715160, 0.119193], # G
|
||||
[0.180423, 0.072169, 0.950227], # B
|
||||
])
|
||||
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
|
||||
|
||||
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
|
||||
with tf.name_scope("xyz_to_cielab"):
|
||||
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
|
||||
|
||||
# normalize for D65 white point
|
||||
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
|
||||
|
||||
epsilon = 6/29
|
||||
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
|
||||
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
|
||||
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
|
||||
|
||||
# convert to lab
|
||||
fxfyfz_to_lab = tf.constant([
|
||||
# l a b
|
||||
[ 0.0, 500.0, 0.0], # fx
|
||||
[116.0, -500.0, 200.0], # fy
|
||||
[ 0.0, 0.0, -200.0], # fz
|
||||
])
|
||||
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
|
||||
|
||||
return tf.reshape(lab_pixels, tf.shape(srgb))
|
||||
|
||||
|
||||
def lab_to_rgb(lab):
|
||||
with tf.name_scope("lab_to_rgb"):
|
||||
lab = check_image(lab)
|
||||
lab_pixels = tf.reshape(lab, [-1, 3])
|
||||
|
||||
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
|
||||
with tf.name_scope("cielab_to_xyz"):
|
||||
# convert to fxfyfz
|
||||
lab_to_fxfyfz = tf.constant([
|
||||
# fx fy fz
|
||||
[1/116.0, 1/116.0, 1/116.0], # l
|
||||
[1/500.0, 0.0, 0.0], # a
|
||||
[ 0.0, 0.0, -1/200.0], # b
|
||||
])
|
||||
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
|
||||
|
||||
# convert to xyz
|
||||
epsilon = 6/29
|
||||
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
|
||||
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
|
||||
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
|
||||
|
||||
# denormalize for D65 white point
|
||||
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
|
||||
|
||||
with tf.name_scope("xyz_to_srgb"):
|
||||
xyz_to_rgb = tf.constant([
|
||||
# r g b
|
||||
[ 3.2404542, -0.9692660, 0.0556434], # x
|
||||
[-1.5371385, 1.8760108, -0.2040259], # y
|
||||
[-0.4985314, 0.0415560, 1.0572252], # z
|
||||
])
|
||||
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
|
||||
# avoid a slightly negative number messing up the conversion
|
||||
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
|
||||
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
|
||||
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
|
||||
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
|
||||
|
||||
return tf.reshape(srgb_pixels, tf.shape(lab))
|
||||
|
||||
|
||||
def load_examples():
|
||||
if a.input_dir is None or not os.path.exists(a.input_dir):
|
||||
raise Exception("input_dir does not exist")
|
||||
|
||||
input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
|
||||
decode = tf.image.decode_jpeg
|
||||
if len(input_paths) == 0:
|
||||
input_paths = glob.glob(os.path.join(a.input_dir, "*.png"))
|
||||
decode = tf.image.decode_png
|
||||
|
||||
if len(input_paths) == 0:
|
||||
raise Exception("input_dir contains no image files")
|
||||
|
||||
def get_name(path):
|
||||
name, _ = os.path.splitext(os.path.basename(path))
|
||||
return name
|
||||
|
||||
# if the image names are numbers, sort by the value rather than asciibetically
|
||||
# having sorted inputs means that the outputs are sorted in test mode
|
||||
if all(get_name(path).isdigit() for path in input_paths):
|
||||
input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
|
||||
else:
|
||||
input_paths = sorted(input_paths)
|
||||
|
||||
with tf.name_scope("load_images"):
|
||||
path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
|
||||
reader = tf.WholeFileReader()
|
||||
paths, contents = reader.read(path_queue)
|
||||
raw_input = decode(contents)
|
||||
raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)
|
||||
|
||||
assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels")
|
||||
with tf.control_dependencies([assertion]):
|
||||
raw_input = tf.identity(raw_input)
|
||||
|
||||
raw_input.set_shape([None, None, 3])
|
||||
|
||||
if a.lab_colorization:
|
||||
# load color and brightness from image, no B image exists here
|
||||
lab = rgb_to_lab(raw_input)
|
||||
L_chan, a_chan, b_chan = preprocess_lab(lab)
|
||||
a_images = tf.expand_dims(L_chan, axis=2)
|
||||
b_images = tf.stack([a_chan, b_chan], axis=2)
|
||||
else:
|
||||
# break apart image pair and move to range [-1, 1]
|
||||
width = tf.shape(raw_input)[1] # [height, width, channels]
|
||||
a_images = preprocess(raw_input[:,:width//2,:])
|
||||
b_images = preprocess(raw_input[:,width//2:,:])
|
||||
|
||||
if a.which_direction == "AtoB":
|
||||
inputs, targets = [a_images, b_images]
|
||||
elif a.which_direction == "BtoA":
|
||||
inputs, targets = [b_images, a_images]
|
||||
else:
|
||||
raise Exception("invalid direction")
|
||||
|
||||
# synchronize seed for image operations so that we do the same operations to both
|
||||
# input and output images
|
||||
seed = random.randint(0, 2**31 - 1)
|
||||
def transform(image):
|
||||
r = image
|
||||
if a.flip:
|
||||
r = tf.image.random_flip_left_right(r, seed=seed)
|
||||
|
||||
# area produces a nice downscaling, but does nearest neighbor for upscaling
|
||||
# assume we're going to be doing downscaling here
|
||||
r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
|
||||
|
||||
offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
|
||||
if a.scale_size > CROP_SIZE:
|
||||
r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
|
||||
elif a.scale_size < CROP_SIZE:
|
||||
raise Exception("scale size cannot be less than crop size")
|
||||
return r
|
||||
|
||||
with tf.name_scope("input_images"):
|
||||
input_images = transform(inputs)
|
||||
|
||||
with tf.name_scope("target_images"):
|
||||
target_images = transform(targets)
|
||||
|
||||
paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
|
||||
steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))
|
||||
|
||||
return Examples(
|
||||
paths=paths_batch,
|
||||
inputs=inputs_batch,
|
||||
targets=targets_batch,
|
||||
count=len(input_paths),
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
)
|
||||
|
||||
|
||||
def create_generator(generator_inputs, generator_outputs_channels):
|
||||
layers = []
|
||||
|
||||
# encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
|
||||
with tf.variable_scope("encoder_1"):
|
||||
output = gen_conv(generator_inputs, a.ngf)
|
||||
layers.append(output)
|
||||
|
||||
layer_specs = [
|
||||
a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
|
||||
a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
|
||||
a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
|
||||
a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
|
||||
a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
|
||||
a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
|
||||
a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
|
||||
]
|
||||
|
||||
for out_channels in layer_specs:
|
||||
with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
|
||||
rectified = lrelu(layers[-1], 0.2)
|
||||
# [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
|
||||
convolved = gen_conv(rectified, out_channels)
|
||||
output = batchnorm(convolved)
|
||||
layers.append(output)
|
||||
|
||||
layer_specs = [
|
||||
(a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
|
||||
(a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
|
||||
(a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
|
||||
(a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
|
||||
(a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
|
||||
(a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
|
||||
(a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
|
||||
]
|
||||
|
||||
num_encoder_layers = len(layers)
|
||||
for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
|
||||
skip_layer = num_encoder_layers - decoder_layer - 1
|
||||
with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
|
||||
if decoder_layer == 0:
|
||||
# first decoder layer doesn't have skip connections
|
||||
# since it is directly connected to the skip_layer
|
||||
input = layers[-1]
|
||||
else:
|
||||
input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
|
||||
|
||||
rectified = tf.nn.relu(input)
|
||||
# [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
|
||||
output = gen_deconv(rectified, out_channels)
|
||||
output = batchnorm(output)
|
||||
|
||||
if dropout > 0.0:
|
||||
output = tf.nn.dropout(output, keep_prob=1 - dropout)
|
||||
|
||||
layers.append(output)
|
||||
|
||||
# decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
|
||||
with tf.variable_scope("decoder_1"):
|
||||
input = tf.concat([layers[-1], layers[0]], axis=3)
|
||||
rectified = tf.nn.relu(input)
|
||||
output = gen_deconv(rectified, generator_outputs_channels)
|
||||
output = tf.tanh(output)
|
||||
layers.append(output)
|
||||
|
||||
return layers[-1]
|
||||
|
||||
|
||||
def create_model(inputs, targets):
|
||||
def create_discriminator(discrim_inputs, discrim_targets):
|
||||
n_layers = 3
|
||||
layers = []
|
||||
|
||||
# 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
|
||||
input = tf.concat([discrim_inputs, discrim_targets], axis=3)
|
||||
|
||||
# layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
|
||||
with tf.variable_scope("layer_1"):
|
||||
convolved = discrim_conv(input, a.ndf, stride=2)
|
||||
rectified = lrelu(convolved, 0.2)
|
||||
layers.append(rectified)
|
||||
|
||||
# layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
|
||||
# layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
|
||||
# layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
|
||||
for i in range(n_layers):
|
||||
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
|
||||
out_channels = a.ndf * min(2**(i+1), 8)
|
||||
stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1
|
||||
convolved = discrim_conv(layers[-1], out_channels, stride=stride)
|
||||
normalized = batchnorm(convolved)
|
||||
rectified = lrelu(normalized, 0.2)
|
||||
layers.append(rectified)
|
||||
|
||||
# layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
|
||||
with tf.variable_scope("layer_%d" % (len(layers) + 1)):
|
||||
convolved = discrim_conv(rectified, out_channels=1, stride=1)
|
||||
output = tf.sigmoid(convolved)
|
||||
layers.append(output)
|
||||
|
||||
return layers[-1]
|
||||
|
||||
with tf.variable_scope("generator"):
|
||||
out_channels = int(targets.get_shape()[-1])
|
||||
outputs = create_generator(inputs, out_channels)
|
||||
|
||||
# create two copies of discriminator, one for real pairs and one for fake pairs
|
||||
# they share the same underlying variables
|
||||
with tf.name_scope("real_discriminator"):
|
||||
with tf.variable_scope("discriminator"):
|
||||
# 2x [batch, height, width, channels] => [batch, 30, 30, 1]
|
||||
predict_real = create_discriminator(inputs, targets)
|
||||
|
||||
with tf.name_scope("fake_discriminator"):
|
||||
with tf.variable_scope("discriminator", reuse=True):
|
||||
# 2x [batch, height, width, channels] => [batch, 30, 30, 1]
|
||||
predict_fake = create_discriminator(inputs, outputs)
|
||||
|
||||
with tf.name_scope("discriminator_loss"):
|
||||
# minimizing -tf.log will try to get inputs to 1
|
||||
# predict_real => 1
|
||||
# predict_fake => 0
|
||||
discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
|
||||
|
||||
with tf.name_scope("generator_loss"):
|
||||
# predict_fake => 1
|
||||
# abs(targets - outputs) => 0
|
||||
gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
|
||||
gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
|
||||
gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight
|
||||
|
||||
with tf.name_scope("discriminator_train"):
|
||||
discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
|
||||
discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
|
||||
discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
|
||||
discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
|
||||
|
||||
with tf.name_scope("generator_train"):
|
||||
with tf.control_dependencies([discrim_train]):
|
||||
gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
|
||||
gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
|
||||
gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
|
||||
gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
|
||||
|
||||
ema = tf.train.ExponentialMovingAverage(decay=0.99)
|
||||
update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
incr_global_step = tf.assign(global_step, global_step+1)
|
||||
|
||||
return Model(
|
||||
predict_real=predict_real,
|
||||
predict_fake=predict_fake,
|
||||
discrim_loss=ema.average(discrim_loss),
|
||||
discrim_grads_and_vars=discrim_grads_and_vars,
|
||||
gen_loss_GAN=ema.average(gen_loss_GAN),
|
||||
gen_loss_L1=ema.average(gen_loss_L1),
|
||||
gen_grads_and_vars=gen_grads_and_vars,
|
||||
outputs=outputs,
|
||||
train=tf.group(update_losses, incr_global_step, gen_train),
|
||||
)
|
||||
|
||||
|
||||
def save_images(fetches, step=None):
|
||||
image_dir = os.path.join(a.output_dir, "images")
|
||||
if not os.path.exists(image_dir):
|
||||
os.makedirs(image_dir)
|
||||
|
||||
filesets = []
|
||||
for i, in_path in enumerate(fetches["paths"]):
|
||||
name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
|
||||
fileset = {"name": name, "step": step}
|
||||
for kind in ["inputs", "outputs", "targets"]:
|
||||
filename = name + "-" + kind + ".png"
|
||||
if step is not None:
|
||||
filename = "%08d-%s" % (step, filename)
|
||||
fileset[kind] = filename
|
||||
out_path = os.path.join(image_dir, filename)
|
||||
contents = fetches[kind][i]
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(contents)
|
||||
filesets.append(fileset)
|
||||
return filesets
|
||||
|
||||
|
||||
def append_index(filesets, step=False):
|
||||
index_path = os.path.join(a.output_dir, "index.html")
|
||||
if os.path.exists(index_path):
|
||||
index = open(index_path, "a")
|
||||
else:
|
||||
index = open(index_path, "w")
|
||||
index.write("<html><body><table><tr>")
|
||||
if step:
|
||||
index.write("<th>step</th>")
|
||||
index.write("<th>name</th><th>input</th><th>output</th><th>target</th></tr>")
|
||||
|
||||
for fileset in filesets:
|
||||
index.write("<tr>")
|
||||
|
||||
if step:
|
||||
index.write("<td>%d</td>" % fileset["step"])
|
||||
index.write("<td>%s</td>" % fileset["name"])
|
||||
|
||||
for kind in ["inputs", "outputs", "targets"]:
|
||||
index.write("<td><img src='images/%s'></td>" % fileset[kind])
|
||||
|
||||
index.write("</tr>")
|
||||
return index_path
|
||||
|
||||
|
||||
def main():
|
||||
if a.seed is None:
|
||||
a.seed = random.randint(0, 2**31 - 1)
|
||||
|
||||
tf.set_random_seed(a.seed)
|
||||
np.random.seed(a.seed)
|
||||
random.seed(a.seed)
|
||||
|
||||
if not os.path.exists(a.output_dir):
|
||||
os.makedirs(a.output_dir)
|
||||
|
||||
if a.mode == "test" or a.mode == "export":
|
||||
if a.checkpoint is None:
|
||||
raise Exception("checkpoint required for test mode")
|
||||
|
||||
# load some options from the checkpoint
|
||||
options = {"which_direction", "ngf", "ndf", "lab_colorization"}
|
||||
with open(os.path.join(a.checkpoint, "options.json")) as f:
|
||||
for key, val in json.loads(f.read()).items():
|
||||
if key in options:
|
||||
print("loaded", key, "=", val)
|
||||
setattr(a, key, val)
|
||||
# disable these features in test mode
|
||||
a.scale_size = CROP_SIZE
|
||||
a.flip = False
|
||||
|
||||
for k, v in a._get_kwargs():
|
||||
print(k, "=", v)
|
||||
|
||||
with open(os.path.join(a.output_dir, "options.json"), "w") as f:
|
||||
f.write(json.dumps(vars(a), sort_keys=True, indent=4))
|
||||
|
||||
if a.mode == "export":
|
||||
# export the generator to a meta graph that can be imported later for standalone generation
|
||||
if a.lab_colorization:
|
||||
raise Exception("export not supported for lab_colorization")
|
||||
|
||||
input = tf.placeholder(tf.string, shape=[1])
|
||||
input_data = tf.decode_base64(input[0])
|
||||
input_image = tf.image.decode_png(input_data)
|
||||
|
||||
# remove alpha channel if present
|
||||
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:,:,:3], lambda: input_image)
|
||||
# convert grayscale to RGB
|
||||
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image)
|
||||
|
||||
input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32)
|
||||
input_image.set_shape([CROP_SIZE, CROP_SIZE, 3])
|
||||
batch_input = tf.expand_dims(input_image, axis=0)
|
||||
|
||||
with tf.variable_scope("generator"):
|
||||
batch_output = deprocess(create_generator(preprocess(batch_input), 3))
|
||||
|
||||
output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]
|
||||
if a.output_filetype == "png":
|
||||
output_data = tf.image.encode_png(output_image)
|
||||
elif a.output_filetype == "jpeg":
|
||||
output_data = tf.image.encode_jpeg(output_image, quality=80)
|
||||
else:
|
||||
raise Exception("invalid filetype")
|
||||
output = tf.convert_to_tensor([tf.encode_base64(output_data)])
|
||||
|
||||
key = tf.placeholder(tf.string, shape=[1])
|
||||
inputs = {
|
||||
"key": key.name,
|
||||
"input": input.name
|
||||
}
|
||||
tf.add_to_collection("inputs", json.dumps(inputs))
|
||||
outputs = {
|
||||
"key": tf.identity(key).name,
|
||||
"output": output.name,
|
||||
}
|
||||
tf.add_to_collection("outputs", json.dumps(outputs))
|
||||
|
||||
init_op = tf.global_variables_initializer()
|
||||
restore_saver = tf.train.Saver()
|
||||
export_saver = tf.train.Saver()
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(init_op)
|
||||
print("loading model from checkpoint")
|
||||
checkpoint = tf.train.latest_checkpoint(a.checkpoint)
|
||||
restore_saver.restore(sess, checkpoint)
|
||||
print("exporting model")
|
||||
export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta"))
|
||||
export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=False)
|
||||
|
||||
return
|
||||
|
||||
examples = load_examples()
|
||||
print("examples count = %d" % examples.count)
|
||||
|
||||
# inputs and targets are [batch_size, height, width, channels]
|
||||
model = create_model(examples.inputs, examples.targets)
|
||||
|
||||
# undo colorization splitting on images that we use for display/output
|
||||
if a.lab_colorization:
|
||||
if a.which_direction == "AtoB":
|
||||
# inputs is brightness, this will be handled fine as a grayscale image
|
||||
# need to augment targets and outputs with brightness
|
||||
targets = augment(examples.targets, examples.inputs)
|
||||
outputs = augment(model.outputs, examples.inputs)
|
||||
# inputs can be deprocessed normally and handled as if they are single channel
|
||||
# grayscale images
|
||||
inputs = deprocess(examples.inputs)
|
||||
elif a.which_direction == "BtoA":
|
||||
# inputs will be color channels only, get brightness from targets
|
||||
inputs = augment(examples.inputs, examples.targets)
|
||||
targets = deprocess(examples.targets)
|
||||
outputs = deprocess(model.outputs)
|
||||
else:
|
||||
raise Exception("invalid direction")
|
||||
else:
|
||||
inputs = deprocess(examples.inputs)
|
||||
targets = deprocess(examples.targets)
|
||||
outputs = deprocess(model.outputs)
|
||||
|
||||
def convert(image):
|
||||
if a.aspect_ratio != 1.0:
|
||||
# upscale to correct aspect ratio
|
||||
size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))]
|
||||
image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC)
|
||||
|
||||
return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)
|
||||
|
||||
# reverse any processing on images so they can be written to disk or displayed to user
|
||||
with tf.name_scope("convert_inputs"):
|
||||
converted_inputs = convert(inputs)
|
||||
|
||||
with tf.name_scope("convert_targets"):
|
||||
converted_targets = convert(targets)
|
||||
|
||||
with tf.name_scope("convert_outputs"):
|
||||
converted_outputs = convert(outputs)
|
||||
|
||||
with tf.name_scope("encode_images"):
|
||||
display_fetches = {
|
||||
"paths": examples.paths,
|
||||
"inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
|
||||
"targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
|
||||
"outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
|
||||
}
|
||||
|
||||
# summaries
|
||||
with tf.name_scope("inputs_summary"):
|
||||
tf.summary.image("inputs", converted_inputs)
|
||||
|
||||
with tf.name_scope("targets_summary"):
|
||||
tf.summary.image("targets", converted_targets)
|
||||
|
||||
with tf.name_scope("outputs_summary"):
|
||||
tf.summary.image("outputs", converted_outputs)
|
||||
|
||||
with tf.name_scope("predict_real_summary"):
|
||||
tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))
|
||||
|
||||
with tf.name_scope("predict_fake_summary"):
|
||||
tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))
|
||||
|
||||
tf.summary.scalar("discriminator_loss", model.discrim_loss)
|
||||
tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
|
||||
tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
|
||||
|
||||
for var in tf.trainable_variables():
|
||||
tf.summary.histogram(var.op.name + "/values", var)
|
||||
|
||||
for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
|
||||
tf.summary.histogram(var.op.name + "/gradients", grad)
|
||||
|
||||
with tf.name_scope("parameter_count"):
|
||||
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
|
||||
|
||||
saver = tf.train.Saver(max_to_keep=1)
|
||||
|
||||
logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
|
||||
sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
|
||||
with sv.managed_session() as sess:
|
||||
print("parameter_count =", sess.run(parameter_count))
|
||||
|
||||
if a.checkpoint is not None:
|
||||
print("loading model from checkpoint")
|
||||
checkpoint = tf.train.latest_checkpoint(a.checkpoint)
|
||||
saver.restore(sess, checkpoint)
|
||||
|
||||
max_steps = 2**32
|
||||
if a.max_epochs is not None:
|
||||
max_steps = examples.steps_per_epoch * a.max_epochs
|
||||
if a.max_steps is not None:
|
||||
max_steps = a.max_steps
|
||||
|
||||
if a.mode == "test":
|
||||
# testing
|
||||
# at most, process the test data once
|
||||
start = time.time()
|
||||
max_steps = min(examples.steps_per_epoch, max_steps)
|
||||
for step in range(max_steps):
|
||||
results = sess.run(display_fetches)
|
||||
filesets = save_images(results)
|
||||
for i, f in enumerate(filesets):
|
||||
print("evaluated image", f["name"])
|
||||
index_path = append_index(filesets)
|
||||
print("wrote index at", index_path)
|
||||
print("rate", (time.time() - start) / max_steps)
|
||||
else:
|
||||
# training
|
||||
start = time.time()
|
||||
|
||||
for step in range(max_steps):
|
||||
def should(freq):
|
||||
return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)
|
||||
|
||||
options = None
|
||||
run_metadata = None
|
||||
if should(a.trace_freq):
|
||||
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
fetches = {
|
||||
"train": model.train,
|
||||
"global_step": sv.global_step,
|
||||
}
|
||||
|
||||
if should(a.progress_freq):
|
||||
fetches["discrim_loss"] = model.discrim_loss
|
||||
fetches["gen_loss_GAN"] = model.gen_loss_GAN
|
||||
fetches["gen_loss_L1"] = model.gen_loss_L1
|
||||
|
||||
if should(a.summary_freq):
|
||||
fetches["summary"] = sv.summary_op
|
||||
|
||||
if should(a.display_freq):
|
||||
fetches["display"] = display_fetches
|
||||
|
||||
results = sess.run(fetches, options=options, run_metadata=run_metadata)
|
||||
|
||||
if should(a.summary_freq):
|
||||
print("recording summary")
|
||||
sv.summary_writer.add_summary(results["summary"], results["global_step"])
|
||||
|
||||
if should(a.display_freq):
|
||||
print("saving display images")
|
||||
filesets = save_images(results["display"], step=results["global_step"])
|
||||
append_index(filesets, step=True)
|
||||
|
||||
if should(a.trace_freq):
|
||||
print("recording trace")
|
||||
sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])
|
||||
|
||||
if should(a.progress_freq):
|
||||
# global_step will have the correct step count if we resume from a checkpoint
|
||||
train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch)
|
||||
train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1
|
||||
rate = (step + 1) * a.batch_size / (time.time() - start)
|
||||
remaining = (max_steps - step) * a.batch_size / rate
|
||||
print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60))
|
||||
print("discrim_loss", results["discrim_loss"])
|
||||
print("gen_loss_GAN", results["gen_loss_GAN"])
|
||||
print("gen_loss_L1", results["gen_loss_L1"])
|
||||
|
||||
if should(a.save_freq):
|
||||
print("saving model")
|
||||
saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step)
|
||||
|
||||
if sv.should_stop():
|
||||
break
|
||||
|
||||
|
||||
main()
|
||||
@ -0,0 +1,29 @@
|
||||
# pix2pix-tensorflow server
|
||||
|
||||
Host pix2pix-tensorflow models to be used with something like the [Image-to-Image Demo](https://affinelayer.com/pixsrv/).
|
||||
|
||||
This is a simple python server that uses [deeplearn.js](https://deeplearnjs.org/) and weights exported from pix2pix checkpoints using `tools/export-checkpoint.py`.
|
||||
|
||||
## Exporting
|
||||
|
||||
You can export a model to be served with `tools/export-checkpoint.py`.
|
||||
|
||||
```sh
|
||||
python tools/export-checkpoint.py \
|
||||
--checkpoint facades_BtoA \
|
||||
--output_file static/models/facades_BtoA.bin
|
||||
```
|
||||
|
||||
You can also copy models from the `pix2pix-tensorflow-models` repo:
|
||||
|
||||
```sh
|
||||
git clone git@github.com:affinelayer/pix2pix-tensorflow-models.git static/models
|
||||
```
|
||||
|
||||
## Serving
|
||||
|
||||
```sh
|
||||
python serve.py --port 8000
|
||||
```
|
||||
|
||||
If you open [http://localhost:8000/](http://localhost:8000/) in a browser, you should see an interactive demo.
|
||||
@ -0,0 +1,18 @@
|
||||
import os
|
||||
import argparse
|
||||
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", default=8000, type=int, help="port to listen on")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.chdir('static')
|
||||
server_address = ('', args.port)
|
||||
httpd = HTTPServer(server_address, SimpleHTTPRequestHandler)
|
||||
print('serving at http://127.0.0.1:%d' % args.port)
|
||||
httpd.serve_forever()
|
||||
|
||||
|
||||
main()
|
||||
|
After Width: | Height: | Size: 2.3 KiB |
|
After Width: | Height: | Size: 56 KiB |
|
After Width: | Height: | Size: 126 KiB |
|
After Width: | Height: | Size: 20 KiB |
|
After Width: | Height: | Size: 82 KiB |
|
After Width: | Height: | Size: 234 KiB |
|
After Width: | Height: | Size: 13 KiB |
|
After Width: | Height: | Size: 62 KiB |
|
After Width: | Height: | Size: 153 KiB |
|
After Width: | Height: | Size: 34 KiB |
|
After Width: | Height: | Size: 320 B |
|
After Width: | Height: | Size: 42 KiB |
|
After Width: | Height: | Size: 119 KiB |
|
After Width: | Height: | Size: 265 KiB |
@ -0,0 +1,138 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""This script defines CheckpointDumper class.
|
||||
|
||||
This class serves as a base class for other deeplearning checkpoint dumper
|
||||
classes and defines common methods, attributes etc.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
|
||||
class CheckpointDumper(object):
|
||||
|
||||
"""Base Checkpoint Dumper class.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
checkpoint_file : str
|
||||
Path to the model checkpoint
|
||||
FILENAME_CHARS : str
|
||||
Allowed file char names
|
||||
manifest : dict
|
||||
Manifest file defining variables
|
||||
output_dir : str
|
||||
Output directory path
|
||||
remove_variables_regex : str
|
||||
Regex expression for variables to be ignored
|
||||
remove_variables_regex_re : sre.SRE_Pattern
|
||||
Compiled `remove variable` regex
|
||||
"""
|
||||
|
||||
FILENAME_CHARS = string.ascii_letters + string.digits + '_'
|
||||
|
||||
def __init__(self, checkpoint_file, output_dir, remove_variables_regex):
|
||||
"""Constructs object for Checkpoint Dumper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_file : str
|
||||
Path to the model checkpoint
|
||||
output_dir : str
|
||||
Output directory path
|
||||
remove_variables_regex : str
|
||||
Regex expression for variables to be ignored
|
||||
"""
|
||||
self.checkpoint_file = os.path.expanduser(checkpoint_file)
|
||||
self.output_dir = os.path.expanduser(output_dir)
|
||||
self.remove_variables_regex = remove_variables_regex
|
||||
|
||||
self.manifest = {}
|
||||
self.remove_variables_regex_re = re.compile(self.remove_variables_regex)
|
||||
|
||||
self.make_dir(self.output_dir)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def make_dir(directory):
|
||||
"""Makes directory if not existing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
directory : str
|
||||
Path to directory
|
||||
"""
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
|
||||
def should_ignore(self, name):
|
||||
"""Checks whether name should be ignored or not.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name to be checked
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
Whether to ignore the name or not
|
||||
"""
|
||||
return self.remove_variables_regex and re.match(self.remove_variables_regex_re, name)
|
||||
|
||||
|
||||
def dump_weights(self, variable_name, filename, shape, weights):
|
||||
"""Creates a file with given name and dumps byte weights in it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
variable_name : str
|
||||
Name of given variable
|
||||
filename : str
|
||||
File name for given variable
|
||||
shape : list
|
||||
Shape of given variable
|
||||
weights : ndarray
|
||||
Weights for given variable
|
||||
"""
|
||||
self.manifest[variable_name] = {'filename': filename, 'shape': shape}
|
||||
|
||||
print('Writing variable ' + variable_name + '...')
|
||||
with open(os.path.join(self.output_dir, filename), 'wb') as f:
|
||||
f.write(weights.tobytes())
|
||||
|
||||
|
||||
def dump_manifest(self, filename='manifest.json'):
|
||||
"""Creates a manifest file with given name and dumps meta information
|
||||
related to model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str, optional
|
||||
Manifest file name
|
||||
"""
|
||||
manifest_fpath = os.path.join(self.output_dir, filename)
|
||||
|
||||
print('Writing manifest to ' + manifest_fpath)
|
||||
with open(manifest_fpath, 'w') as f:
|
||||
f.write(json.dumps(self.manifest, indent=2, sort_keys=True))
|
||||
@ -0,0 +1,95 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
This script is an entry point for dumping checkpoints for various deeplearning
|
||||
frameworks.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def get_checkpoint_dumper(model_type, checkpoint_file, output_dir, remove_variables_regex):
|
||||
"""Returns Checkpoint dumper instance for a given model type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_type : str
|
||||
Type of deeplearning framework
|
||||
checkpoint_file : str
|
||||
Path to checkpoint file
|
||||
output_dir : str
|
||||
Path to output directory
|
||||
remove_variables_regex : str
|
||||
Regex for variables to be ignored
|
||||
|
||||
Returns
|
||||
-------
|
||||
(TensorflowCheckpointDumper, PytorchCheckpointDumper)
|
||||
Checkpoint Dumper Instance for corresponding model type
|
||||
|
||||
Raises
|
||||
------
|
||||
Error
|
||||
If particular model type is not supported
|
||||
"""
|
||||
if model_type == 'tensorflow':
|
||||
from tensorflow_checkpoint_dumper import TensorflowCheckpointDumper
|
||||
|
||||
return TensorflowCheckpointDumper(
|
||||
checkpoint_file, output_dir, remove_variables_regex)
|
||||
elif model_type == 'pytorch':
|
||||
from pytorch_checkpoint_dumper import PytorchCheckpointDumper
|
||||
|
||||
return PytorchCheckpointDumper(
|
||||
checkpoint_file, output_dir, remove_variables_regex)
|
||||
else:
|
||||
raise Error('Currently, "%s" models are not supported'.format(model_type))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Model checkpoint type')
|
||||
parser.add_argument(
|
||||
'--checkpoint_file',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to the model checkpoint')
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The output directory where to store the converted weights')
|
||||
parser.add_argument(
|
||||
'--remove_variables_regex',
|
||||
type=str,
|
||||
default='',
|
||||
help='A regular expression to match against variable names that should '
|
||||
'not be included')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
if unparsed:
|
||||
parser.print_help()
|
||||
print('Unrecognized flags: ', unparsed)
|
||||
exit(-1)
|
||||
|
||||
checkpoint_dumper = get_checkpoint_dumper(
|
||||
FLAGS.model_type, FLAGS.checkpoint_file, FLAGS.output_dir, FLAGS.remove_variables_regex)
|
||||
checkpoint_dumper.build_and_dump_vars()
|
||||
@ -0,0 +1,104 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""This script defines PytorchCheckpointDumper class.
|
||||
|
||||
This class takes a pytorch checkpoint file and writes all of the variables in the
|
||||
checkpoint to a directory which deeplearnjs can take as input.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six import iteritems
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from checkpoint_dumper import CheckpointDumper
|
||||
|
||||
class PytorchCheckpointDumper(CheckpointDumper):
|
||||
|
||||
"""Class for dumping Pytorch Checkpoints.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
state_dictionary : dict
|
||||
Dictionary defining checkpoint variables and weights
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_file, output_dir, remove_variables_regex):
|
||||
"""Constructs object for Pytorch Checkpoint Dumper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_file : str
|
||||
Path to the model checkpoint
|
||||
output_dir : str
|
||||
Output directory path
|
||||
remove_variables_regex : str
|
||||
Regex expression for variables to be ignored
|
||||
"""
|
||||
super(PytorchCheckpointDumper, self).__init__(
|
||||
checkpoint_file, output_dir, remove_variables_regex)
|
||||
|
||||
self.state_dictionary = torch.load(self.checkpoint_file)
|
||||
|
||||
def var_name_to_filename(self, var_name):
|
||||
"""Converts variable names to standard file names.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
var_name : str
|
||||
Variable name to be converted
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Standardized file name
|
||||
"""
|
||||
chars = []
|
||||
|
||||
for c in var_name:
|
||||
if c in CheckpointDumper.FILENAME_CHARS:
|
||||
chars.append(c)
|
||||
elif c == '.':
|
||||
chars.append('_')
|
||||
|
||||
return ''.join(chars)
|
||||
|
||||
def build_and_dump_vars(self):
|
||||
"""Builds and dumps variables and a manifest file.
|
||||
"""
|
||||
for (var_name, var_weights) in iteritems(self.state_dictionary):
|
||||
if (self.should_ignore(var_name)):
|
||||
print('Ignoring ' + var_name)
|
||||
continue
|
||||
|
||||
var_filename = self.var_name_to_filename(var_name)
|
||||
var_shape = list(map(int, list(var_weights.size())))
|
||||
tensor = var_weights.cpu().numpy()
|
||||
|
||||
self.dump_weights(var_name, var_filename, var_shape, tensor)
|
||||
|
||||
self.dump_manifest()
|
||||
@ -0,0 +1,103 @@
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""This script defines TensorflowCheckpointDumper class.
|
||||
|
||||
This class takes a tensorflow checkpoint file and writes all of the variables in the
|
||||
checkpoint to a directory which deeplearnjs can take as input.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six import iteritems
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from checkpoint_dumper import CheckpointDumper
|
||||
|
||||
class TensorflowCheckpointDumper(CheckpointDumper):
|
||||
|
||||
"""Class for dumping Tensorflow Checkpoints.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
reader : NewCheckpointReader
|
||||
Reader for given tensorflow checkpoint
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_file, output_dir, remove_variables_regex):
|
||||
"""Constructs object for Tensorflow Checkpoint Dumper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_file : str
|
||||
Path to the model checkpoint
|
||||
output_dir : str
|
||||
Output directory path
|
||||
remove_variables_regex : str
|
||||
Regex expression for variables to be ignored
|
||||
"""
|
||||
super(TensorflowCheckpointDumper, self).__init__(
|
||||
checkpoint_file, output_dir, remove_variables_regex)
|
||||
|
||||
self.reader = tf.train.NewCheckpointReader(self.checkpoint_file)
|
||||
|
||||
def var_name_to_filename(self, var_name):
|
||||
"""Converts variable names to standard file names.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
var_name : str
|
||||
Variable name to be converted
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Standardized file name
|
||||
"""
|
||||
chars = []
|
||||
|
||||
for c in var_name:
|
||||
if c in CheckpointDumper.FILENAME_CHARS:
|
||||
chars.append(c)
|
||||
elif c == '/':
|
||||
chars.append('_')
|
||||
|
||||
return ''.join(chars)
|
||||
|
||||
def build_and_dump_vars(self):
|
||||
"""Builds and dumps variables and a manifest file.
|
||||
"""
|
||||
var_to_shape_map = self.reader.get_variable_to_shape_map()
|
||||
|
||||
for (var_name, var_shape) in iteritems(var_to_shape_map):
|
||||
if self.should_ignore(var_name) or var_name == 'global_step':
|
||||
print('Ignoring ' + var_name)
|
||||
continue
|
||||
|
||||
var_filename = self.var_name_to_filename(var_name)
|
||||
self.manifest[var_name] = {'filename': var_filename, 'shape': var_shape}
|
||||
|
||||
tensor = self.reader.get_tensor(var_name)
|
||||
self.dump_weights(var_name, var_filename, var_shape, tensor)
|
||||
|
||||
self.dump_manifest()
|
||||
@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
import subprocess as sp
|
||||
import json
|
||||
import struct
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def log_quantize(data, mu, bins):
|
||||
# mu-law encoding
|
||||
scale = np.max(np.abs(data))
|
||||
norm_data = data / scale
|
||||
log_data = np.sign(data) * np.log(1 + mu * np.abs(norm_data)) / np.log(1 + mu)
|
||||
|
||||
_counts, edges = np.histogram(log_data, bins=bins)
|
||||
log_points = (edges[:-1] + edges[1:]) / 2
|
||||
return np.sign(log_points) * (1 / mu) * ((1 + mu)**np.abs(log_points) - 1) * scale
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint", required=True, help="directory with checkpoint to resume training from or use for testing")
|
||||
parser.add_argument("--output_file", required=True, help="where to write output")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = None
|
||||
with open(os.path.join(args.checkpoint, "checkpoint")) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line == "":
|
||||
continue
|
||||
key, _sep, val = line.partition(": ")
|
||||
val = val[1:-1] # remove quotes
|
||||
if key == "model_checkpoint_path":
|
||||
model_path = val
|
||||
|
||||
if model_path is None:
|
||||
raise Exception("failed to find model path")
|
||||
|
||||
checkpoint_file = os.path.join(args.checkpoint, model_path)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
cmd = ["python", "-u", os.path.join(SCRIPT_DIR, "dump_checkpoints/dump_checkpoint_vars.py"), "--model_type", "tensorflow", "--output_dir", tmp_dir, "--checkpoint_file", checkpoint_file]
|
||||
sp.check_call(cmd)
|
||||
|
||||
with open(os.path.join(tmp_dir, "manifest.json")) as f:
|
||||
manifest = json.loads(f.read())
|
||||
|
||||
names = []
|
||||
for key in manifest.keys():
|
||||
if not key.startswith("generator") or "Adam" in key or "_loss" in key or "_train" in key or "_moving_" in key:
|
||||
continue
|
||||
names.append(key)
|
||||
names = sorted(names)
|
||||
|
||||
arrays = []
|
||||
for name in names:
|
||||
value = manifest[name]
|
||||
with open(os.path.join(tmp_dir, value["filename"]), "rb") as f:
|
||||
arr = np.frombuffer(f.read(), dtype=np.float32).copy().reshape(value["shape"])
|
||||
arrays.append(arr)
|
||||
|
||||
shapes = []
|
||||
for name, arr in zip(names, arrays):
|
||||
shapes.append(dict(
|
||||
name=name,
|
||||
shape=arr.shape,
|
||||
))
|
||||
|
||||
flat = np.hstack([arr.reshape(-1) for arr in arrays])
|
||||
|
||||
start = time.time()
|
||||
index = log_quantize(flat, mu=255, bins=256).astype(np.float32)
|
||||
print("index found in %0.2fs" % (time.time() - start))
|
||||
|
||||
print("quantizing")
|
||||
encoded = np.zeros(flat.shape, dtype=np.uint8)
|
||||
elem_count = 0
|
||||
for i, x in enumerate(flat):
|
||||
distances = np.abs(index - x)
|
||||
nearest = np.argmin(distances)
|
||||
encoded[i] = nearest
|
||||
elem_count += 1
|
||||
if elem_count % 1000000 == 0:
|
||||
print("rate", int(elem_count / (time.time() - start)))
|
||||
|
||||
with open(args.output_file, "wb") as f:
|
||||
def write(name, buf):
|
||||
print("%s bytes %d" % (name, len(buf)))
|
||||
f.write(struct.pack(">L", len(buf)))
|
||||
f.write(buf)
|
||||
|
||||
write("shape", json.dumps(shapes).encode("utf8"))
|
||||
write("index", index.tobytes())
|
||||
write("encoded", encoded.tobytes())
|
||||
|
||||
main()
|
||||
@ -0,0 +1,112 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
|
||||
# from python 3.3 source
|
||||
# https://github.com/python/cpython/blob/master/Lib/shutil.py
|
||||
def which(cmd, mode=os.F_OK | os.X_OK, path=None):
|
||||
"""Given a command, mode, and a PATH string, return the path which
|
||||
conforms to the given mode on the PATH, or None if there is no such
|
||||
file.
|
||||
`mode` defaults to os.F_OK | os.X_OK. `path` defaults to the result
|
||||
of os.environ.get("PATH"), or can be overridden with a custom search
|
||||
path.
|
||||
"""
|
||||
# Check that a given file can be accessed with the correct mode.
|
||||
# Additionally check that `file` is not a directory, as on Windows
|
||||
# directories pass the os.access check.
|
||||
def _access_check(fn, mode):
|
||||
return (os.path.exists(fn) and os.access(fn, mode)
|
||||
and not os.path.isdir(fn))
|
||||
|
||||
# If we're given a path with a directory part, look it up directly rather
|
||||
# than referring to PATH directories. This includes checking relative to the
|
||||
# current directory, e.g. ./script
|
||||
if os.path.dirname(cmd):
|
||||
if _access_check(cmd, mode):
|
||||
return cmd
|
||||
return None
|
||||
|
||||
if path is None:
|
||||
path = os.environ.get("PATH", os.defpath)
|
||||
if not path:
|
||||
return None
|
||||
path = path.split(os.pathsep)
|
||||
|
||||
if sys.platform == "win32":
|
||||
# The current directory takes precedence on Windows.
|
||||
if not os.curdir in path:
|
||||
path.insert(0, os.curdir)
|
||||
|
||||
# PATHEXT is necessary to check on Windows.
|
||||
pathext = os.environ.get("PATHEXT", "").split(os.pathsep)
|
||||
# See if the given file matches any of the expected path extensions.
|
||||
# This will allow us to short circuit when given "python.exe".
|
||||
# If it does match, only test that one, otherwise we have to try
|
||||
# others.
|
||||
if any(cmd.lower().endswith(ext.lower()) for ext in pathext):
|
||||
files = [cmd]
|
||||
else:
|
||||
files = [cmd + ext for ext in pathext]
|
||||
else:
|
||||
# On other platforms you don't have things like PATHEXT to tell you
|
||||
# what file suffixes are executable, so just pass on cmd as-is.
|
||||
files = [cmd]
|
||||
|
||||
seen = set()
|
||||
for dir in path:
|
||||
normdir = os.path.normcase(dir)
|
||||
if not normdir in seen:
|
||||
seen.add(normdir)
|
||||
for thefile in files:
|
||||
name = os.path.join(dir, thefile)
|
||||
if _access_check(name, mode):
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
cmd = sys.argv[1:]
|
||||
|
||||
# check if nvidia-docker or docker are on path
|
||||
docker_path = which("nvidia-docker")
|
||||
if docker_path is None:
|
||||
docker_path = which("docker")
|
||||
|
||||
if docker_path is None:
|
||||
raise Exception("docker not found")
|
||||
|
||||
docker_args = [
|
||||
"--rm",
|
||||
"--volume",
|
||||
"/:/host",
|
||||
"--workdir",
|
||||
"/host" + os.getcwd(),
|
||||
"--env",
|
||||
"PYTHONUNBUFFERED=x",
|
||||
"--env",
|
||||
"CUDA_CACHE_PATH=/host/tmp/cuda-cache",
|
||||
]
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
docker_args.extend(["--env", "CUDA_VISIBLE_DEVICES=%s" % os.environ["CUDA_VISIBLE_DEVICES"]])
|
||||
|
||||
for i, arg in enumerate(cmd):
|
||||
# change absolute paths
|
||||
if arg.startswith("/"):
|
||||
cmd[i] = "/host" + arg
|
||||
|
||||
args = [docker_path, "run"] + docker_args + ["affinelayer/pix2pix-tensorflow:v3"] + cmd
|
||||
|
||||
if not os.access("/var/run/docker.sock", os.R_OK):
|
||||
args = ["sudo"] + args
|
||||
|
||||
print("running", " ".join(shlex.quote(a) for a in args))
|
||||
os.execvp(args[0], args)
|
||||
|
||||
|
||||
main()
|
||||
@ -0,0 +1,24 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
try:
|
||||
from urllib.request import urlopen # python 3
|
||||
except ImportError:
|
||||
from urllib2 import urlopen # python 2
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
dataset = sys.argv[1]
|
||||
url = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz" % dataset
|
||||
with tempfile.TemporaryFile() as tmp:
|
||||
print("downloading", url)
|
||||
shutil.copyfileobj(urlopen(url), tmp)
|
||||
print("extracting")
|
||||
tmp.seek(0)
|
||||
tar = tarfile.open(fileobj=tmp)
|
||||
tar.extractall()
|
||||
tar.close()
|
||||
print("done")
|
||||
@ -0,0 +1,306 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
import subprocess
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import tfimage as im
|
||||
import threading
|
||||
import time
|
||||
import multiprocessing
|
||||
|
||||
edge_pool = None
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_dir", required=True, help="path to folder containing images")
|
||||
parser.add_argument("--output_dir", required=True, help="output path")
|
||||
parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine", "edges"])
|
||||
parser.add_argument("--workers", type=int, default=1, help="number of workers")
|
||||
# resize
|
||||
parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation")
|
||||
parser.add_argument("--size", type=int, default=256, help="size to use for resize operation")
|
||||
# combine
|
||||
parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation")
|
||||
a = parser.parse_args()
|
||||
|
||||
|
||||
def resize(src):
|
||||
height, width, _ = src.shape
|
||||
dst = src
|
||||
if height != width:
|
||||
if a.pad:
|
||||
size = max(height, width)
|
||||
# pad to correct ratio
|
||||
oh = (size - height) // 2
|
||||
ow = (size - width) // 2
|
||||
dst = im.pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
|
||||
else:
|
||||
# crop to correct ratio
|
||||
size = min(height, width)
|
||||
oh = (height - size) // 2
|
||||
ow = (width - size) // 2
|
||||
dst = im.crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
|
||||
|
||||
assert(dst.shape[0] == dst.shape[1])
|
||||
|
||||
size, _, _ = dst.shape
|
||||
if size > a.size:
|
||||
dst = im.downscale(images=dst, size=[a.size, a.size])
|
||||
elif size < a.size:
|
||||
dst = im.upscale(images=dst, size=[a.size, a.size])
|
||||
return dst
|
||||
|
||||
|
||||
def blank(src):
|
||||
height, width, _ = src.shape
|
||||
if height != width:
|
||||
raise Exception("non-square image")
|
||||
|
||||
image_size = width
|
||||
size = int(image_size * 0.3)
|
||||
offset = int(image_size / 2 - size / 2)
|
||||
|
||||
dst = src
|
||||
dst[offset:offset + size,offset:offset + size,:] = np.ones([size, size, 3])
|
||||
return dst
|
||||
|
||||
|
||||
def combine(src, src_path):
|
||||
if a.b_dir is None:
|
||||
raise Exception("missing b_dir")
|
||||
|
||||
# find corresponding file in b_dir, could have a different extension
|
||||
basename, _ = os.path.splitext(os.path.basename(src_path))
|
||||
for ext in [".png", ".jpg"]:
|
||||
sibling_path = os.path.join(a.b_dir, basename + ext)
|
||||
if os.path.exists(sibling_path):
|
||||
sibling = im.load(sibling_path)
|
||||
break
|
||||
else:
|
||||
raise Exception("could not find sibling image for " + src_path)
|
||||
|
||||
# make sure that dimensions are correct
|
||||
height, width, _ = src.shape
|
||||
if height != sibling.shape[0] or width != sibling.shape[1]:
|
||||
raise Exception("differing sizes")
|
||||
|
||||
# convert both images to RGB if necessary
|
||||
if src.shape[2] == 1:
|
||||
src = im.grayscale_to_rgb(images=src)
|
||||
|
||||
if sibling.shape[2] == 1:
|
||||
sibling = im.grayscale_to_rgb(images=sibling)
|
||||
|
||||
# remove alpha channel
|
||||
if src.shape[2] == 4:
|
||||
src = src[:,:,:3]
|
||||
|
||||
if sibling.shape[2] == 4:
|
||||
sibling = sibling[:,:,:3]
|
||||
|
||||
return np.concatenate([src, sibling], axis=1)
|
||||
|
||||
|
||||
def grayscale(src):
|
||||
return im.grayscale_to_rgb(images=im.rgb_to_grayscale(images=src))
|
||||
|
||||
|
||||
net = None
|
||||
def run_caffe(src):
|
||||
# lazy load caffe and create net
|
||||
global net
|
||||
if net is None:
|
||||
# don't require caffe unless we are doing edge detection
|
||||
os.environ["GLOG_minloglevel"] = "2" # disable logging from caffe
|
||||
import caffe
|
||||
# using this requires using the docker image or assembling a bunch of dependencies
|
||||
# and then changing these hardcoded paths
|
||||
net = caffe.Net("/opt/caffe/examples/hed/deploy.prototxt", "/opt/caffe/hed_pretrained_bsds.caffemodel", caffe.TEST)
|
||||
|
||||
net.blobs["data"].reshape(1, *src.shape)
|
||||
net.blobs["data"].data[...] = src
|
||||
net.forward()
|
||||
return net.blobs["sigmoid-fuse"].data[0][0,:,:]
|
||||
|
||||
|
||||
def edges(src):
|
||||
# based on https://github.com/phillipi/pix2pix/blob/master/scripts/edges/batch_hed.py
|
||||
# and https://github.com/phillipi/pix2pix/blob/master/scripts/edges/PostprocessHED.m
|
||||
import scipy.io
|
||||
src = src * 255
|
||||
border = 128 # put a padding around images since edge detection seems to detect edge of image
|
||||
src = src[:,:,:3] # remove alpha channel if present
|
||||
src = np.pad(src, ((border, border), (border, border), (0,0)), "reflect")
|
||||
src = src[:,:,::-1]
|
||||
src -= np.array((104.00698793,116.66876762,122.67891434))
|
||||
src = src.transpose((2, 0, 1))
|
||||
|
||||
# [height, width, channels] => [batch, channel, height, width]
|
||||
fuse = edge_pool.apply(run_caffe, [src])
|
||||
fuse = fuse[border:-border, border:-border]
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png") as png_file, tempfile.NamedTemporaryFile(suffix=".mat") as mat_file:
|
||||
scipy.io.savemat(mat_file.name, {"input": fuse})
|
||||
|
||||
octave_code = r"""
|
||||
E = 1-load(input_path).input;
|
||||
E = imresize(E, [image_width,image_width]);
|
||||
E = 1 - E;
|
||||
E = single(E);
|
||||
[Ox, Oy] = gradient(convTri(E, 4), 1);
|
||||
[Oxx, ~] = gradient(Ox, 1);
|
||||
[Oxy, Oyy] = gradient(Oy, 1);
|
||||
O = mod(atan(Oyy .* sign(-Oxy) ./ (Oxx + 1e-5)), pi);
|
||||
E = edgesNmsMex(E, O, 1, 5, 1.01, 1);
|
||||
E = double(E >= max(eps, threshold));
|
||||
E = bwmorph(E, 'thin', inf);
|
||||
E = bwareaopen(E, small_edge);
|
||||
E = 1 - E;
|
||||
E = uint8(E * 255);
|
||||
imwrite(E, output_path);
|
||||
"""
|
||||
|
||||
config = dict(
|
||||
input_path="'%s'" % mat_file.name,
|
||||
output_path="'%s'" % png_file.name,
|
||||
image_width=256,
|
||||
threshold=25.0/255.0,
|
||||
small_edge=5,
|
||||
)
|
||||
|
||||
args = ["octave"]
|
||||
for k, v in config.items():
|
||||
args.extend(["--eval", "%s=%s;" % (k, v)])
|
||||
|
||||
args.extend(["--eval", octave_code])
|
||||
try:
|
||||
subprocess.check_output(args, stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("octave failed")
|
||||
print("returncode:", e.returncode)
|
||||
print("output:", e.output)
|
||||
raise
|
||||
return im.load(png_file.name)
|
||||
|
||||
|
||||
def process(src_path, dst_path):
|
||||
src = im.load(src_path)
|
||||
|
||||
if a.operation == "grayscale":
|
||||
dst = grayscale(src)
|
||||
elif a.operation == "resize":
|
||||
dst = resize(src)
|
||||
elif a.operation == "blank":
|
||||
dst = blank(src)
|
||||
elif a.operation == "combine":
|
||||
dst = combine(src, src_path)
|
||||
elif a.operation == "edges":
|
||||
dst = edges(src)
|
||||
else:
|
||||
raise Exception("invalid operation")
|
||||
|
||||
im.save(dst, dst_path)
|
||||
|
||||
|
||||
complete_lock = threading.Lock()
|
||||
start = None
|
||||
num_complete = 0
|
||||
total = 0
|
||||
|
||||
def complete():
|
||||
global num_complete, rate, last_complete
|
||||
|
||||
with complete_lock:
|
||||
num_complete += 1
|
||||
now = time.time()
|
||||
elapsed = now - start
|
||||
rate = num_complete / elapsed
|
||||
if rate > 0:
|
||||
remaining = (total - num_complete) / rate
|
||||
else:
|
||||
remaining = 0
|
||||
|
||||
print("%d/%d complete %0.2f images/sec %dm%ds elapsed %dm%ds remaining" % (num_complete, total, rate, elapsed // 60, elapsed % 60, remaining // 60, remaining % 60))
|
||||
|
||||
last_complete = now
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.exists(a.output_dir):
|
||||
os.makedirs(a.output_dir)
|
||||
|
||||
src_paths = []
|
||||
dst_paths = []
|
||||
|
||||
skipped = 0
|
||||
for src_path in im.find(a.input_dir):
|
||||
name, _ = os.path.splitext(os.path.basename(src_path))
|
||||
dst_path = os.path.join(a.output_dir, name + ".png")
|
||||
if os.path.exists(dst_path):
|
||||
skipped += 1
|
||||
else:
|
||||
src_paths.append(src_path)
|
||||
dst_paths.append(dst_path)
|
||||
|
||||
print("skipping %d files that already exist" % skipped)
|
||||
|
||||
global total
|
||||
total = len(src_paths)
|
||||
|
||||
print("processing %d files" % total)
|
||||
|
||||
global start
|
||||
start = time.time()
|
||||
|
||||
if a.operation == "edges":
|
||||
# use a multiprocessing pool for this operation so it can use multiple CPUs
|
||||
# create the pool before we launch processing threads
|
||||
global edge_pool
|
||||
edge_pool = multiprocessing.Pool(a.workers)
|
||||
|
||||
if a.workers == 1:
|
||||
with tf.Session() as sess:
|
||||
for src_path, dst_path in zip(src_paths, dst_paths):
|
||||
process(src_path, dst_path)
|
||||
complete()
|
||||
else:
|
||||
queue = tf.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1)
|
||||
dequeue_op = queue.dequeue()
|
||||
|
||||
def worker(coord):
|
||||
with sess.as_default():
|
||||
while not coord.should_stop():
|
||||
try:
|
||||
src_path, dst_path = sess.run(dequeue_op)
|
||||
except tf.errors.OutOfRangeError:
|
||||
coord.request_stop()
|
||||
break
|
||||
|
||||
process(src_path, dst_path)
|
||||
complete()
|
||||
|
||||
# init epoch counter for the queue
|
||||
local_init_op = tf.local_variables_initializer()
|
||||
with tf.Session() as sess:
|
||||
sess.run(local_init_op)
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
threads = tf.train.start_queue_runners(coord=coord)
|
||||
for i in range(a.workers):
|
||||
t = threading.Thread(target=worker, args=(coord,))
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
try:
|
||||
coord.join(threads)
|
||||
except KeyboardInterrupt:
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
main()
|
||||
@ -0,0 +1,45 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dir", type=str, required=True, help="path to folder containing images")
|
||||
parser.add_argument("--train_frac", type=float, default=0.8, help="percentage of images to use for training set")
|
||||
parser.add_argument("--test_frac", type=float, default=0.0, help="percentage of images to use for test set")
|
||||
parser.add_argument("--sort", action="store_true", help="if set, sort the images instead of shuffling them")
|
||||
a = parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
random.seed(0)
|
||||
|
||||
files = glob.glob(os.path.join(a.dir, "*.png"))
|
||||
files.sort()
|
||||
|
||||
assignments = []
|
||||
assignments.extend(["train"] * int(a.train_frac * len(files)))
|
||||
assignments.extend(["test"] * int(a.test_frac * len(files)))
|
||||
assignments.extend(["val"] * int(len(files) - len(assignments)))
|
||||
|
||||
if not a.sort:
|
||||
random.shuffle(assignments)
|
||||
|
||||
for name in ["train", "val", "test"]:
|
||||
if name in assignments:
|
||||
d = os.path.join(a.dir, name)
|
||||
if not os.path.exists(d):
|
||||
os.makedirs(d)
|
||||
|
||||
print(len(files), len(assignments))
|
||||
for inpath, assignment in zip(files, assignments):
|
||||
outpath = os.path.join(a.dir, assignment, os.path.basename(inpath))
|
||||
print(inpath, "->", outpath)
|
||||
os.rename(inpath, outpath)
|
||||
|
||||
main()
|
||||
@ -0,0 +1,63 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import shutil
|
||||
import shlex
|
||||
|
||||
INPUT_DIR = os.path.abspath("../data")
|
||||
OUTPUT_DIR = os.path.expanduser("~/data/pix2pix/test")
|
||||
|
||||
|
||||
def main():
|
||||
start = time.time()
|
||||
|
||||
images = {
|
||||
"affinelayer": "affinelayer/pix2pix-tensorflow:v3",
|
||||
# "py2-tensorflow": "tensorflow/tensorflow:1.4.1-gpu",
|
||||
# "py3-tensorflow": "tensorflow/tensorflow:1.4.1-gpu-py3",
|
||||
}
|
||||
|
||||
if os.path.exists(OUTPUT_DIR):
|
||||
shutil.rmtree(OUTPUT_DIR)
|
||||
|
||||
for image_name, image in images.items():
|
||||
def run(cmd):
|
||||
docker = "docker"
|
||||
if sys.platform.startswith("linux"):
|
||||
docker = "nvidia-docker"
|
||||
|
||||
prefix = [docker, "run", "--rm", "--volume", os.getcwd() + ":/prj", "--volume", INPUT_DIR + ":/input", "--volume", os.path.join(OUTPUT_DIR, image_name) + ":/output","--workdir", "/prj", "--env", "PYTHONUNBUFFERED=x", "--volume", "/tmp/cuda-cache:/cuda-cache", "--env", "CUDA_CACHE_PATH=/cuda-cache", image]
|
||||
args = prefix + shlex.split(cmd)
|
||||
print(" ".join(args))
|
||||
subprocess.check_call(args)
|
||||
|
||||
run("python tools/process.py --input_dir /input/pusheen/original --operation resize --output_dir /output/process_resize")
|
||||
if image_name == "affinelayer":
|
||||
run("python tools/process.py --input_dir /output/process_resize --operation edges --output_dir /output/process_edges")
|
||||
|
||||
for direction in ["AtoB", "BtoA"]:
|
||||
for dataset in ["facades", "maps"]:
|
||||
name = dataset + "_" + direction
|
||||
run("python pix2pix.py --mode train --input_dir /input/official/%s/train --output_dir /output/%s_train --display_freq 1 --max_steps 1 --which_direction %s --seed 0" % (dataset, name, direction))
|
||||
run("python pix2pix.py --mode test --input_dir /input/official/%s/val --output_dir /output/%s_test --display_freq 1 --max_steps 1 --checkpoint /output/%s_train --seed 0" % (dataset, name, name))
|
||||
|
||||
dataset = "color-lab"
|
||||
name = dataset + "_" + direction
|
||||
run("python pix2pix.py --mode train --input_dir /input/%s/train --output_dir /output/%s_train --display_freq 1 --max_steps 1 --which_direction %s --lab_colorization --seed 0" % (dataset, name, direction))
|
||||
run("python pix2pix.py --mode test --input_dir /input/%s/val --output_dir /output/%s_test --display_freq 1 --max_steps 1 --checkpoint /output/%s_train --seed 0" % (dataset, name, name))
|
||||
|
||||
# using pretrained model
|
||||
# for dataset, direction in [("facades", "BtoA")]:
|
||||
# name = dataset + "_" + direction
|
||||
# run("python pix2pix.py --mode test --output_dir test/%s_pretrained_test --input_dir /input/official/%s/val --max_steps 100 --which_direction %s --seed 0 --checkpoint /input/pretrained/%s" % (name, dataset, direction, name))
|
||||
# run("python pix2pix.py --mode export --output_dir test/%s_pretrained_export --checkpoint /input/pretrained/%s" % (name, name))
|
||||
|
||||
print("elapsed", int(time.time() - start))
|
||||
|
||||
|
||||
main()
|
||||
@ -0,0 +1,144 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
|
||||
def create_op(func, **placeholders):
|
||||
op = func(**placeholders)
|
||||
|
||||
def f(**kwargs):
|
||||
feed_dict = {}
|
||||
for argname, argvalue in kwargs.items():
|
||||
placeholder = placeholders[argname]
|
||||
feed_dict[placeholder] = argvalue
|
||||
return tf.get_default_session().run(op, feed_dict=feed_dict)
|
||||
|
||||
return f
|
||||
|
||||
downscale = create_op(
|
||||
func=tf.image.resize_images,
|
||||
images=tf.placeholder(tf.float32, [None, None, None]),
|
||||
size=tf.placeholder(tf.int32, [2]),
|
||||
method=tf.image.ResizeMethod.AREA,
|
||||
)
|
||||
|
||||
upscale = create_op(
|
||||
func=tf.image.resize_images,
|
||||
images=tf.placeholder(tf.float32, [None, None, None]),
|
||||
size=tf.placeholder(tf.int32, [2]),
|
||||
method=tf.image.ResizeMethod.BICUBIC,
|
||||
)
|
||||
|
||||
decode_jpeg = create_op(
|
||||
func=tf.image.decode_jpeg,
|
||||
contents=tf.placeholder(tf.string),
|
||||
)
|
||||
|
||||
decode_png = create_op(
|
||||
func=tf.image.decode_png,
|
||||
contents=tf.placeholder(tf.string),
|
||||
)
|
||||
|
||||
rgb_to_grayscale = create_op(
|
||||
func=tf.image.rgb_to_grayscale,
|
||||
images=tf.placeholder(tf.float32),
|
||||
)
|
||||
|
||||
grayscale_to_rgb = create_op(
|
||||
func=tf.image.grayscale_to_rgb,
|
||||
images=tf.placeholder(tf.float32),
|
||||
)
|
||||
|
||||
encode_jpeg = create_op(
|
||||
func=tf.image.encode_jpeg,
|
||||
image=tf.placeholder(tf.uint8),
|
||||
)
|
||||
|
||||
encode_png = create_op(
|
||||
func=tf.image.encode_png,
|
||||
image=tf.placeholder(tf.uint8),
|
||||
)
|
||||
|
||||
crop = create_op(
|
||||
func=tf.image.crop_to_bounding_box,
|
||||
image=tf.placeholder(tf.float32),
|
||||
offset_height=tf.placeholder(tf.int32, []),
|
||||
offset_width=tf.placeholder(tf.int32, []),
|
||||
target_height=tf.placeholder(tf.int32, []),
|
||||
target_width=tf.placeholder(tf.int32, []),
|
||||
)
|
||||
|
||||
pad = create_op(
|
||||
func=tf.image.pad_to_bounding_box,
|
||||
image=tf.placeholder(tf.float32),
|
||||
offset_height=tf.placeholder(tf.int32, []),
|
||||
offset_width=tf.placeholder(tf.int32, []),
|
||||
target_height=tf.placeholder(tf.int32, []),
|
||||
target_width=tf.placeholder(tf.int32, []),
|
||||
)
|
||||
|
||||
to_uint8 = create_op(
|
||||
func=tf.image.convert_image_dtype,
|
||||
image=tf.placeholder(tf.float32),
|
||||
dtype=tf.uint8,
|
||||
saturate=True,
|
||||
)
|
||||
|
||||
to_float32 = create_op(
|
||||
func=tf.image.convert_image_dtype,
|
||||
image=tf.placeholder(tf.uint8),
|
||||
dtype=tf.float32,
|
||||
)
|
||||
|
||||
|
||||
def load(path):
|
||||
with open(path, "rb") as f:
|
||||
contents = f.read()
|
||||
|
||||
_, ext = os.path.splitext(path.lower())
|
||||
|
||||
if ext == ".jpg":
|
||||
image = decode_jpeg(contents=contents)
|
||||
elif ext == ".png":
|
||||
image = decode_png(contents=contents)
|
||||
else:
|
||||
raise Exception("invalid image suffix")
|
||||
|
||||
return to_float32(image=image)
|
||||
|
||||
|
||||
def find(d):
|
||||
result = []
|
||||
for filename in os.listdir(d):
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
if ext == ".jpg" or ext == ".png":
|
||||
result.append(os.path.join(d, filename))
|
||||
result.sort()
|
||||
return result
|
||||
|
||||
|
||||
def save(image, path, replace=False):
|
||||
_, ext = os.path.splitext(path.lower())
|
||||
image = to_uint8(image=image)
|
||||
if ext == ".jpg":
|
||||
encoded = encode_jpeg(image=image)
|
||||
elif ext == ".png":
|
||||
encoded = encode_png(image=image)
|
||||
else:
|
||||
raise Exception("invalid image suffix")
|
||||
|
||||
dirname = os.path.dirname(path)
|
||||
if dirname != "" and not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
if os.path.exists(path):
|
||||
if replace:
|
||||
os.remove(path)
|
||||
else:
|
||||
raise Exception("file already exists at " + path)
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(encoded)
|
||||