You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
2.9 KiB
104 lines
2.9 KiB
# Copyright 2017 Google Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""This script defines TensorflowCheckpointDumper class.
|
|
|
|
This class takes a tensorflow checkpoint file and writes all of the variables in the
|
|
checkpoint to a directory which deeplearnjs can take as input.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from six import iteritems
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
import tensorflow as tf
|
|
|
|
from checkpoint_dumper import CheckpointDumper
|
|
|
|
class TensorflowCheckpointDumper(CheckpointDumper):
|
|
|
|
"""Class for dumping Tensorflow Checkpoints.
|
|
|
|
Attributes
|
|
----------
|
|
reader : NewCheckpointReader
|
|
Reader for given tensorflow checkpoint
|
|
"""
|
|
|
|
def __init__(self, checkpoint_file, output_dir, remove_variables_regex):
|
|
"""Constructs object for Tensorflow Checkpoint Dumper.
|
|
|
|
Parameters
|
|
----------
|
|
checkpoint_file : str
|
|
Path to the model checkpoint
|
|
output_dir : str
|
|
Output directory path
|
|
remove_variables_regex : str
|
|
Regex expression for variables to be ignored
|
|
"""
|
|
super(TensorflowCheckpointDumper, self).__init__(
|
|
checkpoint_file, output_dir, remove_variables_regex)
|
|
|
|
self.reader = tf.train.NewCheckpointReader(self.checkpoint_file)
|
|
|
|
def var_name_to_filename(self, var_name):
|
|
"""Converts variable names to standard file names.
|
|
|
|
Parameters
|
|
----------
|
|
var_name : str
|
|
Variable name to be converted
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
Standardized file name
|
|
"""
|
|
chars = []
|
|
|
|
for c in var_name:
|
|
if c in CheckpointDumper.FILENAME_CHARS:
|
|
chars.append(c)
|
|
elif c == '/':
|
|
chars.append('_')
|
|
|
|
return ''.join(chars)
|
|
|
|
def build_and_dump_vars(self):
|
|
"""Builds and dumps variables and a manifest file.
|
|
"""
|
|
var_to_shape_map = self.reader.get_variable_to_shape_map()
|
|
|
|
for (var_name, var_shape) in iteritems(var_to_shape_map):
|
|
if self.should_ignore(var_name) or var_name == 'global_step':
|
|
print('Ignoring ' + var_name)
|
|
continue
|
|
|
|
var_filename = self.var_name_to_filename(var_name)
|
|
self.manifest[var_name] = {'filename': var_filename, 'shape': var_shape}
|
|
|
|
tensor = self.reader.get_tensor(var_name)
|
|
self.dump_weights(var_name, var_filename, var_shape, tensor)
|
|
|
|
self.dump_manifest()
|