Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weakref pickle issue. #167

Merged
merged 2 commits into from
Sep 30, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 66 additions & 52 deletions tf/chunkparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#
# You should have received a copy of the GNU General Public License
# along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.

"""
General comments on how chunkparser works.

Expand Down Expand Up @@ -133,9 +132,54 @@ def __init__(self,
value_focus_min=1,
value_focus_slope=0,
workers=None):
self.inner = ChunkParserInner(self, chunks, expected_input_format,
shuffle_size, sample, buffer_size,
batch_size, value_focus_min,
value_focus_slope, workers)

def shutdown(self):
"""
Terminates all the workers
"""
for i in range(len(self.processes)):
self.processes[i].terminate()
self.processes[i].join()
self.inner.readers[i].close()
self.inner.writers[i].close()
self.chunk_process.terminate()
self.chunk_process.join()

@staticmethod
def parse_function(planes, probs, winner, q, plies_left):
"""
Convert unpacked record batches to tensors for tensorflow training
"""
planes = tf.io.decode_raw(planes, tf.float32)
probs = tf.io.decode_raw(probs, tf.float32)
winner = tf.io.decode_raw(winner, tf.float32)
q = tf.io.decode_raw(q, tf.float32)
plies_left = tf.io.decode_raw(plies_left, tf.float32)

planes = tf.reshape(planes, (ChunkParser.BATCH_SIZE, 112, 8 * 8))
probs = tf.reshape(probs, (ChunkParser.BATCH_SIZE, 1858))
winner = tf.reshape(winner, (ChunkParser.BATCH_SIZE, 3))
q = tf.reshape(q, (ChunkParser.BATCH_SIZE, 3))
plies_left = tf.reshape(plies_left, (ChunkParser.BATCH_SIZE, ))

return (planes, probs, winner, q, plies_left)

def parse(self):
return self.inner.parse()


class ChunkParserInner:
def __init__(self, parent, chunks, expected_input_format, shuffle_size,
sample, buffer_size, batch_size, value_focus_min,
value_focus_slope, workers):
"""
Read data and yield batches of raw tensors.

