From ae3377b7077446f36132803e523d4d15432a9354 Mon Sep 17 00:00:00 2001 From: Jin Xu Date: Mon, 24 Jun 2024 12:39:04 -0700 Subject: [PATCH] Migrate load_from_volume into inputs.py PiperOrigin-RevId: 646187131 --- ffn/training/inputs.py | 99 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 93 insertions(+), 6 deletions(-) diff --git a/ffn/training/inputs.py b/ffn/training/inputs.py index 4a521b2..6cfbf9c 100644 --- a/ffn/training/inputs.py +++ b/ffn/training/inputs.py @@ -27,6 +27,7 @@ import numpy as np import tensorflow.compat.v1 as tf from tensorflow.io import gfile +import tensorstore as ts def create_filename_queue(coordinates_file_pattern, shuffle=True): @@ -323,21 +324,107 @@ def weighted_load_patch_coordinates( ) -def load_from_numpylike(coordinates, volume_names, shape, volume_map, - name=None): +def _filter_masked(item, volinfo_map_string: str): + mask_value = load_from_volume( + item['coord'], + item['volname'], + patch_size=(1, 1, 1), + dtype=tf.int64, + num_channels=1, + volinfo_map_string=volinfo_map_string, + ) + return mask_value[0, 0, 0, 0, 0] > 0 + + +def load_from_volume( + coord, volname, patch_size, dtype, num_channels, volinfo_map_string: str +): + """Loads data from a volume using TensorStore. + + Args: + coord: The coordinates to load from. + volname: The name of the volume. + patch_size: The size of the patch to load. + dtype: The data type of the volume. + num_channels: The number of channels in the volume. + volinfo_map_string: A string representation of the volume info map with the + format "volname1:volinfo_path1,volname2:volinfo_path2". + + Returns: + A tensor containing the loaded data. + """ + if num_channels != 1: + raise ValueError('Only num_channels=1 is currently supported.') + + volinfo_map = {} + for pair in volinfo_map_string.split(','): + name, path = pair.split(':') + volinfo_map[name.strip()] = path.strip() + + def _load_single_volume(inputs): + coord, volinfo_path = inputs + print('volinfo_path:', volinfo_path) + print('coord:', coord) + volinfo_path = volinfo_path.numpy().decode('utf-8') + coord = coord.numpy() + spec = {'driver': 'volumestore', 'volinfo_path': volinfo_path} + + store = ts.open(spec, open=True).result() + + start_coord = [max(0, c - (p // 2)) for c, p in zip(coord, patch_size)] + stop_coord = [ + min(store.shape[i], c + (p // 2) + (p % 2)) + for i, (c, p) in enumerate(zip(coord, patch_size)) + ] + + data = ( + store[ + start_coord[0] : stop_coord[0], + start_coord[1] : stop_coord[1], + start_coord[2] : stop_coord[2], + ] + .read() + .result() + ) + + data = data[:, :, :, 0].transpose(2, 1, 0).astype(dtype.as_numpy_dtype) + data = data[..., tf.newaxis] + return data + + patch_size = list(patch_size) + # Convert lists to tensors for tf.map_fn + coords_tensor = tf.convert_to_tensor(coord) + volinfo_paths_tensor = tf.convert_to_tensor( + [volinfo_map[v].encode('utf-8') for v in volname], dtype=tf.string + ) + + # Use tf.map_fn to process each volume + data_tensor = tf.map_fn( + _load_single_volume, + (coords_tensor, volinfo_paths_tensor), + fn_output_signature=dtype, + dtype=dtype, + ) + + return data_tensor + + +def load_from_numpylike( + coordinates, volume_names, shape, volume_map, name=None +): """TensorFlow Python op that loads data from Numpy-like volumes. The volume object must support Numpy-like indexing, as well as shape, ndim, and dtype properties. The volume can be 3d or 4d. Args: - coordinates: tensor of shape [1, 3] containing XYZ coordinates of the - center of the subvolume to load. + coordinates: tensor of shape [1, 3] containing XYZ coordinates of the center + of the subvolume to load. volume_names: tensor of shape [1] containing names of volumes to load data - from. + from. shape: a 3-sequence giving the XYZ shape of the data to load. volume_map: a dictionary mapping volume names to volume objects. See above - for API requirements of the Numpy-like volume objects. + for API requirements of the Numpy-like volume objects. name: the op name. Returns: