Skip to content

Commit

Permalink
Migrate make_oob_mask into inputs.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667672345
  • Loading branch information
imxj authored and copybara-github committed Aug 28, 2024
1 parent 8a23f8c commit 08ac96f
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,79 @@ def filter_oob(
item['coord'], item['volname'], radius, volinfo_map_string,
use_bboxes
)


def make_oob_mask(
coordinates: tf.Tensor,
volname: tf.Tensor,
volinfo_map_string: str,
radius: Optional[Sequence[int]] = None,
shape: Optional[Sequence[int]] = None,
name='make_oob_mask',
):
"""Builds a tensor masking voxels that are outside of bounding boxes.
Args:
coordinates: int64 Tensor of shape `[1, 3]` representing center coordinates
from which to retrieve patches.
volname: string Tensor of shape `[1]` giving volume from which patch should
be retrieved.
volinfo_map_string: comma delimited string mapping volname:volinfo_path,
where volinfo_path is a gfile with text_format VolumeInfo proto for the
volume from which patches should be extracted.
radius: XYZ radius of patches to extract; exclusive with 'shape'
shape: XYZ shape of patches to extract; exclusive with 'radius'
name: passed to name_scope.
Returns:
float32 tensor of shape [1, dz, dy, dx, 1], where every voxel has one of
two values:
1: if the voxel is inside one or more bounding boxes associated with
the volume specified by `volname`
0: otherwise
and where (dx, dy, dz) = 2 * radius + 1.
"""
boxes_by_volname = {}
for mapping in volinfo_map_string.split(','):
k, volinfo_path = mapping.split(':')
k = k.encode('utf-8')
assert k not in boxes_by_volname

if volinfo_path.endswith('metadata.json'):
f = open(volinfo_path, 'r')
meta = metadata.VolumeMetadata.from_json(f.read())
boxes_by_volname[k] = meta.bounding_boxes
if shape is None:
assert radius is not None
diameter_xyz = np.array(radius) * 2 + 1
else:
assert radius is None
diameter_xyz = np.array(shape)
radius = diameter_xyz // 2

mask_shape = [1] + diameter_xyz.tolist()[::-1] + [1]

def _oob_mask_fn(coordinates, volname): # pylint:disable=missing-docstring
boxes = boxes_by_volname[volname[0]]
patch_box = bounding_box.BoundingBox(
start=np.array(coordinates[0, :]) - radius, size=diameter_xyz
)
oob_mask = np.zeros(mask_shape, dtype=np.float32)
for box in boxes:
ibox = patch_box.intersection(box)
if ibox is None:
continue
rel_ibox = bounding_box.BoundingBox(
start=ibox.start - patch_box.start, size=ibox.size
)
oob_mask[np.index_exp[:] + rel_ibox.to_slice3d() + np.index_exp[:]] = 1
return oob_mask

with tf.name_scope(name, values=[coordinates, volname]) as scope:
assert coordinates.shape_as_list() == [1, 3]
assert volname.shape_as_list() == [1]
oob_mask = tf.py_func(
_oob_mask_fn, [coordinates, volname], [tf.float32], name=scope
)[0]
oob_mask.set_shape(mask_shape)
return oob_mask

0 comments on commit 08ac96f

Please sign in to comment.