Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aum stuff with jjosh #209

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tf/aum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from tensorflow import squeeze, gather
import tensorflow as tf
import tensorflow.compat
from tensorflow.random import normal
import tensorflow.math as tfmath
from dataclasses import dataclass

@dataclass
class AUMRecord:
sample_id: int
visits: int
# target_ndx: int
target_val: float
# pred_ndx: int
pred_val: float
margin: float
aum: float

class AUMTracker:
def __init__(self, verbose: bool = True):
self.verbose = verbose
self.records = []
self.sums = {}
self.visits = {}

def update(self, logits, targets, sample_ids):
targets = gather(targets, 1)

max_logits = tfmath.reduce_max(logits)
# max_logits = squeeze(max_logits, axis = -1)

updated = {}

margins = targets - max_logits

# for i in range(len(sample_ids)):
# sample_id = sample_ids[i].ref()
# margin = margins[i]
#
# if sample_id in self.sums:
# self.sums[sample_id] += margin
# self.visits[sample_id] += 1
# else:
# self.sums[sample_id] = margin
# self.visits[sample_id] = 1
#
# a = self.sums[sample_id] / self.visits[sample_id]
#
# record = AUMRecord(
# sample_id,
# self.visits[sample_id],
# targets[i],
# logits[i],
# margin,
# a,
# )
# updated[sample_id] = record
#
# if self.verbose:
# self.records.append(record)

for i, (margin, sample_id) in enumerate(zip(margins, sample_ids)):
sample_id = sample_id.ref()
if sample_id in self.sums:
self.sums[sample_id] += margin
self.visits[sample_id] += 1
else:
self.sums[sample_id] = margin
self.visits[sample_id] = 1

record = AUMRecord(
sample_id,
self.visits[sample_id],
targets[i],
logits[i],
margin,
self.sums[sample_id] / self.visits[sample_id],
)
updated[sample_id] = record

if self.verbose:
self.records.append(record)

return updated

def graph(self) -> None:
if not self.verbose:
print("Not in verbose mode, did not log anything")
return

from matplotlib import pyplot as plt

fig = plt.figure()

assert len(self.records) != 0
plt.plot(self.get_all_scores())
# print(self.get_all_scores())

plt.savefig("poop.png")
print("saved graph")

def get_all_scores(self):
scores = []
for record in self.records:
scores.append(record.aum)
return scores

def clear(self):
self.records = []
self.sums = {}
self.visits = {}

if __name__ == "__main__":
tracker = AUMTracker()
tracker.update(normal([1, 4]), normal([1, 4]), [1, 2, 3, 4])
tracker.graph()
6 changes: 4 additions & 2 deletions tf/chunkparsefunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import tensorflow as tf


def parse_function(planes, probs, winner, q, plies_left):
def parse_function(planes, probs, winner, q, plies_left, sample_ids):
"""
Convert unpacked record batches to tensors for tensorflow training
"""
sample_ids = tf.io.decode_raw(sample_ids, tf.int64)
planes = tf.io.decode_raw(planes, tf.float32)
probs = tf.io.decode_raw(probs, tf.float32)
winner = tf.io.decode_raw(winner, tf.float32)
Expand All @@ -31,7 +32,8 @@ def parse_function(planes, probs, winner, q, plies_left):
planes = tf.reshape(planes, (-1, 112, 8, 8))
probs = tf.reshape(probs, (-1, 1858))
winner = tf.reshape(winner, (-1, 3))
sample_ids = tf.reshape(sample_ids, (-1, 1))
q = tf.reshape(q, (-1, 3))
plies_left = tf.reshape(plies_left, (-1, 1))

return (planes, probs, winner, q, plies_left)
return (planes, probs, winner, q, plies_left, sample_ids)
4 changes: 2 additions & 2 deletions tf/chunkparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def convert_v6_to_tuple(self, content):
assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0
best_q = struct.pack('fff', best_q_w, best_d, best_q_l)

return (planes, probs, winner, best_q, plies_left)
return (planes, probs, winner, best_q, plies_left, hash(planes).to_bytes(8, 'little', signed=True))

