diff --git a/ffn/training/inputs.py b/ffn/training/inputs.py index 222c40d..4420f6c 100644 --- a/ffn/training/inputs.py +++ b/ffn/training/inputs.py @@ -22,6 +22,7 @@ from connectomics.common import bounding_box from connectomics.common import box_generator from connectomics.segmentation import labels as label_utils +from connectomics.volume import metadata from ffn.training import augmentation from ffn.training import variables import numpy as np @@ -33,13 +34,11 @@ def create_filename_queue(coordinates_file_pattern, shuffle=True): """Creates a queue for reading coordinates from coordinate file. Args: - coordinates_file_pattern: File pattern for TFRecords of - input examples of the form of a glob - pattern or path@shards - or comma-separated file patterns. + coordinates_file_pattern: File pattern for TFRecords of input examples of + the form of a glob pattern or path@shards or comma-separated file + patterns. shuffle: Whether to shuffle the coordinate file list. Note that the expanded - coordinates_file_pattern is not guaranteed to be sorted - alphabetically. + coordinates_file_pattern is not guaranteed to be sorted alphabetically. Returns: Tensorflow queue with coordinate filenames @@ -702,3 +701,91 @@ def sample( if repeat: sampled_dataset = sampled_dataset.repeat() return sampled_dataset + + +def coordinates_in_bounds( + coordinates: tf.Tensor, + volname: tf.Tensor, + radius: Sequence[int], + volinfo_map_string: str, + use_bboxes: bool = True, + name: str = 'coordinates_in_bounds', +) -> tf.Tensor: + """Tensorflow Python Op returning boolean whether coordinates are in bounds. + + Args: + coordinates: int64 Tensor of shape `[1, 3]` representing center coordinates + from which to retrieve patches. + volname: string Tensor of shape `[1]` giving volume from which patch should + be retrieved. + radius: length 3 sequence indicating the radius of the patches to be + retrieved around each coordinates (xyz). + volinfo_map_string: comma delimited string mapping volname:volinfo_path, + where volinfo_path points to the metadata of the volume from which + patches should be extracted. + use_bboxes: whether to use the bounding boxes declared in the volume; if + False, the physical size of the volume is used as the bounding box. + name: passed to name_scope. + + Returns: + Boolean scalar Tensor indicating whether the patch specified by coordinates + and radius fits within the volume specified by volname and + volinfo_map_string. This can be passed to tf.cond to select either + coordinates or an empty constant of shape `[0, 3]`, which can then be + passed to batching (e.g. see tests). + """ + boxes_by_volname = {} + for mapping in volinfo_map_string.split(','): + k, volinfo_path = mapping.split(':') + k = k.encode('utf-8') + assert k not in boxes_by_volname + + if volinfo_path.endswith('metadata.json'): + f = open(volinfo_path, 'r') + meta = metadata.VolumeMetadata.from_json(f.read()) + if use_bboxes: + bboxes = meta.bounding_boxes + else: + bboxes = [ + bounding_box.BoundingBox( + (0, 0, 0), + (meta.volume_size.x, meta.volume_size.y, meta.volume_size.z), + ) + ] + boxes_by_volname[k] = bboxes + + if not boxes_by_volname: + raise ValueError('boxes_by_volname is empty.') + + def _in_bounds_fn(coordinates, volname): + boxes = boxes_by_volname[volname[0]] + patch_start = np.array(coordinates) - radius + patch_back = np.array(coordinates) + radius + for box in boxes: + if (box.start <= patch_start).all() and (box.end > patch_back).all(): + return True + return False + + with tf.name_scope(name, values=[coordinates, volname]) as scope: + assert coordinates.shape_as_list() == [1, 3] + assert volname.shape_as_list() == [1] + in_bounds = tf.py_func( + _in_bounds_fn, + [coordinates, volname], + [tf.bool], + name=scope, + stateful=False, + )[0] + in_bounds.set_shape([]) + return in_bounds + + +def filter_oob( + item: dict[str, tf.Tensor], volinfo_map_string: str, patch_size: list[int], + use_bboxes: bool = True +) -> tf.Tensor: + radius = np.floor_divide(patch_size, 2) + return coordinates_in_bounds( + item['coord'], item['volname'], radius, volinfo_map_string, + use_bboxes + )