Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lr search #82

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, cfg):
self.training = tf.placeholder(tf.bool)
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.learning_rate = tf.placeholder(tf.float32)
self.target_lr = None

def init(self, dataset, train_iterator, test_iterator):
# TF variables
Expand Down Expand Up @@ -200,7 +201,7 @@ def init_net(self, next_batch):

# You need to change the learning rate here if you are training
# from a self-play training set, for example start with 0.005 instead.
opt_op = tf.train.MomentumOptimizer(
self.opt_op = tf.train.MomentumOptimizer(
learning_rate=self.learning_rate, momentum=0.9, use_nesterov=True)

# Do swa after we contruct the net
Expand All @@ -227,19 +228,30 @@ def init_net(self, next_batch):
var.initialized_value()), trainable=False) for var in tf.trainable_variables()]
self.zero_op = [var.assign(tf.zeros_like(var))
for var in gradient_accum]
if self.cfg['training'].get('lr_search', False):
self.backup_vars = [tf.Variable(tf.zeros_like(
var.initialized_value()), trainable=False) for var in tf.trainable_variables()]
self.backup_momentums = [tf.Variable(tf.zeros_like(
var.initialized_value()), trainable=False) for var in tf.trainable_variables()]
self.opt_backup_op = None
self.opt_restore_op = None
self.last_lr = tf.Variable(0., name='last_lr', trainable=False)
self.last_lr_cached = None

self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(self.update_ops):
gradients = opt_op.compute_gradients(loss)
gradients = self.opt_op.compute_gradients(loss)
self.accum_op = [accum.assign_add(
gradient[0]) for accum, gradient in zip(gradient_accum, gradients)]
# gradients are num_batch_splits times higher due to accumulation by summing, so the norm will be too
max_grad_norm = self.cfg['training'].get(
'max_grad_norm', 10000.0) * self.cfg['training'].get('num_batch_splits', 1)
gradient_accum, self.grad_norm = tf.clip_by_global_norm(
gradient_accum, max_grad_norm)
self.train_op = opt_op.apply_gradients(
self.train_op = self.opt_op.apply_gradients(
[(accum, gradient[1]) for accum, gradient in zip(gradient_accum, gradients)], global_step=self.global_step)
self.quiet_train_op = self.opt_op.apply_gradients(
[(accum, gradient[1]) for accum, gradient in zip(gradient_accum, gradients)])

correct_policy_prediction = \
tf.equal(tf.argmax(self.y_conv, 1), tf.argmax(self.y_, 1))
Expand Down Expand Up @@ -274,6 +286,7 @@ def init_net(self, next_batch):
self.session.run(self.init)

def replace_weights(self, new_weights):
all_evals = []
for e, weights in enumerate(self.weights):
if weights.shape.ndims == 4:
# Rescale rule50 related weights as clients do not normalize the input.
Expand All @@ -295,7 +308,7 @@ def replace_weights(self, new_weights):
s = weights.shape.as_list()
shape = [s[i] for i in [3, 2, 0, 1]]
new_weight = tf.constant(new_weights[e], shape=shape)
self.session.run(weights.assign(
all_evals.append(weights.assign(
tf.transpose(new_weight, [2, 3, 1, 0])))
elif weights.shape.ndims == 2:
# Fully connected layers are [in, out] in TF
Expand All @@ -305,12 +318,13 @@ def replace_weights(self, new_weights):
s = weights.shape.as_list()
shape = [s[i] for i in [1, 0]]
new_weight = tf.constant(new_weights[e], shape=shape)
self.session.run(weights.assign(
all_evals.append(weights.assign(
tf.transpose(new_weight, [1, 0])))
else:
# Biases, batchnorm etc
new_weight = tf.constant(new_weights[e], shape=weights.shape)
self.session.run(tf.assign(weights, new_weight))
all_evals.append(tf.assign(weights, new_weight))
self.session.run(all_evals)
# This should result in identical file to the starting one
# self.save_leelaz_weights('restored.txt')

Expand Down Expand Up @@ -360,6 +374,20 @@ def process(self, batch_size, test_batches, batch_splits=1):
lr_boundaries = self.cfg['training']['lr_boundaries']
steps_total = steps % self.cfg['training']['total_steps']
self.lr = lr_values[bisect.bisect_right(lr_boundaries, steps_total)]
lr_search = self.cfg['training'].get('lr_search', False)
lr_searching = lr_search and steps % self.cfg['training']['lr_search_freq'] == 0 and steps != 0
if lr_search:
if self.last_lr_cached is None:
self.last_lr_cached = self.session.run(self.last_lr)
if self.last_lr_cached > 0:
self.lr = self.last_lr_cached
if self.target_lr is not None:
target_progress = steps % self.cfg['training']['lr_search_freq']
if target_progress == 0:
self.lr = self.target_lr
else:
self.lr = self.lr + (self.target_lr - self.lr) * (target_progress / self.cfg['training']['lr_search_freq'])
raw_lr = self.lr
if self.warmup_steps > 0 and steps < self.warmup_steps:
self.lr = self.lr * (steps + 1) / self.warmup_steps

Expand All @@ -383,6 +411,32 @@ def process(self, batch_size, test_batches, batch_splits=1):
self.avg_value_loss.append(value_loss)
self.avg_mse_loss.append(mse_loss)
self.avg_reg_term.append(reg_term)
if lr_searching:
if self.opt_backup_op is None:
self.opt_backup_op = [var.assign(val) for var, val in zip(self.backup_vars, tf.trainable_variables())] +\
[var.assign(val) for var, val in zip(self.backup_momentums, [self.opt_op.get_slot(var, "momentum") for var in tf.trainable_variables()])]
if self.opt_restore_op is None:
self.opt_restore_op = [val.assign(var) for var, val in zip(self.backup_vars, tf.trainable_variables())] +\
[val.assign(var) for var, val in zip(self.backup_momentums, [self.opt_op.get_slot(var, "momentum") for var in tf.trainable_variables()])]
self.session.run(self.opt_backup_op)
best_reg_term = None
best_x = 0
for x in np.arange(-1, 1, 0.1):
corrected_lr = raw_lr*(2**x) / batch_splits
_, grad_norm = self.session.run([self.quiet_train_op, self.grad_norm],
feed_dict={self.learning_rate: corrected_lr, self.training: True, self.handle: self.train_handle})
new_reg = self.session.run(self.reg_term)
if best_reg_term is None or new_reg < best_reg_term:
best_reg_term = new_reg
best_x = x
print("LR Search {} {}".format(raw_lr*(2**x), new_reg))
self.session.run(self.opt_restore_op)
self.last_lr_cached = raw_lr
self.session.run(self.last_lr.assign(raw_lr))
best_x = best_x / 5
self.target_lr = raw_lr*(2**best_x)
print("LR Target {} {}".format(raw_lr*(2**best_x), best_reg_term))

# Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this.
corrected_lr = self.lr / batch_splits
_, grad_norm = self.session.run([self.train_op, self.grad_norm],
Expand Down Expand Up @@ -437,6 +491,9 @@ def process(self, batch_size, test_batches, batch_splits=1):
if self.swa_enabled:
self.calculate_swa_summaries(test_batches, steps)

if lr_search and steps % self.cfg['training']['total_steps'] == 0:
self.session.run(self.last_lr.assign(self.target_lr))

# Save session and weights at end, and also optionally every 'checkpoint_steps'.
if steps % self.cfg['training']['total_steps'] == 0 or (
'checkpoint_steps' in self.cfg['training'] and steps % self.cfg['training']['checkpoint_steps'] == 0):
Expand Down