def sample_record(self, chunkdata):
"""
Expand Down Expand Up @@ -553,7 +553,7 @@ def batch_gen(self, gen, allow_partial=True):
return
yield (b''.join([x[0] for x in s]), b''.join([x[1] for x in s]),
b''.join([x[2] for x in s]), b''.join([x[3] for x in s]),
b''.join([x[4] for x in s]))
b''.join([x[4] for x in s]), b''.join([x[5] for x in s]))

def parse(self):
"""
Expand Down
23 changes: 18 additions & 5 deletions tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import attention_policy_map as apm
import proto.net_pb2 as pb
from functools import reduce
from aum import AUMTracker
import operator

from net import Net
Expand Down Expand Up @@ -129,6 +130,8 @@ def __init__(self, cfg):
self.virtual_batch_size = self.cfg['model'].get(
'virtual_batch_size', None)

self.aum_tracker = AUMTracker()

if precision == 'single':
self.model_dtype = tf.float32
elif precision == 'half':
Expand Down Expand Up @@ -319,6 +322,7 @@ def correct_policy(target, output):

def policy_loss(target, output):
target, output = correct_policy(target, output)

policy_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.stop_gradient(target), logits=output)
return tf.reduce_mean(input_tensor=policy_cross_entropy)
Expand Down Expand Up @@ -629,7 +633,7 @@ def read_weights(self):
return [w.read_value() for w in self.model.weights]

@tf.function()
def process_inner_loop(self, x, y, z, q, m):
def process_inner_loop(self, x, y, z, q, m, s):
with tf.GradientTape() as tape:
outputs = self.model(x, training=True)
policy = outputs[0]
Expand All @@ -651,6 +655,7 @@ def process_inner_loop(self, x, y, z, q, m):
reg_term)
if self.loss_scale != 1:
total_loss = self.optimizer.get_scaled_loss(total_loss)

if self.wdl:
mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
else:
Expand All @@ -665,7 +670,7 @@ def process_inner_loop(self, x, y, z, q, m):
# get comparable values.
mse_loss / 4.0,
]
return metrics, tape.gradient(total_loss, self.model.trainable_weights)
return policy, metrics, tape.gradient(total_loss, self.model.trainable_weights)

@tf.function()
def strategy_process_inner_loop(self, x, y, z, q, m):
Expand Down Expand Up @@ -719,12 +724,14 @@ def train_step(self, steps, batch_size, batch_splits):
# Run training for this batch
grads = None
for _ in range(batch_splits):
x, y, z, q, m = next(self.train_iter)
x, y, z, q, m, s = next(self.train_iter)
if self.strategy is not None:
metrics, new_grads = self.strategy_process_inner_loop(
x, y, z, q, m)
else:
metrics, new_grads = self.process_inner_loop(x, y, z, q, m)
logits, metrics, new_grads = self.process_inner_loop(x, y, z, q, m, s)
if steps % 10 == 0:
self.aum_tracker.update(logits, y, s)
if not grads:
grads = new_grads
else:
Expand Down Expand Up @@ -754,9 +761,15 @@ def train_step(self, steps, batch_size, batch_splits):
self.global_step.assign_add(1)
steps = self.global_step.read_value()

if steps % 2000 == 0:
self.aum_tracker.graph()
if steps % 10000 == 0:
self.aum_tracker.clear()

if steps % self.cfg['training'][
'train_avg_report_steps'] == 0 or steps % self.cfg['training'][
'total_steps'] == 0:

time_end = time.time()
speed = 0
if self.time_start:
Expand Down Expand Up @@ -930,7 +943,7 @@ def calculate_test_summaries(self, test_batches, steps):
for metric in self.test_metrics:
metric.reset()
for _ in range(0, test_batches):
x, y, z, q, m = next(self.test_iter)
x, y, z, q, m, s = next(self.test_iter)
if self.strategy is not None:
metrics = self.strategy_calculate_test_summaries_inner_loop(
x, y, z, q, m)
Expand Down
4 changes: 2 additions & 2 deletions tf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ def main(cmd):
tfprocess = TFProcess(cfg)
train_dataset = tf.data.Dataset.from_generator(
train_parser.parse,
output_types=(tf.string, tf.string, tf.string, tf.string, tf.string))
output_types=(tf.string, tf.string, tf.string, tf.string, tf.string, tf.string))
train_dataset = train_dataset.map(parse_function)
test_dataset = tf.data.Dataset.from_generator(
test_parser.parse,
output_types=(tf.string, tf.string, tf.string, tf.string, tf.string))
output_types=(tf.string, tf.string, tf.string, tf.string, tf.string, tf.string))
test_dataset = test_dataset.map(parse_function)

validation_dataset = None
Expand Down