Skip to content

Commit

Permalink
Fix a reference to gfile.glob.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654714199
  • Loading branch information
mjanusz authored and copybara-github committed Jul 23, 2024
1 parent 1e95fbf commit 10a0a7b
Showing 1 changed file with 0 additions and 29 deletions.
29 changes: 0 additions & 29 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,35 +199,6 @@ def parse_tf_coords(x):
)


def load_coordinates_from_tfex(
coord_pattern: str,
shuffle: bool = True,
shuffle_size: Optional[int] = 4096,
shuffle_seed: Optional[int] = None,
parse_fn: Callable[[Any], dict[str, Any]] = parse_tf_coords,
reshuffle_each_iteration: bool = True,
) -> tf.data.Dataset:
"""Loads coordinates from a RecordIO of tf.Example protos."""
coord_paths = sorted(gfile.Glob(coord_pattern))
if shuffle:
if shuffle_seed:
random.Random(shuffle_seed).shuffle(coord_paths)
else:
random.shuffle(coord_paths)
logging.info('Loading data from: %r', coord_paths)
ds = tf.data.RecordIODataset(tf.constant(coord_paths, dtype=tf.string))

ds = ds.map(parse_fn, deterministic=True)
if shuffle:
ds = ds.shuffle(
shuffle_size,
seed=shuffle_seed,
reshuffle_each_iteration=reshuffle_each_iteration,
)

return ds.repeat()


def load_patch_coordinates(coordinates_file_pattern,
shuffle=True,
scope='load_patch_coordinates',
Expand Down

0 comments on commit 10a0a7b

Please sign in to comment.