You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.7 KiB
50 lines
1.7 KiB
import math
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
def warmup_cosine(x, warmup=0.002):
|
|
s = tf.cast(x <= warmup, tf.float32)
|
|
return s*(x/warmup) + (1-s)*(0.5 * (1 + tf.cos(math.pi * x)))
|
|
|
|
def warmup_constant(x, warmup=0.002):
|
|
s = tf.cast(x <= warmup, tf.float32)
|
|
return s*(x/warmup) + (1-s)*1
|
|
|
|
def warmup_linear(x, warmup=0.002):
|
|
s = tf.cast(x <= warmup, tf.float32)
|
|
return (s*(x/warmup) + (1-s))*(1-x)
|
|
|
|
schedules = {
|
|
'warmup_cosine':warmup_cosine,
|
|
'warmup_constant':warmup_constant,
|
|
'warmup_linear':warmup_linear,
|
|
}
|
|
|
|
def adam(params, grads, lr, schedule, t_total, b1=0.9, b2=0.999, e=1e-8, l2=0, vector_l2=False, max_grad_norm=-1, **kwargs):
|
|
"""
|
|
adam with weight decay fix
|
|
"""
|
|
t = tf.Variable(0, dtype=tf.float32, trainable=False)
|
|
tt = t+1
|
|
updates = [t.assign(tt)]
|
|
if max_grad_norm > 0:
|
|
grads, _ = tf.clip_by_global_norm(grads, max_grad_norm)
|
|
for p, g in zip(params, grads):
|
|
if p is None or g is None:
|
|
print("can't train", p.name, g)
|
|
else:
|
|
if isinstance(g, tf.IndexedSlices):
|
|
g = tf.convert_to_tensor(g)
|
|
m = tf.Variable(p*0, dtype=tf.float32, trainable=False)
|
|
v = tf.Variable(p*0, dtype=tf.float32, trainable=False)
|
|
lrt = lr*tf.sqrt(1-b2**tt)/(1-b1**tt)
|
|
lrt *= schedule(t/t_total)
|
|
mt = b1*m + (1-b1)*g
|
|
vt = b2*v + (1-b2)*g*g
|
|
if (len(p.get_shape()) > 1 or vector_l2) and l2 > 0:
|
|
pt = p - lrt * (mt / (tf.sqrt(vt) + e) + l2*p)
|
|
else:
|
|
pt = p - lrt * (mt / (tf.sqrt(vt) + e))
|
|
updates.extend([m.assign(mt), v.assign(vt), p.assign(pt)])
|
|
return tf.group(*updates)
|