You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
3.8 KiB

# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This script defines CheckpointDumper class.
This class serves as a base class for other deeplearning checkpoint dumper
classes and defines common methods, attributes etc.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import re
import string
class CheckpointDumper(object):
"""Base Checkpoint Dumper class.
Attributes
----------
checkpoint_file : str
Path to the model checkpoint
FILENAME_CHARS : str
Allowed file char names
manifest : dict
Manifest file defining variables
output_dir : str
Output directory path
remove_variables_regex : str
Regex expression for variables to be ignored
remove_variables_regex_re : sre.SRE_Pattern
Compiled `remove variable` regex
"""
FILENAME_CHARS = string.ascii_letters + string.digits + '_'
def __init__(self, checkpoint_file, output_dir, remove_variables_regex):
"""Constructs object for Checkpoint Dumper.
Parameters
----------
checkpoint_file : str
Path to the model checkpoint
output_dir : str
Output directory path
remove_variables_regex : str
Regex expression for variables to be ignored
"""
self.checkpoint_file = os.path.expanduser(checkpoint_file)
self.output_dir = os.path.expanduser(output_dir)
self.remove_variables_regex = remove_variables_regex
self.manifest = {}
self.remove_variables_regex_re = re.compile(self.remove_variables_regex)
self.make_dir(self.output_dir)
@staticmethod
def make_dir(directory):
"""Makes directory if not existing.
Parameters
----------
directory : str
Path to directory
"""
if not os.path.exists(directory):
os.makedirs(directory)
def should_ignore(self, name):
"""Checks whether name should be ignored or not.
Parameters
----------
name : str
Name to be checked
Returns
-------
bool
Whether to ignore the name or not
"""
return self.remove_variables_regex and re.match(self.remove_variables_regex_re, name)
def dump_weights(self, variable_name, filename, shape, weights):
"""Creates a file with given name and dumps byte weights in it.
Parameters
----------
variable_name : str
Name of given variable
filename : str
File name for given variable
shape : list
Shape of given variable
weights : ndarray
Weights for given variable
"""
self.manifest[variable_name] = {'filename': filename, 'shape': shape}
print('Writing variable ' + variable_name + '...')
with open(os.path.join(self.output_dir, filename), 'wb') as f:
f.write(weights.tobytes())
def dump_manifest(self, filename='manifest.json'):
"""Creates a manifest file with given name and dumps meta information
related to model.
Parameters
----------
filename : str, optional
Manifest file name
"""
manifest_fpath = os.path.join(self.output_dir, filename)
print('Writing manifest to ' + manifest_fpath)
with open(manifest_fpath, 'w') as f:
f.write(json.dumps(self.manifest, indent=2, sort_keys=True))