Skip to content

Commit

Permalink
Migrate load_patch_coordinates_from_filename_queue and weighted_load_…
Browse files Browse the repository at this point in the history
…patch_coordinates methods into inputs.py

PiperOrigin-RevId: 647705619
  • Loading branch information
imxj authored and copybara-github committed Jun 28, 2024
1 parent b9ae681 commit e5d3d79
Showing 1 changed file with 92 additions and 12 deletions.
104 changes: 92 additions & 12 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from connectomics.common import box_generator
from connectomics.segmentation import labels as label_utils
from ffn.training import augmentation
from ffn.training import variables
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.io import gfile
Expand Down Expand Up @@ -61,24 +62,31 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True):
return tf.train.string_input_producer(coord_file_list, shuffle=shuffle)


def load_patch_coordinates_from_filename_queue(filename_queue):
def load_patch_coordinates_from_filename_queue(filename_queue,
file_format='tfrecord'):
"""Loads coordinates and volume names from filename queue.
Args:
filename_queue: Tensorflow queue created from create_filename_queue()
file_format: String indicating the format of the files in the queue.
Can be 'sstable' or 'tfrecord'. Defaults to 'tfrecord'.
Returns:
Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors.
"""
record_options = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)
_, protos = tf.TFRecordReader(options=record_options).read(filename_queue)
examples = tf.parse_single_example(protos, features=dict(
center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string),
))
coord = examples['center']
volname = examples['label_volume_name']
if file_format == 'tfrecord':
record_options = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)
_, protos = tf.TFRecordReader(options=record_options).read(filename_queue)
examples = tf.parse_single_example(protos, features=dict(
center=tf.FixedLenFeature(shape=[1, 3], dtype=tf.int64),
label_volume_name=tf.FixedLenFeature(shape=[1], dtype=tf.string),
))
coord = examples['center']
volname = examples['label_volume_name']
else:
raise ValueError(f'Unsupported file format: {file_format}.')

return coord, volname


Expand Down Expand Up @@ -222,7 +230,8 @@ def load_coordinates_from_tfex(

def load_patch_coordinates(coordinates_file_pattern,
shuffle=True,
scope='load_patch_coordinates'):
scope='load_patch_coordinates',
file_format='tfrecord'):
"""Loads coordinates and volume names from tables of VolumeStoreInputExamples.
Args:
Expand All @@ -233,14 +242,85 @@ def load_patch_coordinates(coordinates_file_pattern,
coordinates_file_pattern is not guaranteed to be sorted
alphabetically.
scope: Passed to name_scope.
file_format: String indicating the format of the files in the queue.
Can be 'sstable' or 'tfrecord'. Defaults to 'tfrecord'.
Returns:
Tuple of coordinates (shape `[1, 3]`) and volume name (shape `[1]`) tensors.
"""
with tf.name_scope(scope):
filename_queue = create_filename_queue(
coordinates_file_pattern, shuffle=shuffle)
return load_patch_coordinates_from_filename_queue(filename_queue)
return load_patch_coordinates_from_filename_queue(
filename_queue, file_format=file_format)


def weighted_load_patch_coordinates(
coord_paths: Sequence[str],
weights: Optional[Sequence[float]] = None,
scope: str = 'weighted_load_patch_coordinates',
file_format: str = 'tfrecord',
):
"""Like the unweighted version, but pulls data from multiple sources.
Args:
coord_paths: glob patterns for SSTables of VolumeStoreInputExamples
weights: weights determining the relative frequency the corresponding paths
will be sampled; needs to be same length as coord_paths; weights do not
need to be normalized
scope: passed to name_scape
file_format: String indicating the format of the files in the queue.
Can be 'sstable' or 'tfrecord'. Defaults to 'tfrecord'.
Returns:
TF op to pull a tuple of coordinates and volume name from a queue.
"""
if weights is None:
weights = [1.0] * len(coord_paths)
if len(coord_paths) != len(weights):
raise ValueError(
'# of coord paths: %d does not match # of weights %d'
% (len(coord_paths), len(weights))
)

weights = np.array(weights)
weights /= np.sum(weights)
cum_weights = np.cumsum(weights)

with tf.name_scope(scope):
# Filename queues have to be created in the main graph, outside of
# tf.switch_case branches.
load_queues = []
for path in coord_paths:
load_queues.append(create_filename_queue(path, shuffle=True))

with tf.variable_scope(None, 'coord_source'):
dist = variables.DistributionTracker(len(coord_paths))
rates = dist.get_rates()
for i in range(len(coord_paths)):
tf.summary.scalar('source_%d' % i, rates[i])

# Choose source at random and pull coordinates from the associated queue.
source_num = tf.cast(
tf.reduce_min(tf.where(tf.random.uniform(shape=[1]) < cum_weights)),
tf.int32,
)
with tf.control_dependencies([dist.record_class(source_num)]):
return tf.switch_case(
source_num,
[
lambda qq=q: load_patch_coordinates_from_filename_queue(
qq, file_format=file_format)
for q in load_queues
],
# Use a default invalid value so that the process crashes if
# no valid case is found instead of silently selecting the
# last branch.
default=lambda: ( # pylint:disable=g-long-lambda
tf.constant([[0, 0, 0]], dtype=tf.int64),
tf.constant(['missing'], dtype=tf.string),
),
)


def load_from_numpylike(coordinates, volume_names, shape, volume_map,
Expand Down

0 comments on commit e5d3d79

Please sign in to comment.