Skip to content

Commit

Permalink
Migrate filter_oob into inputs.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655319877
  • Loading branch information
imxj authored and copybara-github committed Jul 23, 2024
1 parent 1e95fbf commit 458cc90
Showing 1 changed file with 86 additions and 6 deletions.
92 changes: 86 additions & 6 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# ==============================================================================
"""Tensorflow Python ops and utilities for generating network inputs."""

import os
import random
import re
from typing import Any, Callable, Optional, Sequence

from absl import logging
from connectomics.common import bounding_box
from connectomics.common import box_generator
from connectomics.common import ts_utils
from connectomics.segmentation import labels as label_utils
from ffn.training import augmentation
from ffn.training import variables
Expand All @@ -33,13 +35,11 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True):
"""Creates a queue for reading coordinates from coordinate file.
Args:
coordinates_file_pattern: File pattern for TFRecords of
input examples of the form of a glob
pattern or path@shards
or comma-separated file patterns.
coordinates_file_pattern: File pattern for TFRecords of input examples of
the form of a glob pattern or path@shards or comma-separated file
patterns.
shuffle: Whether to shuffle the coordinate file list. Note that the expanded
coordinates_file_pattern is not guaranteed to be sorted
alphabetically.
coordinates_file_pattern is not guaranteed to be sorted alphabetically.
Returns:
Tensorflow queue with coordinate filenames
Expand Down Expand Up @@ -731,3 +731,83 @@ def sample(
if repeat:
sampled_dataset = sampled_dataset.repeat()
return sampled_dataset


def coordinates_in_bounds(
coordinates: tf.Tensor,
volname: tf.Tensor,
radius: Sequence[int],
volinfo_map_string: str,
use_bboxes: bool = True,
name: str = 'coordinates_in_bounds',
) -> tf.Tensor:
"""Tensorflow Python Op returning boolean whether coordinates are in bounds.
Args:
coordinates: int64 Tensor of shape `[1, 3]` representing center coordinates
from which to retrieve patches.
volname: string Tensor of shape `[1]` giving volume from which patch should
be retrieved.
radius: length 3 sequence indicating the radius of the patches to be
retrieved around each coordinates (xyz).
volinfo_map_string: comma delimited string mapping volname:volinfo_path,
where volinfo_path is a gfile with text_format VolumeInfo proto for the
volume from which patches should be extracted.
use_bboxes: whether to use the bounding boxes declared in the volume; if
False, the physical size of the volume is used as the bounding box.
name: passed to name_scope.
Returns:
Boolean scalar Tensor indicating whether the patch specified by coordinates
and radius fits within the volume specified by volname and
volinfo_map_string. This can be passed to tf.cond to select either
coordinates or an empty constant of shape `[0, 3]`, which can then be
passed to batching (e.g. see tests).
"""
boxes_by_volname = {}
for mapping in volinfo_map_string.split(','):
k, v = mapping.split(':')
k = k.encode('utf-8')
assert k not in boxes_by_volname

metadata_path = os.path.join(v, 'metadata.json')
if gfile.Exists(metadata_path):
metadata = ts_utils.read_metadata(metadata_path)
boxes_by_volname[k] = metadata['bboxes']
else:
break
if not boxes_by_volname:
raise ValueError('boxes_by_volname is empty.')

def _in_bounds_fn(coordinates, volname):
boxes = boxes_by_volname[volname[0]]
patch_start = np.array(coordinates) - radius
patch_back = np.array(coordinates) + radius
for box in boxes:
if (box.start <= patch_start).all() and (box.end > patch_back).all():
return True
return False

with tf.name_scope(name, values=[coordinates, volname]) as scope:
assert coordinates.shape_as_list() == [1, 3]
assert volname.shape_as_list() == [1]
in_bounds = tf.py_func(
_in_bounds_fn,
[coordinates, volname],
[tf.bool],
name=scope,
stateful=False,
)[0]
in_bounds.set_shape([])
return in_bounds


def filter_oob(
item: dict[str, tf.Tensor], volinfo_map_string: str, patch_size: list[int],
use_bboxes: bool = True
) -> tf.Tensor:
radius = np.floor_divide(patch_size, 2)
return coordinates_in_bounds(
item['coord'], item['volname'], radius, volinfo_map_string,
use_bboxes
)

0 comments on commit 458cc90

Please sign in to comment.