'parent' the outer chunk parser to store processes. Must not be stored by self directly or indirectly.
'chunks' list of chunk filenames.
'shuffle_size' is the size of the shuffle buffer.
'sample' is the rate to down-sample.
Expand Down Expand Up @@ -182,38 +226,26 @@ def __init__(self,
# Start the child workers running
self.readers = []
self.writers = []
self.processes = []
parent.processes = []
self.chunk_filename_queue = mp.Queue(maxsize=4096)
for _ in range(workers):
read, write = mp.Pipe(duplex=False)
p = mp.Process(target=self.task,
args=(self.chunk_filename_queue, write))
p.daemon = True
self.processes.append(p)
parent.processes.append(p)
p.start()
self.readers.append(read)
self.writers.append(write)

self.chunk_process = mp.Process(target=chunk_reader,
args=(chunks,
self.chunk_filename_queue))
self.chunk_process.daemon = True
self.chunk_process.start()
parent.chunk_process = mp.Process(target=chunk_reader,
args=(chunks,
self.chunk_filename_queue))
parent.chunk_process.daemon = True
parent.chunk_process.start()

self.init_structs()

def shutdown(self):
"""
Terminates all the workers
"""
for i in range(len(self.readers)):
self.processes[i].terminate()
self.processes[i].join()
self.readers[i].close()
self.writers[i].close()
self.chunk_process.terminate()
self.chunk_process.join()

def init_structs(self):
"""
struct.Struct doesn't pickle, so it needs to be separately
Expand All @@ -224,25 +256,6 @@ def init_structs(self):
self.v4_struct = struct.Struct(V4_STRUCT_STRING)
self.v3_struct = struct.Struct(V3_STRUCT_STRING)

@staticmethod
def parse_function(planes, probs, winner, q, plies_left):
"""
Convert unpacked record batches to tensors for tensorflow training
"""
planes = tf.io.decode_raw(planes, tf.float32)
probs = tf.io.decode_raw(probs, tf.float32)
winner = tf.io.decode_raw(winner, tf.float32)
q = tf.io.decode_raw(q, tf.float32)
plies_left = tf.io.decode_raw(plies_left, tf.float32)

planes = tf.reshape(planes, (ChunkParser.BATCH_SIZE, 112, 8 * 8))
probs = tf.reshape(probs, (ChunkParser.BATCH_SIZE, 1858))
winner = tf.reshape(winner, (ChunkParser.BATCH_SIZE, 3))
q = tf.reshape(q, (ChunkParser.BATCH_SIZE, 3))
plies_left = tf.reshape(plies_left, (ChunkParser.BATCH_SIZE, ))

return (planes, probs, winner, q, plies_left)

def convert_v6_to_tuple(self, content):
"""
Unpack a v6 binary record to 5-tuple (state, policy pi, result, q, m)
Expand Down Expand Up @@ -294,11 +307,12 @@ def convert_v6_to_tuple(self, content):
"""
# unpack the V6 content from raw byte array, arbitrarily chose 4 2-byte values
# for the 8 "reserved" bytes
(ver, input_format, probs, planes, us_ooo, us_oo, them_ooo, them_oo, stm,
rule50_count, invariance_info, dep_result, root_q, best_q, root_d, best_d, root_m,
best_m, plies_left, result_q, result_d, played_q, played_d, played_m, orig_q,
orig_d, orig_m, visits, played_idx, best_idx, reserved1, reserved2, reserved3,
reserved4) = self.v6_struct.unpack(content)
(ver, input_format, probs, planes, us_ooo, us_oo, them_ooo, them_oo,
stm, rule50_count, invariance_info, dep_result, root_q, best_q,
root_d, best_d, root_m, best_m, plies_left, result_q, result_d,
played_q, played_d, played_m, orig_q, orig_d, orig_m, visits,
played_idx, best_idx, reserved1, reserved2, reserved3,
reserved4) = self.v6_struct.unpack(content)
"""
v5 struct format was (8308 bytes total)
int32 version (4 bytes)
Expand All @@ -321,7 +335,7 @@ def convert_v6_to_tuple(self, content):
float32 best_m (4 bytes)
float32 plies_left (4 bytes)
"""
# v3/4 data sometimes has a useful value in dep_ply_count (now invariance_info),
# v3/4 data sometimes has a useful value in dep_ply_count (now invariance_info),
# so copy that over if the new ply_count is not populated.
if plies_left == 0:
plies_left = invariance_info
Expand Down Expand Up @@ -370,7 +384,8 @@ def convert_v6_to_tuple(self, content):
# Concatenate all byteplanes. Make the last plane all 1's so the NN can
# detect edges of the board more easily
aux_plus_6_plane = self.flat_planes[0]
if (input_format == 132 or input_format == 133) and invariance_info >= 128:
if (input_format == 132
or input_format == 133) and invariance_info >= 128:
aux_plus_6_plane = self.flat_planes[1]
planes = planes.tobytes() + \
middle_planes + \
Expand All @@ -381,13 +396,13 @@ def convert_v6_to_tuple(self, content):
assert len(planes) == ((8 * 13 * 1 + 8 * 1 * 1) * 8 * 8 * 4)

if ver == V6_VERSION:
winner = struct.pack('fff', 0.5 * (1.0 - result_d + result_q), result_d,
0.5 * (1.0 - result_d - result_q))
winner = struct.pack('fff', 0.5 * (1.0 - result_d + result_q),
result_d, 0.5 * (1.0 - result_d - result_q))
else:
dep_result = float(dep_result)
assert dep_result == 1.0 or dep_result == -1.0 or dep_result == 0.0
winner = struct.pack('fff', dep_result == 1.0, dep_result == 0.0,
dep_result == -1.0)
dep_result == -1.0)

best_q_w = 0.5 * (1.0 - best_d + best_q)
best_q_l = 0.5 * (1.0 - best_d - best_q)
Expand All @@ -396,7 +411,6 @@ def convert_v6_to_tuple(self, content):

return (planes, probs, winner, best_q, plies_left)


def sample_record(self, chunkdata):
"""
Randomly sample through the v3/4/5/6 chunk data and select records in v6 format
Expand Down Expand Up @@ -439,7 +453,7 @@ def sample_record(self, chunkdata):
# value focus code, peek at best_q and orig_q from record (unpacks as tuple with one item)
best_q = struct.unpack('f', record[8284:8288])[0]
orig_q = struct.unpack('f', record[8328:8332])[0]

# if orig_q is NaN, accept, else accept based on value focus
if not np.isnan(orig_q):
diff_q = abs(best_q - orig_q)
Expand Down Expand Up @@ -534,7 +548,7 @@ def parse(self):
"""
Read data from child workers and yield batches of unpacked records
"""
gen = self.v6_gen() # read from workers
gen = self.v6_gen() # read from workers
gen = self.tuple_gen(gen) # convert v6->tuple
gen = self.batch_gen(gen) # assemble into batches
for b in gen:
Expand Down