From e5d3d79b591a8e2774a0b22e39f23ccbde58648a Mon Sep 17 00:00:00 2001 From: Jin Xu Date: Fri, 28 Jun 2024 09:22:52 -0700 Subject: [PATCH] Migrate load_patch_coordinates_from_filename_queue and weighted_load_patch_coordinates methods into inputs.py PiperOrigin-RevId: 647705619 --- ffn/training/inputs.py | 104 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 12 deletions(-) diff --git a/ffn/training/inputs.py b/ffn/training/inputs.py index 30863a3..5363712 100644 --- a/ffn/training/inputs.py +++ b/ffn/training/inputs.py @@ -23,6 +23,7 @@ from connectomics.common import box_generator from connectomics.segmentation import labels as label_utils from ffn.training import augmentation +from ffn.training import variables import numpy as np import tensorflow.compat.v1 as tf from tensorflow.io import gfile @@ -61,24 +62,31 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True): return tf.train.string_input_producer(coord_file_list, shuffle=shuffle) -def load_patch_coordinates_from_filename_queue(filename_queue): +def load_patch_coordinates_from_filename_queue(filename_queue, + file_format='tfrecord'): """Loads coordinates and volume names from filename queue. Args: filename_queue: Tensorflow queue created from create_filename_queue() + file_format: String indicating the format of the files in the queue. + Can be 'sstable' or 'tfrecord'. Defaults to 'tfrecord'. Returns: Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors. """ - record_options = tf.python_io.TFRecordOptions( - tf.python_io.TFRecordCompressionType.GZIP) - _, protos = tf.TFRecordReader(options=record_options).read(filename_queue) - examples = tf.parse_single_example(protos, features=dict( - center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64), - label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string), - )) - coord = examples['center'] - volname = examples['label_volume_name'] + if file_format == 'tfrecord': + record_options = tf.python_io.TFRecordOptions( + tf.python_io.TFRecordCompressionType.GZIP) + _, protos = tf.TFRecordReader(options=record_options).read(filename_queue) + examples = tf.parse_single_example(protos, features=dict( + center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64), + label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string), + )) + coord = examples['center'] + volname = examples['label_volume_name'] + else: + raise ValueError(f'Unsupported file format: {file_format}.') + return coord, volname @@ -222,7 +230,8 @@ def load_coordinates_from_tfex( def load_patch_coordinates(coordinates_file_pattern, shuffle=True, - scope='load_patch_coordinates'): + scope='load_patch_coordinates', + file_format='tfrecord'): """Loads coordinates and volume names from tables of VolumeStoreInputExamples. Args: @@ -233,6 +242,8 @@ def load_patch_coordinates(coordinates_file_pattern, coordinates_file_pattern is not guaranteed to be sorted alphabetically. scope: Passed to name_scope. + file_format: String indicating the format of the files in the queue. + Can be 'sstable' or 'tfrecord'. Defaults to 'tfrecord'. Returns: Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors. @@ -240,7 +251,76 @@ def load_patch_coordinates(coordinates_file_pattern, with tf.name_scope(scope): filename_queue = create_filename_queue( coordinates_file_pattern, shuffle=shuffle) - return load_patch_coordinates_from_filename_queue(filename_queue) + return load_patch_coordinates_from_filename_queue( + filename_queue, file_format=file_format) + + +def weighted_load_patch_coordinates( + coord_paths: Sequence[str], + weights: Optional[Sequence[float]] = None, + scope: str = 'weighted_load_patch_coordinates', + file_format: str = 'tfrecord', +): + """Like the unweighted version, but pulls data from multiple sources. + + Args: + coord_paths: glob patterns for SSTables of VolumeStoreInputExamples + weights: weights determining the relative frequency the corresponding paths + will be sampled; needs to be same length as coord_paths; weights do not + need to be normalized + scope: passed to name_scape + file_format: String indicating the format of the files in the queue. + Can be 'sstable' or 'tfrecord'. Defaults to 'tfrecord'. + + Returns: + TF op to pull a tuple of coordinates and volume name from a queue. + """ + if weights is None: + weights = [1.0] * len(coord_paths) + if len(coord_paths) != len(weights): + raise ValueError( + '# of coord paths: %d does not match # of weights %d' + % (len(coord_paths), len(weights)) + ) + + weights = np.array(weights) + weights /= np.sum(weights) + cum_weights = np.cumsum(weights) + + with tf.name_scope(scope): + # Filename queues have to be created in the main graph, outside of + # tf.switch_case branches. + load_queues = [] + for path in coord_paths: + load_queues.append(create_filename_queue(path, shuffle=True)) + + with tf.variable_scope(None, 'coord_source'): + dist = variables.DistributionTracker(len(coord_paths)) + rates = dist.get_rates() + for i in range(len(coord_paths)): + tf.summary.scalar('source_%d' % i, rates[i]) + + # Choose source at random and pull coordinates from the associated queue. + source_num = tf.cast( + tf.reduce_min(tf.where(tf.random.uniform(shape=[1]) < cum_weights)), + tf.int32, + ) + with tf.control_dependencies([dist.record_class(source_num)]): + return tf.switch_case( + source_num, + [ + lambda qq=q: load_patch_coordinates_from_filename_queue( + qq, file_format=file_format) + for q in load_queues + ], + # Use a default invalid value so that the process crashes if + # no valid case is found instead of silently selecting the + # last branch. + default=lambda: ( # pylint:disable=g-long-lambda + tf.constant([[0, 0, 0]], dtype=tf.int64), + tf.constant(['missing'], dtype=tf.string), + ), + ) def load_from_numpylike(coordinates, volume_names, shape, volume_map,