diff --git a/tf/chunkparser.py b/tf/chunkparser.py index e634ac4d..39c1c87f 100644 --- a/tf/chunkparser.py +++ b/tf/chunkparser.py @@ -80,13 +80,13 @@ def __init__(self, expected_input_format, shuffle_size=1, sample=1, - buffer_size=1, batch_size=256, workers=None): """ Read data and yield batches of raw tensors. 'chunks' list of chunk filenames. + 'expected_input_format' is an int, one of [1, 2, 3]. Determines the middle planes in convert_v5_to_tuple 'shuffle_size' is the size of the shuffle buffer. 'sample' is the rate to down-sample. 'workers' is the number of child workers to use. @@ -99,7 +99,7 @@ def __init__(self, chunkdata: type Bytes. Multiple records of v5 format where each record consists of (state, policy, result, q) - raw: A byte string holding raw tensors contenated together. This is + raw: A byte string holding raw tensors concatenated together. This is used to pass data from the workers to the parent. Exists because TensorFlow doesn't have a fast way to unpack bit vectors. 7950 bytes long. diff --git a/tf/decode_training.py b/tf/decode_training.py index bd9a8688..274ae0c0 100755 --- a/tf/decode_training.py +++ b/tf/decode_training.py @@ -293,10 +293,11 @@ def describe(self): class TrainingStep: - def __init__(self, version): + def __init__(self, version, input_format=1): self.version = version # Construct a fake parser just to get access to it's variables self.parser = chunkparser.ChunkParser(chunkparser.ChunkDataSrc([]), + expected_input_format=input_format, workers=1) self.NUM_HIST = 8 self.NUM_PIECE_TYPES = 6