Skip to content

Commit

Permalink
Migrate filter_oob into inputs.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655319877
  • Loading branch information
imxj authored and copybara-github committed Aug 22, 2024
1 parent 2060884 commit f41e937
Showing 1 changed file with 93 additions and 6 deletions.
99 changes: 93 additions & 6 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from connectomics.common import bounding_box
from connectomics.common import box_generator
from connectomics.segmentation import labels as label_utils
from connectomics.volume import metadata
from ffn.training import augmentation
from ffn.training import variables
import numpy as np
Expand All @@ -33,13 +34,11 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True):
"""Creates a queue for reading coordinates from coordinate file.
Args:
coordinates_file_pattern: File pattern for TFRecords of
input examples of the form of a glob
pattern or path@shards
or comma-separated file patterns.
coordinates_file_pattern: File pattern for TFRecords of input examples of
the form of a glob pattern or path@shards or comma-separated file
patterns.
shuffle: Whether to shuffle the coordinate file list. Note that the expanded
coordinates_file_pattern is not guaranteed to be sorted
alphabetically.
coordinates_file_pattern is not guaranteed to be sorted alphabetically.
Returns:
Tensorflow queue with coordinate filenames
Expand Down Expand Up @@ -702,3 +701,91 @@ def sample(
if repeat:
sampled_dataset = sampled_dataset.repeat()
return sampled_dataset


def coordinates_in_bounds(
coordinates: tf.Tensor,
volname: tf.Tensor,
radius: Sequence[int],
volinfo_map_string: str,
use_bboxes: bool = True,
name: str = 'coordinates_in_bounds',
) -> tf.Tensor:
"""Tensorflow Python Op returning boolean whether coordinates are in bounds.
Args:
coordinates: int64 Tensor of shape `[1, 3]` representing center coordinates
from which to retrieve patches.
volname: string Tensor of shape `[1]` giving volume from which patch should
be retrieved.
radius: length 3 sequence indicating the radius of the patches to be
retrieved around each coordinates (xyz).
volinfo_map_string: comma delimited string mapping volname:volinfo_path,
where volinfo_path points to the metadata of the volume from which
patches should be extracted.
use_bboxes: whether to use the bounding boxes declared in the volume; if
False, the physical size of the volume is used as the bounding box.
name: passed to name_scope.
Returns:
Boolean scalar Tensor indicating whether the patch specified by coordinates
and radius fits within the volume specified by volname and
volinfo_map_string. This can be passed to tf.cond to select either
coordinates or an empty constant of shape `[0, 3]`, which can then be
passed to batching (e.g. see tests).
"""
boxes_by_volname = {}
for mapping in volinfo_map_string.split(','):
k, volinfo_path = mapping.split(':')
k = k.encode('utf-8')
assert k not in boxes_by_volname

if volinfo_path.endswith('metadata.json'):
f = open(volinfo_path, 'r')
meta = metadata.VolumeMetadata.from_json(f.read())
if use_bboxes:
bboxes = meta.bounding_boxes
else:
bboxes = [
bounding_box.BoundingBox(
(0, 0, 0),
(meta.volume_size.x, meta.volume_size.y, meta.volume_size.z),
)
]
boxes_by_volname[k] = bboxes

if not boxes_by_volname:
raise ValueError('boxes_by_volname is empty.')

def _in_bounds_fn(coordinates, volname):
boxes = boxes_by_volname[volname[0]]
patch_start = np.array(coordinates) - radius
patch_back = np.array(coordinates) + radius
for box in boxes:
if (box.start <= patch_start).all() and (box.end > patch_back).all():
return True
return False

with tf.name_scope(name, values=[coordinates, volname]) as scope:
assert coordinates.shape_as_list() == [1, 3]
assert volname.shape_as_list() == [1]
in_bounds = tf.py_func(
_in_bounds_fn,
[coordinates, volname],
[tf.bool],
name=scope,
stateful=False,
)[0]
in_bounds.set_shape([])
return in_bounds


def filter_oob(
item: dict[str, tf.Tensor], volinfo_map_string: str, patch_size: list[int],
use_bboxes: bool = True
) -> tf.Tensor:
radius = np.floor_divide(patch_size, 2)
return coordinates_in_bounds(
item['coord'], item['volname'], radius, volinfo_map_string,
use_bboxes
)

0 comments on commit f41e937

Please sign in to comment.