diff --git a/tf/aum.py b/tf/aum.py new file mode 100644 index 00000000..48ffcb30 --- /dev/null +++ b/tf/aum.py @@ -0,0 +1,116 @@ +from tensorflow import squeeze, gather +import tensorflow as tf +import tensorflow.compat +from tensorflow.random import normal +import tensorflow.math as tfmath +from dataclasses import dataclass + +@dataclass +class AUMRecord: + sample_id: int + visits: int + # target_ndx: int + target_val: float + # pred_ndx: int + pred_val: float + margin: float + aum: float + +class AUMTracker: + def __init__(self, verbose: bool = True): + self.verbose = verbose + self.records = [] + self.sums = {} + self.visits = {} + + def update(self, logits, targets, sample_ids): + targets = gather(targets, 1) + + max_logits = tfmath.reduce_max(logits) + # max_logits = squeeze(max_logits, axis = -1) + + updated = {} + + margins = targets - max_logits + + # for i in range(len(sample_ids)): + # sample_id = sample_ids[i].ref() + # margin = margins[i] + # + # if sample_id in self.sums: + # self.sums[sample_id] += margin + # self.visits[sample_id] += 1 + # else: + # self.sums[sample_id] = margin + # self.visits[sample_id] = 1 + # + # a = self.sums[sample_id] / self.visits[sample_id] + # + # record = AUMRecord( + # sample_id, + # self.visits[sample_id], + # targets[i], + # logits[i], + # margin, + # a, + # ) + # updated[sample_id] = record + # + # if self.verbose: + # self.records.append(record) + + for i, (margin, sample_id) in enumerate(zip(margins, sample_ids)): + sample_id = sample_id.ref() + if sample_id in self.sums: + self.sums[sample_id] += margin + self.visits[sample_id] += 1 + else: + self.sums[sample_id] = margin + self.visits[sample_id] = 1 + + record = AUMRecord( + sample_id, + self.visits[sample_id], + targets[i], + logits[i], + margin, + self.sums[sample_id] / self.visits[sample_id], + ) + updated[sample_id] = record + + if self.verbose: + self.records.append(record) + + return updated + + def graph(self) -> None: + if not self.verbose: + print("Not in verbose mode, did not log anything") + return + + from matplotlib import pyplot as plt + + fig = plt.figure() + + assert len(self.records) != 0 + plt.plot(self.get_all_scores()) + # print(self.get_all_scores()) + + plt.savefig("poop.png") + print("saved graph") + + def get_all_scores(self): + scores = [] + for record in self.records: + scores.append(record.aum) + return scores + + def clear(self): + self.records = [] + self.sums = {} + self.visits = {} + +if __name__ == "__main__": + tracker = AUMTracker() + tracker.update(normal([1, 4]), normal([1, 4]), [1, 2, 3, 4]) + tracker.graph() diff --git a/tf/chunkparsefunc.py b/tf/chunkparsefunc.py index 3196ecb0..75acb103 100644 --- a/tf/chunkparsefunc.py +++ b/tf/chunkparsefunc.py @@ -18,10 +18,11 @@ import tensorflow as tf -def parse_function(planes, probs, winner, q, plies_left): +def parse_function(planes, probs, winner, q, plies_left, sample_ids): """ Convert unpacked record batches to tensors for tensorflow training """ + sample_ids = tf.io.decode_raw(sample_ids, tf.int64) planes = tf.io.decode_raw(planes, tf.float32) probs = tf.io.decode_raw(probs, tf.float32) winner = tf.io.decode_raw(winner, tf.float32) @@ -31,7 +32,8 @@ def parse_function(planes, probs, winner, q, plies_left): planes = tf.reshape(planes, (-1, 112, 8, 8)) probs = tf.reshape(probs, (-1, 1858)) winner = tf.reshape(winner, (-1, 3)) + sample_ids = tf.reshape(sample_ids, (-1, 1)) q = tf.reshape(q, (-1, 3)) plies_left = tf.reshape(plies_left, (-1, 1)) - return (planes, probs, winner, q, plies_left) + return (planes, probs, winner, q, plies_left, sample_ids) diff --git a/tf/chunkparser.py b/tf/chunkparser.py index 9f6066e3..774bfd2b 100644 --- a/tf/chunkparser.py +++ b/tf/chunkparser.py @@ -400,7 +400,7 @@ def convert_v6_to_tuple(self, content): assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0 best_q = struct.pack('fff', best_q_w, best_d, best_q_l) - return (planes, probs, winner, best_q, plies_left) + return (planes, probs, winner, best_q, plies_left, hash(planes).to_bytes(8, 'little', signed=True)) def sample_record(self, chunkdata): """ @@ -553,7 +553,7 @@ def batch_gen(self, gen, allow_partial=True): return yield (b''.join([x[0] for x in s]), b''.join([x[1] for x in s]), b''.join([x[2] for x in s]), b''.join([x[3] for x in s]), - b''.join([x[4] for x in s])) + b''.join([x[4] for x in s]), b''.join([x[5] for x in s])) def parse(self): """ diff --git a/tf/tfprocess.py b/tf/tfprocess.py index e1b1cf9a..102af02f 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -27,6 +27,7 @@ import attention_policy_map as apm import proto.net_pb2 as pb from functools import reduce +from aum import AUMTracker import operator from net import Net @@ -129,6 +130,8 @@ def __init__(self, cfg): self.virtual_batch_size = self.cfg['model'].get( 'virtual_batch_size', None) + self.aum_tracker = AUMTracker() + if precision == 'single': self.model_dtype = tf.float32 elif precision == 'half': @@ -319,6 +322,7 @@ def correct_policy(target, output): def policy_loss(target, output): target, output = correct_policy(target, output) + policy_cross_entropy = tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(target), logits=output) return tf.reduce_mean(input_tensor=policy_cross_entropy) @@ -629,7 +633,7 @@ def read_weights(self): return [w.read_value() for w in self.model.weights] @tf.function() - def process_inner_loop(self, x, y, z, q, m): + def process_inner_loop(self, x, y, z, q, m, s): with tf.GradientTape() as tape: outputs = self.model(x, training=True) policy = outputs[0] @@ -651,6 +655,7 @@ def process_inner_loop(self, x, y, z, q, m): reg_term) if self.loss_scale != 1: total_loss = self.optimizer.get_scaled_loss(total_loss) + if self.wdl: mse_loss = self.mse_loss_fn(self.qMix(z, q), value) else: @@ -665,7 +670,7 @@ def process_inner_loop(self, x, y, z, q, m): # get comparable values. mse_loss / 4.0, ] - return metrics, tape.gradient(total_loss, self.model.trainable_weights) + return policy, metrics, tape.gradient(total_loss, self.model.trainable_weights) @tf.function() def strategy_process_inner_loop(self, x, y, z, q, m): @@ -719,12 +724,14 @@ def train_step(self, steps, batch_size, batch_splits): # Run training for this batch grads = None for _ in range(batch_splits): - x, y, z, q, m = next(self.train_iter) + x, y, z, q, m, s = next(self.train_iter) if self.strategy is not None: metrics, new_grads = self.strategy_process_inner_loop( x, y, z, q, m) else: - metrics, new_grads = self.process_inner_loop(x, y, z, q, m) + logits, metrics, new_grads = self.process_inner_loop(x, y, z, q, m, s) + if steps % 10 == 0: + self.aum_tracker.update(logits, y, s) if not grads: grads = new_grads else: @@ -754,9 +761,15 @@ def train_step(self, steps, batch_size, batch_splits): self.global_step.assign_add(1) steps = self.global_step.read_value() + if steps % 2000 == 0: + self.aum_tracker.graph() + if steps % 10000 == 0: + self.aum_tracker.clear() + if steps % self.cfg['training'][ 'train_avg_report_steps'] == 0 or steps % self.cfg['training'][ 'total_steps'] == 0: + time_end = time.time() speed = 0 if self.time_start: @@ -930,7 +943,7 @@ def calculate_test_summaries(self, test_batches, steps): for metric in self.test_metrics: metric.reset() for _ in range(0, test_batches): - x, y, z, q, m = next(self.test_iter) + x, y, z, q, m, s = next(self.test_iter) if self.strategy is not None: metrics = self.strategy_calculate_test_summaries_inner_loop( x, y, z, q, m) diff --git a/tf/train.py b/tf/train.py index 2935c313..fb0ab7a7 100644 --- a/tf/train.py +++ b/tf/train.py @@ -188,11 +188,11 @@ def main(cmd): tfprocess = TFProcess(cfg) train_dataset = tf.data.Dataset.from_generator( train_parser.parse, - output_types=(tf.string, tf.string, tf.string, tf.string, tf.string)) + output_types=(tf.string, tf.string, tf.string, tf.string, tf.string, tf.string)) train_dataset = train_dataset.map(parse_function) test_dataset = tf.data.Dataset.from_generator( test_parser.parse, - output_types=(tf.string, tf.string, tf.string, tf.string, tf.string)) + output_types=(tf.string, tf.string, tf.string, tf.string, tf.string, tf.string)) test_dataset = test_dataset.map(parse_function) validation_dataset = None