You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
2.9 KiB

# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This script is an entry point for dumping checkpoints for various deeplearning
frameworks.
"""
from __future__ import print_function
import argparse
def get_checkpoint_dumper(model_type, checkpoint_file, output_dir, remove_variables_regex):
"""Returns Checkpoint dumper instance for a given model type.
Parameters
----------
model_type : str
Type of deeplearning framework
checkpoint_file : str
Path to checkpoint file
output_dir : str
Path to output directory
remove_variables_regex : str
Regex for variables to be ignored
Returns
-------
(TensorflowCheckpointDumper, PytorchCheckpointDumper)
Checkpoint Dumper Instance for corresponding model type
Raises
------
Error
If particular model type is not supported
"""
if model_type == 'tensorflow':
from tensorflow_checkpoint_dumper import TensorflowCheckpointDumper
return TensorflowCheckpointDumper(
checkpoint_file, output_dir, remove_variables_regex)
elif model_type == 'pytorch':
from pytorch_checkpoint_dumper import PytorchCheckpointDumper
return PytorchCheckpointDumper(
checkpoint_file, output_dir, remove_variables_regex)
else:
raise Error('Currently, "%s" models are not supported'.format(model_type))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_type',
type=str,
required=True,
help='Model checkpoint type')
parser.add_argument(
'--checkpoint_file',
type=str,
required=True,
help='Path to the model checkpoint')
parser.add_argument(
'--output_dir',
type=str,
required=True,
help='The output directory where to store the converted weights')
parser.add_argument(
'--remove_variables_regex',
type=str,
default='',
help='A regular expression to match against variable names that should '
'not be included')
FLAGS, unparsed = parser.parse_known_args()
if unparsed:
parser.print_help()
print('Unrecognized flags: ', unparsed)
exit(-1)
checkpoint_dumper = get_checkpoint_dumper(
FLAGS.model_type, FLAGS.checkpoint_file, FLAGS.output_dir, FLAGS.remove_variables_regex)
checkpoint_dumper.build_and_dump_vars()