Skip to content

Commit

Permalink
Migrate load_from_volume into inputs.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646187131
  • Loading branch information
imxj authored and copybara-github committed Jul 17, 2024
1 parent 1e95fbf commit ae3377b
Showing 1 changed file with 93 additions and 6 deletions.
99 changes: 93 additions & 6 deletions ffn/training/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.io import gfile
import tensorstore as ts


def create_filename_queue(coordinates_file_pattern, shuffle=True):
Expand Down Expand Up @@ -323,21 +324,107 @@ def weighted_load_patch_coordinates(
)


def load_from_numpylike(coordinates, volume_names, shape, volume_map,
name=None):
def _filter_masked(item, volinfo_map_string: str):
mask_value = load_from_volume(
item['coord'],
item['volname'],
patch_size=(1, 1, 1),
dtype=tf.int64,
num_channels=1,
volinfo_map_string=volinfo_map_string,
)
return mask_value[0, 0, 0, 0, 0] > 0


def load_from_volume(
coord, volname, patch_size, dtype, num_channels, volinfo_map_string: str
):
"""Loads data from a volume using TensorStore.
Args:
coord: The coordinates to load from.
volname: The name of the volume.
patch_size: The size of the patch to load.
dtype: The data type of the volume.
num_channels: The number of channels in the volume.
volinfo_map_string: A string representation of the volume info map with the
format "volname1:volinfo_path1,volname2:volinfo_path2".
Returns:
A tensor containing the loaded data.
"""
if num_channels != 1:
raise ValueError('Only num_channels=1 is currently supported.')

volinfo_map = {}
for pair in volinfo_map_string.split(','):
name, path = pair.split(':')
volinfo_map[name.strip()] = path.strip()

def _load_single_volume(inputs):
coord, volinfo_path = inputs
print('volinfo_path:', volinfo_path)
print('coord:', coord)
volinfo_path = volinfo_path.numpy().decode('utf-8')
coord = coord.numpy()
spec = {'driver': 'volumestore', 'volinfo_path': volinfo_path}

store = ts.open(spec, open=True).result()

start_coord = [max(0, c - (p // 2)) for c, p in zip(coord, patch_size)]
stop_coord = [
min(store.shape[i], c + (p // 2) + (p % 2))
for i, (c, p) in enumerate(zip(coord, patch_size))
]

data = (
store[
start_coord[0] : stop_coord[0],
start_coord[1] : stop_coord[1],
start_coord[2] : stop_coord[2],
]
.read()
.result()
)

data = data[:, :, :, 0].transpose(2, 1, 0).astype(dtype.as_numpy_dtype)
data = data[..., tf.newaxis]
return data

patch_size = list(patch_size)
# Convert lists to tensors for tf.map_fn
coords_tensor = tf.convert_to_tensor(coord)
volinfo_paths_tensor = tf.convert_to_tensor(
[volinfo_map[v].encode('utf-8') for v in volname], dtype=tf.string
)

# Use tf.map_fn to process each volume
data_tensor = tf.map_fn(
_load_single_volume,
(coords_tensor, volinfo_paths_tensor),
fn_output_signature=dtype,
dtype=dtype,
)

return data_tensor


def load_from_numpylike(
coordinates, volume_names, shape, volume_map, name=None
):
"""TensorFlow Python op that loads data from Numpy-like volumes.
The volume object must support Numpy-like indexing, as well as shape, ndim,
and dtype properties. The volume can be 3d or 4d.
Args:
coordinates: tensor of shape [1, 3] containing XYZ coordinates of the
center of the subvolume to load.
coordinates: tensor of shape [1, 3] containing XYZ coordinates of the center
of the subvolume to load.
volume_names: tensor of shape [1] containing names of volumes to load data
from.
from.
shape: a 3-sequence giving the XYZ shape of the data to load.
volume_map: a dictionary mapping volume names to volume objects. See above
for API requirements of the Numpy-like volume objects.
for API requirements of the Numpy-like volume objects.
name: the op name.
Returns:
Expand Down

0 comments on commit ae3377b

Please sign in to comment.