You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.3 KiB
101 lines
3.3 KiB
import argparse
|
|
import os
|
|
import tempfile
|
|
import subprocess as sp
|
|
import json
|
|
import struct
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
def log_quantize(data, mu, bins):
|
|
# mu-law encoding
|
|
scale = np.max(np.abs(data))
|
|
norm_data = data / scale
|
|
log_data = np.sign(data) * np.log(1 + mu * np.abs(norm_data)) / np.log(1 + mu)
|
|
|
|
_counts, edges = np.histogram(log_data, bins=bins)
|
|
log_points = (edges[:-1] + edges[1:]) / 2
|
|
return np.sign(log_points) * (1 / mu) * ((1 + mu)**np.abs(log_points) - 1) * scale
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--checkpoint", required=True, help="directory with checkpoint to resume training from or use for testing")
|
|
parser.add_argument("--output_file", required=True, help="where to write output")
|
|
args = parser.parse_args()
|
|
|
|
model_path = None
|
|
with open(os.path.join(args.checkpoint, "checkpoint")) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line == "":
|
|
continue
|
|
key, _sep, val = line.partition(": ")
|
|
val = val[1:-1] # remove quotes
|
|
if key == "model_checkpoint_path":
|
|
model_path = val
|
|
|
|
if model_path is None:
|
|
raise Exception("failed to find model path")
|
|
|
|
checkpoint_file = os.path.join(args.checkpoint, model_path)
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
cmd = ["python", "-u", os.path.join(SCRIPT_DIR, "dump_checkpoints/dump_checkpoint_vars.py"), "--model_type", "tensorflow", "--output_dir", tmp_dir, "--checkpoint_file", checkpoint_file]
|
|
sp.check_call(cmd)
|
|
|
|
with open(os.path.join(tmp_dir, "manifest.json")) as f:
|
|
manifest = json.loads(f.read())
|
|
|
|
names = []
|
|
for key in manifest.keys():
|
|
if not key.startswith("generator") or "Adam" in key or "_loss" in key or "_train" in key or "_moving_" in key:
|
|
continue
|
|
names.append(key)
|
|
names = sorted(names)
|
|
|
|
arrays = []
|
|
for name in names:
|
|
value = manifest[name]
|
|
with open(os.path.join(tmp_dir, value["filename"]), "rb") as f:
|
|
arr = np.frombuffer(f.read(), dtype=np.float32).copy().reshape(value["shape"])
|
|
arrays.append(arr)
|
|
|
|
shapes = []
|
|
for name, arr in zip(names, arrays):
|
|
shapes.append(dict(
|
|
name=name,
|
|
shape=arr.shape,
|
|
))
|
|
|
|
flat = np.hstack([arr.reshape(-1) for arr in arrays])
|
|
|
|
start = time.time()
|
|
index = log_quantize(flat, mu=255, bins=256).astype(np.float32)
|
|
print("index found in %0.2fs" % (time.time() - start))
|
|
|
|
print("quantizing")
|
|
encoded = np.zeros(flat.shape, dtype=np.uint8)
|
|
elem_count = 0
|
|
for i, x in enumerate(flat):
|
|
distances = np.abs(index - x)
|
|
nearest = np.argmin(distances)
|
|
encoded[i] = nearest
|
|
elem_count += 1
|
|
if elem_count % 1000000 == 0:
|
|
print("rate", int(elem_count / (time.time() - start)))
|
|
|
|
with open(args.output_file, "wb") as f:
|
|
def write(name, buf):
|
|
print("%s bytes %d" % (name, len(buf)))
|
|
f.write(struct.pack(">L", len(buf)))
|
|
f.write(buf)
|
|
|
|
write("shape", json.dumps(shapes).encode("utf8"))
|
|
write("index", index.tobytes())
|
|
write("encoded", encoded.tobytes())
|
|
|
|
main() |