diff --git a/ffn/training/augmentation.py b/ffn/training/augmentation.py index 0e9f518..bb77146 100644 --- a/ffn/training/augmentation.py +++ b/ffn/training/augmentation.py @@ -48,7 +48,7 @@ def xy_transpose(data, decision): """ with tf.name_scope('augment_xy_transpose'): rank = data.get_shape().ndims - perm = range(rank) + perm = list(range(rank)) perm[rank - 3], perm[rank - 2] = perm[rank - 2], perm[rank - 3] return tf.cond(decision, lambda: tf.transpose(data, perm), diff --git a/ffn/training/examples.py b/ffn/training/examples.py new file mode 100644 index 0000000..f162fdb --- /dev/null +++ b/ffn/training/examples.py @@ -0,0 +1,328 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for building training examples for FFN training.""" + +import collections +from concurrent import futures +import itertools +from typing import Callable, Iterable, Optional, Sequence + +import numpy as np +from scipy import special + +from ..inference import movement +from . import mask +from . import model as ffn_model +from . import tracker + +GetOffsets = Callable[ + [ffn_model.ModelInfo, np.ndarray, np.ndarray, tracker.EvalTracker], + Iterable[tuple[int, int, int]]] + + +def get_example(load_example, eval_tracker: tracker.EvalTracker, + info: ffn_model.ModelInfo, get_offsets: GetOffsets, + seed_pad: float, seed_shape: tuple[int, int, int]): + """Generates individual training examples. + + Args: + load_example: callable returning a tuple of image and label ndarrays as well + as the seed coordinate and volume name of the example + eval_tracker: tracker.EvalTracker object + info: ModelInfo metadata about the model + get_offsets: callable returning an iterable of (x, y, z) offsets to + investigate within the training patch + seed_pad: value to fill the empty areas of the seed with + seed_shape: z, y, x shape of the seed + + Yields: + tuple of [1, z, y, x, 1]-shaped arrays for: + seed, image, label, weights + """ + while True: + ex = load_example() + full_patches, full_labels, loss_weights, coord, volname = ex + + # Start with a clean seed. + seed = special.logit(mask.make_seed(seed_shape, 1, pad=seed_pad)) + + for off in get_offsets(info, seed, full_labels, eval_tracker): + predicted = mask.crop_and_pad(seed, off, info.input_seed_size[::-1]) + patches = mask.crop_and_pad(full_patches, off, + info.input_image_size[::-1]) + labels = mask.crop_and_pad(full_labels, off, info.pred_mask_size[::-1]) + weights = mask.crop_and_pad(loss_weights, off, info.pred_mask_size[::-1]) + + # Necessary, since the caller is going to update the array and these + # changes need to be visible in the following iterations. + assert predicted.base is seed + yield predicted, patches, labels, weights + + eval_tracker.add_patch(full_labels, seed, loss_weights, coord, volname, + full_patches) + + +ExampleGenerator = Iterable[tuple[np.ndarray, np.ndarray, np.ndarray, + np.ndarray]] +_BatchGenerator = Iterable[tuple[Sequence[np.ndarray], Sequence[np.ndarray], + Sequence[np.ndarray], Sequence[np.ndarray]]] + + +def _batch_gen(make_example_generator_fn: Callable[[], ExampleGenerator], + batch_size: int) -> _BatchGenerator: + """Generates batches of training examples.""" + # Create a separate generator for every element in the batch. This generator + # will automatically advance to a different training example once the + # allowed moves for the current location are exhausted. + example_gens = [make_example_generator_fn() for _ in range(batch_size)] + + with futures.ThreadPoolExecutor(max_workers=batch_size) as tpe: + while True: + fs = [] + for gen in example_gens: + fs.append(tpe.submit(next, gen)) + + # `batch` is sequence of `batch_size` tuples returned by the + # `get_example` generator, to which we apply the following transformation: + # [(a0, b0), (a1, b1), .. (an, bn)] -> [(a0, a1, .., an), + # (b0, b1, .., bn)] + # (where n is the batch size) to get a sequence, each element of which + # represents a batch of values of a given type (e.g., seed, image, etc.) + batch = [f.result() for f in fs] + yield tuple(zip(*batch)) + + +class BatchExampleIter: + """Generates batches of training examples.""" + + def __init__(self, example_generator_fn: Callable[[], ExampleGenerator], + eval_tracker: tracker.EvalTracker, batch_size: int, + info: ffn_model.ModelInfo): + self._eval_tracker = eval_tracker + self._batch_generator = _batch_gen(example_generator_fn, batch_size) + self._seeds = None + self._info = info + + def __iter__(self): + return self + + def __next__(self): + seeds, patches, labels, weights = next(self._batch_generator) + self._seeds = seeds + batched_seeds = np.concatenate(seeds) + batched_weights = np.concatenate(weights) + self._eval_tracker.track_weights(batched_weights) + return (batched_seeds, np.concatenate(patches), np.concatenate(labels), + batched_weights) + + def update_seeds(self, batched_seeds: np.ndarray): + """Distributes updated predictions back to the generator buffers. + + Args: + batched_seeds: [b, z, y, x, c] array of the part of the seed updated by + the model + """ + assert self._seeds is not None + + # Convert to numpy array in case this function was called with an array-like + # object backed by accelerator memory. + batched_seeds = np.asarray(batched_seeds) + + dx = self._info.input_seed_size[0] - self._info.pred_mask_size[0] + dy = self._info.input_seed_size[1] - self._info.pred_mask_size[1] + dz = self._info.input_seed_size[2] - self._info.pred_mask_size[2] + + if dz == 0 and dy == 0 and dx == 0: + for i in range(len(self._seeds)): + self._seeds[i][:] = batched_seeds[i, ...] + else: + for i in range(len(self._seeds)): + self._seeds[i][:, # + dz // 2:-(dz - dz // 2), # + dy // 2:-(dy - dy // 2), # + dx // 2:-(dx - dx // 2), # + :] = batched_seeds[i, ...] + + +def _eval_move(seed: np.ndarray, labels: np.ndarray, + off_xyz: tuple[int, int, int], seed_threshold: float, + label_threshold: float) -> tuple[bool, bool]: + """Evaluates a FOV move.""" + valid_move = seed[:, # + seed.shape[1] // 2 + off_xyz[2], # + seed.shape[2] // 2 + off_xyz[1], # + seed.shape[3] // 2 + off_xyz[0], # + 0] >= seed_threshold + wanted_move = ( + labels[:, # + labels.shape[1] // 2 + off_xyz[2], # + labels.shape[2] // 2 + off_xyz[1], # + labels.shape[3] // 2 + off_xyz[0], # + 0] >= label_threshold) + + return valid_move, wanted_move + + +FovShifts = Optional[Iterable[tuple[int, int, int]]] + + +def fixed_offsets(info: ffn_model.ModelInfo, + seed: np.ndarray, + labels: np.ndarray, + eval_tracker: tracker.EvalTracker, + threshold: float, + fov_shifts: FovShifts = None): + """Generates offsets based on a fixed list.""" + del info + + label_threshold = special.expit(threshold) + for off in itertools.chain([(0, 0, 0)], fov_shifts): # xyz + valid_move, wanted_move = _eval_move(seed, labels, off, threshold, + label_threshold) + eval_tracker.record_move(wanted_move, valid_move, off) + if not valid_move: + continue + + yield off + + +def fixed_offsets_window(info: ffn_model.ModelInfo, + seed: np.ndarray, + labels: np.ndarray, + eval_tracker: tracker.EvalTracker, + threshold: float, + fov_shifts: FovShifts = None, + radius: int = 4): + """Like fixed_offsets, but allows more flexible moves. + + Instead of looking at the single voxel pointed to by the offset vector, + considers a small window in the plane orthogonal to the movement direction. + + This helps with training on thin processes that might not be followed by the + 'fixed_offsets' movement policy. + + Args: + info: ModelInfo object + seed: seed array (logits) + labels: label array (probabilities) + eval_tracker: EvalTracker object + threshold: value that the seed needs to match or exceed in order to be + considered a valid move target + fov_shifts: list of XYZ moves to evaluate + radius: max distance away from the offset vector to look for voxels crossing + threshold (within a plan ortohogonal to that vector) + + Yields: + XYZ offset tuples indicating moves to take relative to the center of 'seed' + """ + off = 0, 0, 0 + label_threshold = special.expit(threshold) + valid_move, wanted_move = _eval_move(seed, labels, off, threshold, + label_threshold) + eval_tracker.record_move(wanted_move, valid_move, off) + if valid_move: + yield off + + seed_center = np.array(seed.shape[1:4]) // 2 + label_center = np.array(labels.shape[1:4]) // 2 + + # Define a thin shell at distance of 'delta' around the center. + hz, hy, hx = np.mgrid[:seed.shape[1], :seed.shape[2], :seed.shape[3]] + hz -= seed_center[0] + hy -= seed_center[1] + hx -= seed_center[2] + halo = ((np.abs(hx) <= info.deltas[0]) & # + (np.abs(hy) <= info.deltas[1]) & # + (np.abs(hz) <= info.deltas[2]) & ( # + (np.abs(hx) == info.deltas[0]) | # + (np.abs(hy) == info.deltas[1]) | # + (np.abs(hz) == info.deltas[2]))) + + for off in fov_shifts: # xyz + # Look for voxels within a window of radius 'radius' around the standard + # move point. We can extend this window in any direction below since + # the 'halo' array is set up to restrict us to relevant voxels only. + off_center = seed_center + off[::-1] + pre = off_center - radius + post = off_center + radius + 1 + zz, yy, xx = np.where(halo[pre[0]:post[0], pre[1]:post[1], pre[2]:post[2]]) + + zz_s = zz + pre[0] + yy_s = yy + pre[1] + xx_s = xx + pre[2] + xx_l = xx_s + label_center[2] - seed_center[2] + yy_l = yy_s + label_center[1] - seed_center[1] + zz_l = zz_s + label_center[0] - seed_center[0] + + # Under 'fixed_offsets' the exact voxel at the offset vector would + # have to cross the threshold. Here it is instead sufficient that any voxel + # with a specified radius does. + valid_move = np.any(seed[:, zz_s, yy_s, xx_s, :] >= threshold) + wanted_move = np.any(labels[:, zz_l, yy_l, xx_l, :] >= label_threshold) + eval_tracker.record_move(wanted_move, valid_move, off) + if valid_move: + yield off + + +def no_offsets(info: ffn_model.ModelInfo, seed: np.ndarray, labels: np.ndarray, + eval_tracker: tracker.EvalTracker): + del info, labels, seed + eval_tracker.record_move(True, True, (0, 0, 0)) + yield (0, 0, 0) + + +def max_pred_offsets(info: ffn_model.ModelInfo, seed: np.ndarray, + labels: np.ndarray, eval_tracker: tracker.EvalTracker, + threshold: float, max_radius: np.ndarray): + """Generates offsets with the policy used for inference.""" + # Always start at the center. + queue = collections.deque([(0, 0, 0)]) # xyz + done = set() + + label_threshold = special.expit(threshold) + deltas = np.array(info.deltas) + while queue: + offset = np.array(queue.popleft()) + + # Drop any offsets that would take us beyond the image fragment we + # loaded for training. + if np.any(np.abs(np.array(offset)) > max_radius): + continue + + # Ignore locations that were visited previously. + quantized_offset = tuple((offset + deltas / 2) // np.maximum(deltas, 1)) + + if quantized_offset in done: + continue + + valid, wanted = _eval_move(seed, labels, tuple(offset), threshold, + label_threshold) + eval_tracker.record_move(wanted, valid, (0, 0, 0)) + + if not valid or (not wanted and quantized_offset != (0, 0, 0)): + continue + + done.add(quantized_offset) + + yield tuple(offset) + + # Look for new offsets within the updated seed. + curr_seed = mask.crop_and_pad(seed, offset, info.pred_mask_size[::-1]) + todos = sorted( + movement.get_scored_move_offsets( + info.deltas[::-1], curr_seed[0, ..., 0], threshold=threshold), + reverse=True) + queue.extend((x[2] + offset[0], x[1] + offset[1], x[0] + offset[2]) + for _, x in todos) diff --git a/ffn/training/inputs.py b/ffn/training/inputs.py index 5c359fe..5f6b674 100644 --- a/ffn/training/inputs.py +++ b/ffn/training/inputs.py @@ -18,8 +18,8 @@ import numpy as np import tensorflow.compat.v1 as tf +from connectomics.common import bounding_box from tensorflow.io import gfile -from ..utils import bounding_box def create_filename_queue(coordinates_file_pattern, shuffle=True): @@ -139,7 +139,7 @@ def _load_from_numpylike(coord, volname): volume = volume_map[volname.decode('ascii')] # Get data, including all channels if volume is 4d. starts = np.array(coord) - start_offset - slc = bounding_box.BoundingBox(start=starts, size=shape).to_slice() + slc = bounding_box.BoundingBox(start=starts, size=shape).to_slice3d() if volume.ndim == 4: slc = np.index_exp[:] + slc data = volume[slc] diff --git a/ffn/training/model.py b/ffn/training/model.py index d3b6449..aa62a2b 100644 --- a/ffn/training/model.py +++ b/ffn/training/model.py @@ -14,33 +14,45 @@ # ============================================================================== """Classes for FFN model definition.""" -import tensorflow.compat.v1 as tf -from . import optimizer +import dataclasses +import numpy as np +import tensorflow.compat.v1 as tf -class FFNModel(object): - """Base class for FFN models.""" +from . import optimizer - # Dimensionality of the model (2 or 3). - dim = None - ############################################################################ - # (x, y, z) tuples defining various properties of the network. - # Note that 3-tuples should be used even for 2D networks, in which case - # the third (z) value is ignored. +@dataclasses.dataclass +class ModelInfo: + """Basic geometric information about the network. + Arrays are (x, y, z), even for 2D models, in which case the z value is + ignored. + """ # How far to move the field of view in the respective directions. - deltas = None + deltas: np.ndarray + + # Size of the predicted patch as returned by the model. + pred_mask_size: np.ndarray # Size of the input image and seed subvolumes to be used during inference. # This is enough information to execute a single prediction step, without # moving the field of view. - input_image_size = None - input_seed_size = None + input_seed_size: np.ndarray + input_image_size: np.ndarray - # Size of the predicted patch as returned by the model. - pred_mask_size = None - ########################################################################### + # For JAX models only: whether the predicted seed should be added to + # its initial state. + additive: bool = False + + +class FFNModel: + """Base class for FFN models.""" + + info: ModelInfo + + # Dimensionality of the model (2 or 3). + dim: int = None # TF op to compute loss optimized during training. This should include all # loss components in case more than just the pixelwise loss is used. @@ -49,18 +61,21 @@ class FFNModel(object): # TF op to call to perform loss optimization on the model. train_op = None - def __init__(self, deltas, batch_size=None, define_global_step=True): + def __init__(self, + info: ModelInfo, + batch_size=None, + define_global_step=True): assert self.dim is not None - self.deltas = deltas + self.info = info self.batch_size = batch_size # Initialize the shift collection. This is used during training with the # fixed step size policy. self.shifts = [] - for dx in (-self.deltas[0], 0, self.deltas[0]): - for dy in (-self.deltas[1], 0, self.deltas[1]): - for dz in (-self.deltas[2], 0, self.deltas[2]): + for dx in (-self.info.deltas[0], 0, self.info.deltas[0]): + for dy in (-self.info.deltas[1], 0, self.info.deltas[1]): + for dz in (-self.info.deltas[2], 0, self.info.deltas[2]): if dx == 0 and dy == 0 and dz == 0: continue self.shifts.append((dx, dy, dz)) @@ -80,49 +95,30 @@ def __init__(self, deltas, batch_size=None, define_global_step=True): # If specified, should have the same shape as self.labels. self.loss_weights = None - self.logits = None # type: tf.Operation + # Updated part of the seed, in logit form. + self.logits: tf.Operation = None # List of image tensors to save in summaries. The images are concatenated # along the X axis. self._images = [] - def set_uniform_io_size(self, patch_size): - """Initializes unset input/output sizes to 'patch_size', sets input shapes. - - This assumes that the inputs and outputs are of equal size, and that exactly - one step is executed in every direction during training. - - Args: - patch_size: (x, y, z) specifying the input/output patch size - - Returns: - None - """ - if self.pred_mask_size is None: - self.pred_mask_size = patch_size - if self.input_seed_size is None: - self.input_seed_size = patch_size - if self.input_image_size is None: - self.input_image_size = patch_size - self.set_input_shapes() - def set_input_shapes(self): """Sets the shape inference for input_seed and input_patches. Assumes input_seed_size and input_image_size are already set. """ self.input_seed.set_shape([self.batch_size] + - list(self.input_seed_size[::-1]) + [1]) + list(self.info.input_seed_size[::-1]) + [1]) self.input_patches.set_shape([self.batch_size] + - list(self.input_image_size[::-1]) + [1]) + list(self.info.input_image_size[::-1]) + [1]) def set_up_sigmoid_pixelwise_loss(self, logits): """Sets up the loss function of the model.""" assert self.labels is not None assert self.loss_weights is not None - pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, - labels=self.labels) + pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits( + logits=logits, labels=self.labels) pixel_loss *= self.loss_weights self.loss = tf.reduce_mean(pixel_loss) tf.summary.scalar('pixel_loss', self.loss) @@ -135,6 +131,8 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): tf.summary.scalar('optimizer_loss', self.loss) opt = optimizer.optimizer_from_flags() + self.opt = opt + grads_and_vars = opt.compute_gradients(loss) for g, v in grads_and_vars: @@ -142,8 +140,7 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): tf.logging.error('Gradient is None: %s', v.op.name) if max_gradient_entry_mag > 0.0: - grads_and_vars = [(tf.clip_by_value(g, - -max_gradient_entry_mag, + grads_and_vars = [(tf.clip_by_value(g, -max_gradient_entry_mag, +max_gradient_entry_mag), v) for g, v, in grads_and_vars] @@ -152,14 +149,12 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): for var in trainables: tf.summary.histogram(var.name.replace(':0', ''), var) for grad, var in grads_and_vars: - tf.summary.histogram( - 'gradients/%s' % var.name.replace(':0', ''), grad) + tf.summary.histogram('gradients/%s' % var.name.replace(':0', ''), grad) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): - self.train_op = opt.apply_gradients(grads_and_vars, - global_step=self.global_step, - name='train') + self.train_op = opt.apply_gradients( + grads_and_vars, global_step=self.global_step, name='train') def show_center_slice(self, image, sigmoid=True): image = image[:, image.get_shape().dims[1] // 2, :, :, :] @@ -172,18 +167,19 @@ def add_summaries(self): def update_seed(self, seed, update): """Updates the initial 'seed' with 'update'.""" - dx = self.input_seed_size[0] - self.pred_mask_size[0] - dy = self.input_seed_size[1] - self.pred_mask_size[1] - dz = self.input_seed_size[2] - self.pred_mask_size[2] + dx = self.info.input_seed_size[0] - self.info.pred_mask_size[0] + dy = self.info.input_seed_size[1] - self.info.pred_mask_size[1] + dz = self.info.input_seed_size[2] - self.info.pred_mask_size[2] if dx == 0 and dy == 0 and dz == 0: seed += update else: - seed += tf.pad(update, [[0, 0], - [dz // 2, dz - dz // 2], - [dy // 2, dy - dy // 2], - [dx // 2, dx - dx // 2], - [0, 0]]) + seed += tf.pad(update, + [[0, 0], # + [dz // 2, dz - dz // 2], # + [dy // 2, dy - dy // 2], # + [dx // 2, dx - dx // 2], # + [0, 0]]) return seed def define_tf_graph(self): diff --git a/ffn/training/models/convstack_3d.py b/ffn/training/models/convstack_3d.py index 675145b..b6ee826 100644 --- a/ffn/training/models/convstack_3d.py +++ b/ffn/training/models/convstack_3d.py @@ -15,6 +15,7 @@ """Simplest FFN model, as described in https://arxiv.org/abs/1611.00421.""" import functools +import itertools import tensorflow.compat.v1 as tf import tf_slim from .. import model @@ -22,49 +23,70 @@ # Note: this model was originally trained with conv3d layers initialized with # TruncatedNormalInitializedVariable with stddev = 0.01. -def _predict_object_mask(net, depth=9): +def _predict_object_mask(net, depth=9, features=32): """Computes single-object mask prediction.""" + conv = tf_slim.convolution3d conv = functools.partial( - tf_slim.convolution3d, - num_outputs=32, kernel_size=(3, 3, 3), padding='SAME') - net = conv(net, scope='conv0_a') - net = conv(net, scope='conv0_b', activation_fn=None) + tf_slim.convolution3d, kernel_size=(3, 3, 3), padding='SAME' + ) + + if isinstance(features, int): + feats = itertools.repeat(features) + else: + feats = iter(features) + + net = conv(net, scope='conv0_a', num_outputs=next(feats)) + net = conv(net, scope='conv0_b', activation_fn=None, num_outputs=next(feats)) for i in range(1, depth): with tf.name_scope('residual%d' % i): in_net = net net = tf.nn.relu(net) - net = conv(net, scope='conv%d_a' % i) - net = conv(net, scope='conv%d_b' % i, activation_fn=None) + net = conv(net, scope='conv%d_a' % i, num_outputs=next(feats)) + net = conv( + net, scope='conv%d_b' % i, activation_fn=None, num_outputs=next(feats) + ) net += in_net net = tf.nn.relu(net) logits = tf_slim.convolution3d( - net, 1, (1, 1, 1), activation_fn=None, scope='conv_lom') + net, 1, (1, 1, 1), activation_fn=None, scope='conv_lom' + ) return logits class ConvStack3DFFNModel(model.FFNModel): + """A simple conv-stack FFN model. + + The model is composed of `depth` residual modules, operating at a + constant spatial resolution. + """ + dim = 3 - def __init__(self, fov_size=None, deltas=None, batch_size=None, depth=9): - super(ConvStack3DFFNModel, self).__init__(deltas, batch_size) - self.set_uniform_io_size(fov_size) + def __init__( + self, + fov_size=None, + deltas=None, + batch_size=None, + depth: int = 9, + features: int = 32, + **kwargs + ): + info = model.ModelInfo(deltas, fov_size, fov_size, fov_size) + super().__init__(info, batch_size, **kwargs) + self.set_input_shapes() self.depth = depth + self.features = features def define_tf_graph(self): self.show_center_slice(self.input_seed) - if self.input_patches is None: - self.input_patches = tf.placeholder( - tf.float32, [1] + list(self.input_image_size[::-1]) +[1], - name='patches') - net = tf.concat([self.input_patches, self.input_seed], 4) with tf.variable_scope('seed_update', reuse=False): - logit_update = _predict_object_mask(net, self.depth) + logit_update = _predict_object_mask(net, self.depth, self.features) logit_seed = self.update_seed(self.input_seed, logit_update) @@ -78,5 +100,3 @@ def define_tf_graph(self): self.show_center_slice(logit_seed) self.show_center_slice(self.labels, sigmoid=False) self.add_summaries() - - self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) diff --git a/ffn/training/optimizer.py b/ffn/training/optimizer.py index 6cc2417..bcd5e84 100644 --- a/ffn/training/optimizer.py +++ b/ffn/training/optimizer.py @@ -14,44 +14,115 @@ # ============================================================================== """Utilities to configure TF optimizers.""" -import tensorflow.compat.v1 as tf - from absl import flags +import tensorflow.compat.v1 as tf +_OPTIMIZER = flags.DEFINE_enum( + 'optimizer', + 'sgd', + ['momentum', 'sgd', 'adagrad', 'adam', 'rmsprop'], + 'Which optimizer to use.', +) +_LEARNING_RATE = flags.DEFINE_float( + 'learning_rate', 0.001, 'Initial learning rate.' +) +_MOMENTUM = flags.DEFINE_float('momentum', 0.9, 'Momentum.') +_LEARNING_RATE_DECAY_FACTOR = flags.DEFINE_float( + 'learning_rate_decay_factor', None, 'Learning rate decay factor.' +) +_DECAY_STEPS = flags.DEFINE_integer( + 'decay_steps', + None, + ( + 'How many steps the model needs to train for in order for ' + 'the decay factor to be applied to the learning rate.' + ), +) +_NUM_EPOCHS_PER_DECAY = flags.DEFINE_float( + 'num_epochs_per_decay', + 2.0, + 'Number of epochs after which learning rate decays.', +) +_RMSPROP_DECAY = flags.DEFINE_float( + 'rmsprop_decay', 0.9, 'Decay term for RMSProp.' +) +_ADAM_BETA1 = flags.DEFINE_float( + 'adam_beta1', 0.9, 'Gradient decay term for Adam.' +) +_ADAM_BETA2 = flags.DEFINE_float( + 'adam_beta2', 0.999, 'Gradient^2 decay term for Adam.' +) +_EPSILON = flags.DEFINE_float( + 'epsilon', 1e-8, 'Epsilon term for RMSProp and Adam.' +) +_SYNC_SGD = flags.DEFINE_boolean( + 'sync_sgd', False, 'Whether to use synchronous SGD.' +) +_REPLICAS_TO_AGGREGATE = flags.DEFINE_integer( + 'replicas_to_aggregate', + None, + 'When using sync SGD, over how many replicas to aggregate the gradients.', +) +_TOTAL_REPLICAS = flags.DEFINE_integer( + 'total_replicas', + None, + 'When using sync SGD, total number of replicas in the training pool.', +) -FLAGS = flags.FLAGS -flags.DEFINE_string('optimizer', 'sgd', - 'Which optimizer to use. Valid values are: ' - 'momentum, sgd, adagrad, adam, rmsprop.') -flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.') -flags.DEFINE_float('momentum', 0.9, 'Momentum.') -flags.DEFINE_float('learning_rate_decay_factor', 0.94, - 'Learning rate decay factor.') -flags.DEFINE_float('num_epochs_per_decay', 2.0, - 'Number of epochs after which learning rate decays.') -flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') -flags.DEFINE_float('adam_beta1', 0.9, 'Gradient decay term for Adam.') -flags.DEFINE_float('adam_beta2', 0.999, 'Gradient^2 decay term for Adam.') -flags.DEFINE_float('epsilon', 1e-8, 'Epsilon term for RMSProp and Adam.') +def _optimizer_from_flags(): + """Defines a TF optimizer based on flag settings.""" + lr = _LEARNING_RATE.value + if ( + _LEARNING_RATE_DECAY_FACTOR.value is not None + and _DECAY_STEPS.value is not None + ): + lr = tf.train.exponential_decay( + _LEARNING_RATE.value, + tf.train.get_or_create_global_step(), + _DECAY_STEPS.value, + _LEARNING_RATE_DECAY_FACTOR.value, + staircase=True, + ) + tf.summary.scalar('learning_rate', lr) -def optimizer_from_flags(): - lr = FLAGS.learning_rate - if FLAGS.optimizer == 'momentum': - return tf.train.MomentumOptimizer(lr, FLAGS.momentum) - elif FLAGS.optimizer == 'sgd': + if _OPTIMIZER.value == 'momentum': + return tf.train.MomentumOptimizer(lr, _MOMENTUM.value) + elif _OPTIMIZER.value == 'sgd': return tf.train.GradientDescentOptimizer(lr) - elif FLAGS.optimizer == 'adagrad': + elif _OPTIMIZER.value == 'adagrad': return tf.train.AdagradOptimizer(lr) - elif FLAGS.optimizer == 'adam': - return tf.train.AdamOptimizer(learning_rate=lr, - beta1=FLAGS.adam_beta1, - beta2=FLAGS.adam_beta2, - epsilon=FLAGS.epsilon) - elif FLAGS.optimizer == 'rmsprop': - return tf.train.RMSPropOptimizer(lr, FLAGS.rmsprop_decay, - momentum=FLAGS.momentum, - epsilon=FLAGS.epsilon) + elif _OPTIMIZER.value == 'adam': + return tf.train.AdamOptimizer( + learning_rate=lr, + beta1=_ADAM_BETA1.value, + beta2=_ADAM_BETA2.value, + epsilon=_EPSILON.value, + ) + elif _OPTIMIZER.value == 'rmsprop': + return tf.train.RMSPropOptimizer( + lr, + _RMSPROP_DECAY.value, + momentum=_MOMENTUM.value, + epsilon=_EPSILON.value, + ) + else: + raise ValueError('Unknown optimizer: %s' % _OPTIMIZER.value) + + +def optimizer_from_flags(): + """Defines a TF optimizer based on command-line flags.""" + opt = _optimizer_from_flags() + if _SYNC_SGD.value: + assert _REPLICAS_TO_AGGREGATE.value is not None + if _TOTAL_REPLICAS.value is not None: + assert _TOTAL_REPLICAS.value >= _REPLICAS_TO_AGGREGATE.value + + return tf.train.SyncReplicasOptimizer( + opt, + replicas_to_aggregate=_REPLICAS_TO_AGGREGATE.value, + total_num_replicas=_TOTAL_REPLICAS.value, + ) else: - raise ValueError('Unknown optimizer: %s' % FLAGS.optimizer) + return opt diff --git a/ffn/training/tracker.py b/ffn/training/tracker.py new file mode 100644 index 0000000..e71c672 --- /dev/null +++ b/ffn/training/tracker.py @@ -0,0 +1,362 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for tracking and reporting the training status.""" + +import collections +import enum +import io +import itertools +from typing import Optional, Sequence + +import numpy as np + +import PIL +import PIL.Image +import PIL.ImageDraw +from scipy import special +from skimage import measure + +import tensorflow.google as tf +from . import mask +from . import variables + + +if tf.executing_eagerly(): + tf.compat.v2.experimental.numpy.experimental_enable_numpy_behavior() + + +class MoveType(enum.IntEnum): + CORRECT = 0 + MISSED = 1 + SPURIOUS = 2 + + +class VoxelType(enum.IntEnum): + TOTAL = 0 + MASKED = 1 + + +class PredictionType(enum.IntEnum): + TP = 0 + TN = 1 + FP = 2 + FN = 3 + + +class FovStat(enum.IntEnum): + TOTAL_VOXELS = 0 + MASKED_VOXELS = 1 + WEIGHTS_SUM = 2 + + +class EvalTracker: + """Tracks eval results over multiple training steps.""" + + def __init__(self, + eval_shape: list[int], + shifts: Sequence[tuple[int, int, int]]): + # TODO(mjanusz): Remove this TFv1 code once no longer used. + if not tf.executing_eagerly(): + self.eval_labels = tf.compat.v1.placeholder( + tf.float32, [1] + eval_shape + [1], name='eval_labels') + self.eval_preds = tf.compat.v1.placeholder( + tf.float32, [1] + eval_shape + [1], name='eval_preds') + self.eval_weights = tf.compat.v1.placeholder( + tf.float32, [1] + eval_shape + [1], name='eval_weights') + self.eval_loss = tf.reduce_mean( + self.eval_weights * tf.nn.sigmoid_cross_entropy_with_logits( + logits=self.eval_preds, labels=self.eval_labels)) + self.sess = None + self.eval_threshold = special.logit(0.9) + self._eval_shape = eval_shape # zyx + self._define_tf_vars(shifts) + self._patch_count = 0 + + self.reset() + + def _add_tf_var(self, name, shape, dtype): + v = variables.TFSyncVariable(name, shape, dtype) + setattr(self, name, v) + self._tf_vars.append(v) + return v + + def _define_tf_vars(self, fov_shifts: Sequence[tuple[int, int, int]]): + """Defines TFSyncVariables.""" + self._tf_vars = [] + self._add_tf_var('moves', [3], tf.int64) + self._add_tf_var('loss', [1], tf.float32) + self._add_tf_var('num_voxels', [2], tf.int64) + self._add_tf_var('num_patches', [1], tf.int64) + self._add_tf_var('prediction_counts', [4], tf.int64) + self._add_tf_var('fov_stats', [3], tf.float32) + + radii = set([int(np.linalg.norm(s)) for s in fov_shifts]) + radii.add(0) + self.moves_by_r = {} + for r in radii: + self.moves_by_r[r] = self._add_tf_var('moves_%d' % r, [3], tf.int64) + + def to_tf(self): + ops = [] + feed_dict = {} + + for var in self._tf_vars: + var.to_tf(ops, feed_dict) + + assert self.sess is not None + self.sess.run(ops, feed_dict) + + def from_tf(self): + ops = [var.from_tf for var in self._tf_vars] + assert self.sess is not None + values = self.sess.run(ops) + + for value, var in zip(values, self._tf_vars): + var.tf_value = value + + def reset(self): + """Resets status of the tracker.""" + self.images_xy = collections.deque(maxlen=16) + self.images_xz = collections.deque(maxlen=16) + self.images_yz = collections.deque(maxlen=16) + self.meshes = collections.deque(maxlen=16 * 3) + for var in self._tf_vars: + var.reset() + + def track_weights(self, weights: np.ndarray): + self.fov_stats.value[FovStat.TOTAL_VOXELS] += weights.size + self.fov_stats.value[FovStat.MASKED_VOXELS] += np.sum(weights == 0.0) + self.fov_stats.value[FovStat.WEIGHTS_SUM] += np.sum(weights) + + def record_move(self, wanted: bool, executed: bool, + offset_xyz: Sequence[int]): + """Records an FFN FOV move.""" + r = int(np.linalg.norm(offset_xyz)) + assert r in self.moves_by_r, ('%d not in %r' % + (r, list(self.moves_by_r.keys()))) + + if wanted: + if executed: + self.moves.value[MoveType.CORRECT] += 1 + self.moves_by_r[r].value[MoveType.CORRECT] += 1 + else: + self.moves.value[MoveType.MISSED] += 1 + self.moves_by_r[r].value[MoveType.MISSED] += 1 + elif executed: + self.moves.value[MoveType.SPURIOUS] += 1 + self.moves_by_r[r].value[MoveType.SPURIOUS] += 1 + + + def slice_image(self, coord: np.ndarray, labels: np.ndarray, + predicted: np.ndarray, weights: np.ndarray, + slice_axis: int) -> tf.Summary.Value: + """Builds a tf.Summary showing a slice of an object mask. + + The object mask slice is shown side by side with the corresponding + ground truth mask. + + Args: + coord: [1, 3] xyz coordinate as ndarray + labels: ndarray of ground truth data, shape [1, z, y, x, 1] + predicted: ndarray of predicted data, shape [1, z, y, x, 1] + weights: ndarray of loss weights, shape [1, z, y, x, 1] + slice_axis: axis in the middle of which to place the cutting plane for + which the summary image will be generated, valid values are 2 ('x'), 1 + ('y'), and 0 ('z'). + + Returns: + tf.Summary.Value object with the image. + """ + zyx = list(labels.shape[1:-1]) + selector = [0, slice(None), slice(None), slice(None), 0] + selector[slice_axis + 1] = zyx[slice_axis] // 2 + selector = tuple(selector) # for numpy indexing + + del zyx[slice_axis] + h, w = zyx + + buf = io.BytesIO() + labels = (labels[selector] * 255).astype(np.uint8) + predicted = (predicted[selector] * 255).astype(np.uint8) + weights = (weights[selector] * 255).astype(np.uint8) + + im = PIL.Image.fromarray( + np.repeat( + np.concatenate([labels, predicted, weights], axis=1)[..., + np.newaxis], + 3, + axis=2), 'RGB') + draw = PIL.ImageDraw.Draw(im) + + x, y, z = coord.squeeze() + draw.text((1, 1), '%d %d %d' % (x, y, z), fill='rgb(255,64,64)') + del draw + + im.save(buf, 'PNG') + + axis_names = 'zyx' + axis_names = axis_names.replace(axis_names[slice_axis], '') + + return tf.Summary.Value( + tag='final_%s' % axis_names[::-1], + image=tf.Summary.Image( + height=h, + width=w * 3, + colorspace=3, # RGB + encoded_image_string=buf.getvalue())) + + def add_patch(self, + labels: np.ndarray, + predicted: np.ndarray, + weights: np.ndarray, + coord: Optional[np.ndarray] = None, + volname: Optional[str] = None, + patches: Optional[np.ndarray] = None, + image_summaries: bool = True): + """Evaluates single-object segmentation quality.""" + + predicted = mask.crop_and_pad(predicted, (0, 0, 0), self._eval_shape) + weights = mask.crop_and_pad(weights, (0, 0, 0), self._eval_shape) + labels = mask.crop_and_pad(labels, (0, 0, 0), self._eval_shape) + + if not tf.executing_eagerly(): + assert self.sess is not None + loss, = self.sess.run( + [self.eval_loss], { + self.eval_labels: labels, + self.eval_preds: predicted, + self.eval_weights: weights + }) + else: + loss = tf.reduce_mean(weights * tf.nn.sigmoid_cross_entropy_with_logits( + logits=predicted, labels=labels)) + + self.loss.value[:] += loss + self.num_voxels.value[VoxelType.TOTAL] += labels.size + self.num_voxels.value[VoxelType.MASKED] += np.sum(weights == 0.0) + + pred_mask = predicted >= self.eval_threshold + true_mask = labels > 0.5 + pred_bg = np.logical_not(pred_mask) + true_bg = np.logical_not(true_mask) + + self.prediction_counts.value[PredictionType.TP] += np.sum(pred_mask + & true_mask) + self.prediction_counts.value[PredictionType.TN] += np.sum(pred_bg & true_bg) + self.prediction_counts.value[PredictionType.FP] += np.sum(pred_mask + & true_bg) + self.prediction_counts.value[PredictionType.FN] += np.sum(pred_bg + & true_mask) + self.num_patches.value[:] += 1 + + if image_summaries: + predicted = special.expit(predicted) + self.images_xy.append( + self.slice_image(coord, labels, predicted, weights, 0)) + self.images_xz.append( + self.slice_image(coord, labels, predicted, weights, 1)) + self.images_yz.append( + self.slice_image(coord, labels, predicted, weights, 2)) + + + def _compute_classification_metrics(self, prediction_counts, prefix): + """Computes standard classification metrics.""" + tp = prediction_counts.tf_value[PredictionType.TP] + fp = prediction_counts.tf_value[PredictionType.FP] + tn = prediction_counts.tf_value[PredictionType.TN] + fn = prediction_counts.tf_value[PredictionType.FN] + + precision = tp / max(tp + fp, 1) + recall = tp / max(tp + fn, 1) + + if precision > 0 or recall > 0: + f1 = (2.0 * precision * recall / (precision + recall)) + else: + f1 = 0.0 + + return [ + tf.Summary.Value( + tag='%s/accuracy' % prefix, + simple_value=(tp + tn) / max(tp + tn + fp + fn, 1)), + tf.Summary.Value(tag='%s/precision' % prefix, simple_value=precision), + tf.Summary.Value(tag='%s/recall' % prefix, simple_value=recall), + tf.Summary.Value( + tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1)), + tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1) + ] + + def get_summaries(self) -> list[tf.Summary.Value]: + """Gathers tensorflow summaries into single list.""" + + self.from_tf() + if not self.num_voxels.tf_value[VoxelType.TOTAL]: + return [] + + for images in self.images_xy, self.images_xz, self.images_yz: + for i, summary in enumerate(images): + summary.tag += '/%d' % i + + total_moves = sum(self.moves.tf_value) + move_summaries = [] + for mt in MoveType: + move_summaries.append( + tf.Summary.Value( + tag='moves/all/%s' % mt.name.lower(), + simple_value=self.moves.tf_value[mt] / total_moves)) + + summaries = [ + tf.Summary.Value( + tag='fov/masked_voxel_fraction', + simple_value=(self.fov_stats.tf_value[FovStat.MASKED_VOXELS] / + self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])), + tf.Summary.Value( + tag='fov/average_weight', + simple_value=(self.fov_stats.tf_value[FovStat.WEIGHTS_SUM] / + self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])), + tf.Summary.Value( + tag='masked_voxel_fraction', + simple_value=(self.num_voxels.tf_value[VoxelType.MASKED] / + self.num_voxels.tf_value[VoxelType.TOTAL])), + tf.Summary.Value( + tag='eval/patch_loss', + simple_value=self.loss.tf_value[0] / self.num_patches.tf_value[0]), + tf.Summary.Value( + tag='eval/patches', simple_value=self.num_patches.tf_value[0]), + tf.Summary.Value(tag='moves/total', simple_value=total_moves) + ] + move_summaries + ( + list(self.meshes) + list(self.images_xy) + list(self.images_xz) + + list(self.images_yz)) + + summaries.extend( + self._compute_classification_metrics(self.prediction_counts, + 'eval/all')) + + for r, r_moves in self.moves_by_r.items(): + total_moves = sum(r_moves.tf_value) + summaries.extend([ + tf.Summary.Value( + tag='moves/r=%d/correct' % r, + simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves), + tf.Summary.Value( + tag='moves/r=%d/spurious' % r, + simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves), + tf.Summary.Value( + tag='moves/r=%d/missed' % r, + simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves), + tf.Summary.Value( + tag='moves/r=%d/total' % r, simple_value=total_moves) + ]) + + return summaries diff --git a/ffn/training/variables.py b/ffn/training/variables.py index eb4d993..904a5d6 100644 --- a/ffn/training/variables.py +++ b/ffn/training/variables.py @@ -14,13 +14,14 @@ # ============================================================================== """Customized variables for tracking ratios, rates, etc.""" +import numpy as np import tensorflow.compat.v1 as tf -class FractionTracker(object): +class FractionTracker: """Helper for tracking fractions.""" - def __init__(self, name='fraction'): + def __init__(self, name: str = 'fraction'): # Values are: total, hits. self.var = tf.get_variable(name, [2], tf.int64, tf.constant_initializer([0, 0]), trainable=False) @@ -42,3 +43,89 @@ def get_hit_rate(self): update_var = self.var.assign_add([-total, -hits]) with tf.control_dependencies([update_var]): return tf.identity(hit_rate) + + +class DistributionTracker: + """Helper for tracking distributions.""" + + def __init__(self, num_classes: int, name: str = 'distribution'): + self.num_classes = num_classes + self.var = tf.get_variable( + name, [num_classes], + tf.int64, + tf.constant_initializer([0] * num_classes), + trainable=False) + + def record_class(self, class_id, count=1): + return self.var.assign_add( + tf.one_hot(class_id, self.num_classes, dtype=tf.int64) * count) + + def record_classes(self, labels): + delta = tf.math.bincount( + labels, + minlength=self.num_classes, + maxlength=self.num_classes, + dtype=tf.int64) + return self.var.assign_add(delta) + + def get_rates(self, reset=True): + """Queries the class frequencies. + + Args: + reset: whether to reset the class counters to 0 after query + + Returns: + TF op for class frequencies + """ + total = tf.reduce_sum(self.var) + rates = tf.cast(self.var, tf.float32) / tf.maximum( + tf.constant(1, dtype=tf.float32), tf.cast(total, tf.float32)) + if not reset: + return rates + + with tf.control_dependencies([rates]): + update_var = self.var.assign_add(-self.var) + with tf.control_dependencies([update_var]): + return tf.identity(rates) + + +def get_and_reset_value(var): + readout = var + 0 + with tf.control_dependencies([readout]): + update_var = var.assign_add(-readout) + with tf.control_dependencies([update_var]): + return tf.identity(readout) + + +class TFSyncVariable: + """A local variable which can be periodically synchronized to a TF one.""" + + def __init__(self, name, shape, dtype): + self._value = np.zeros(shape, dtype=dtype.as_numpy_dtype) + self._tf_var = tf.get_variable( + name, + shape, + dtype, + tf.constant_initializer(self.value), + trainable=False) + self._update_placeholder = tf.placeholder( + dtype, shape, name='plc_%s' % name) + self._to_tf = self._tf_var.assign_add(self._update_placeholder) + self._from_tf = get_and_reset_value(self._tf_var) + self.tf_value = None + + @property + def from_tf(self): + return self._from_tf + + @property + def value(self): + return self._value + + def to_tf(self, ops, feed_dict): + ops.append(self._to_tf) + feed_dict[self._update_placeholder] = self._value + self._value = np.zeros_like(self._value) + + def reset(self): + self._value = np.zeros_like(self._value) diff --git a/train.py b/train.py index d8a3988..6aaef93 100644 --- a/train.py +++ b/train.py @@ -19,41 +19,30 @@ of view in a way dependent on the initial predictions. """ -from collections import deque -from io import BytesIO from functools import partial -import itertools import json import logging import os import random import time +from typing import Optional +from absl import app +from absl import flags +from ffn.training import augmentation +from ffn.training import examples +from ffn.training import inputs +from ffn.training import model as ffn_model +# Necessary so that optimizer flags are defined. +from ffn.training import optimizer # pylint: disable=unused-import +from ffn.training import tracker +from ffn.training.import_util import import_symbol import h5py import numpy as np - -import PIL -import PIL.Image - -import six - -from scipy.special import expit -from scipy.special import logit +from scipy import special import tensorflow.compat.v1 as tf - -from absl import app -from absl import flags from tensorflow.io import gfile -from ffn.inference import movement -from ffn.training import mask -from ffn.training.import_util import import_symbol -from ffn.training import inputs -from ffn.training import augmentation -# Necessary so that optimizer flags are defined. -# pylint: disable=unused-import -from ffn.training import optimizer -# pylint: enable=unused-import FLAGS = flags.FLAGS @@ -143,148 +132,9 @@ FLAGS = flags.FLAGS -class EvalTracker(object): - """Tracks eval results over multiple training steps.""" - - def __init__(self, eval_shape): - self.eval_labels = tf.placeholder( - tf.float32, [1] + eval_shape + [1], name='eval_labels') - self.eval_preds = tf.placeholder( - tf.float32, [1] + eval_shape + [1], name='eval_preds') - self.eval_loss = tf.reduce_mean( - tf.nn.sigmoid_cross_entropy_with_logits( - logits=self.eval_preds, labels=self.eval_labels)) - self.reset() - self.eval_threshold = logit(0.9) - self.sess = None - self._eval_shape = eval_shape - - def reset(self): - """Resets status of the tracker.""" - self.loss = 0 - self.num_patches = 0 - self.tp = 0 - self.tn = 0 - self.fn = 0 - self.fp = 0 - self.total_voxels = 0 - self.masked_voxels = 0 - self.images_xy = deque(maxlen=16) - self.images_xz = deque(maxlen=16) - self.images_yz = deque(maxlen=16) - - def slice_image(self, labels, predicted, weights, slice_axis): - """Builds a tf.Summary showing a slice of an object mask. - - The object mask slice is shown side by side with the corresponding - ground truth mask. - - Args: - labels: ndarray of ground truth data, shape [1, z, y, x, 1] - predicted: ndarray of predicted data, shape [1, z, y, x, 1] - weights: ndarray of loss weights, shape [1, z, y, x, 1] - slice_axis: axis in the middle of which to place the cutting plane - for which the summary image will be generated, valid values are - 2 ('x'), 1 ('y'), and 0 ('z'). - - Returns: - tf.Summary.Value object with the image. - """ - zyx = list(labels.shape[1:-1]) - selector = [0, slice(None), slice(None), slice(None), 0] - selector[slice_axis + 1] = zyx[slice_axis] // 2 - selector = tuple(selector) - - del zyx[slice_axis] - h, w = zyx - - buf = BytesIO() - labels = (labels[selector] * 255).astype(np.uint8) - predicted = (predicted[selector] * 255).astype(np.uint8) - weights = (weights[selector] * 255).astype(np.uint8) - - im = PIL.Image.fromarray(np.concatenate([labels, predicted, - weights], axis=1), 'L') - im.save(buf, 'PNG') - - axis_names = 'zyx' - axis_names = axis_names.replace(axis_names[slice_axis], '') - - return tf.Summary.Value( - tag='final_%s' % axis_names[::-1], - image=tf.Summary.Image( - height=h, width=w * 3, colorspace=1, # greyscale - encoded_image_string=buf.getvalue())) - - def add_patch(self, labels, predicted, weights, - coord=None, volname=None, patches=None): - """Evaluates single-object segmentation quality.""" - - predicted = mask.crop_and_pad(predicted, (0, 0, 0), self._eval_shape) - weights = mask.crop_and_pad(weights, (0, 0, 0), self._eval_shape) - labels = mask.crop_and_pad(labels, (0, 0, 0), self._eval_shape) - loss, = self.sess.run([self.eval_loss], {self.eval_labels: labels, - self.eval_preds: predicted}) - self.loss += loss - self.total_voxels += labels.size - self.masked_voxels += np.sum(weights == 0.0) - - pred_mask = predicted >= self.eval_threshold - true_mask = labels > 0.5 - pred_bg = np.logical_not(pred_mask) - true_bg = np.logical_not(true_mask) - - self.tp += np.sum(pred_mask & true_mask) - self.fp += np.sum(pred_mask & true_bg) - self.fn += np.sum(pred_bg & true_mask) - self.tn += np.sum(pred_bg & true_bg) - self.num_patches += 1 - - predicted = expit(predicted) - self.images_xy.append(self.slice_image(labels, predicted, weights, 0)) - self.images_xz.append(self.slice_image(labels, predicted, weights, 1)) - self.images_yz.append(self.slice_image(labels, predicted, weights, 2)) - - def get_summaries(self): - """Gathers tensorflow summaries into single list.""" - - if not self.total_voxels: - return [] - - precision = self.tp / max(self.tp + self.fp, 1) - recall = self.tp / max(self.tp + self.fn, 1) - - for images in self.images_xy, self.images_xz, self.images_yz: - for i, summary in enumerate(images): - summary.tag += '/%d' % i - - summaries = ( - list(self.images_xy) + list(self.images_xz) + list(self.images_yz) + [ - tf.Summary.Value(tag='masked_voxel_fraction', - simple_value=(self.masked_voxels / - self.total_voxels)), - tf.Summary.Value(tag='eval/patch_loss', - simple_value=self.loss / self.num_patches), - tf.Summary.Value(tag='eval/patches', - simple_value=self.num_patches), - tf.Summary.Value(tag='eval/accuracy', - simple_value=(self.tp + self.tn) / ( - self.tp + self.tn + self.fp + self.fn)), - tf.Summary.Value(tag='eval/precision', - simple_value=precision), - tf.Summary.Value(tag='eval/recall', - simple_value=recall), - tf.Summary.Value(tag='eval/specificity', - simple_value=self.tn / max(self.tn + self.fp, 1)), - tf.Summary.Value(tag='eval/f1', - simple_value=(2.0 * precision * recall / - (precision + recall))) - ]) - - return summaries - - -def run_training_step(sess, model, fetch_summary, feed_dict): +def run_training_step(sess: tf.Session, model: ffn_model.FFNModel, + fetch_summary: Optional[tf.Operation], + feed_dict: dict[str, np.ndarray]): """Runs one training step for a single FFN FOV.""" ops_to_run = [model.train_op, model.global_step, model.logits] @@ -302,34 +152,34 @@ def run_training_step(sess, model, fetch_summary, feed_dict): return prediction, step, summ -def fov_moves(): +def fov_moves() -> int: # Add one more move to get a better fill of the evaluation area. if FLAGS.fov_policy == 'max_pred_moves': return FLAGS.fov_moves + 1 return FLAGS.fov_moves -def train_labels_size(model): - return (np.array(model.pred_mask_size) + - np.array(model.deltas) * 2 * fov_moves()) +def train_labels_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.pred_mask_size) + + np.array(info.deltas) * 2 * fov_moves()) -def train_eval_size(model): - return (np.array(model.pred_mask_size) + - np.array(model.deltas) * 2 * FLAGS.fov_moves) +def train_eval_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.pred_mask_size) + + np.array(info.deltas) * 2 * FLAGS.fov_moves) -def train_image_size(model): - return (np.array(model.input_image_size) + - np.array(model.deltas) * 2 * fov_moves()) +def train_image_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.input_image_size) + + np.array(info.deltas) * 2 * fov_moves()) -def train_canvas_size(model): - return (np.array(model.input_seed_size) + - np.array(model.deltas) * 2 * fov_moves()) +def train_canvas_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.input_seed_size) + + np.array(info.deltas) * 2 * fov_moves()) -def _get_offset_and_scale_map(): +def _get_offset_and_scale_map() -> dict[str, tuple[float, float]]: if not FLAGS.image_offset_scale_map: return {} @@ -436,167 +286,15 @@ def define_data_input(model, queue_batch=None): return patches, labels, loss_weights, coord, volname -def prepare_ffn(model): +def prepare_ffn(model: ffn_model.FFNModel): """Creates the TF graph for an FFN.""" - shape = [FLAGS.batch_size] + list(model.pred_mask_size[::-1]) + [1] + shape = [FLAGS.batch_size] + list(model.info.pred_mask_size[::-1]) + [1] model.labels = tf.placeholder(tf.float32, shape, name='labels') model.loss_weights = tf.placeholder(tf.float32, shape, name='loss_weights') model.define_tf_graph() -def fixed_offsets(model, seed, fov_shifts=None): - """Generates offsets based on a fixed list.""" - for off in itertools.chain([(0, 0, 0)], fov_shifts): - if model.dim == 3: - is_valid_move = seed[:, - seed.shape[1] // 2 + off[2], - seed.shape[2] // 2 + off[1], - seed.shape[3] // 2 + off[0], - 0] >= logit(FLAGS.threshold) - else: - is_valid_move = seed[:, - seed.shape[1] // 2 + off[1], - seed.shape[2] // 2 + off[0], - 0] >= logit(FLAGS.threshold) - - if not is_valid_move: - continue - - yield off - - -def max_pred_offsets(model, seed): - """Generates offsets with the policy used for inference.""" - # Always start at the center. - queue = deque([(0, 0, 0)]) - done = set() - - train_image_radius = train_image_size(model) // 2 - input_image_radius = np.array(model.input_image_size) // 2 - - while queue: - offset = queue.popleft() - - # Drop any offsets that would take us beyond the image fragment we - # loaded for training. - if np.any(np.abs(np.array(offset)) + input_image_radius > - train_image_radius): - continue - - # Ignore locations that were visited previously. - quantized_offset = ( - offset[0] // max(model.deltas[0], 1), - offset[1] // max(model.deltas[1], 1), - offset[2] // max(model.deltas[2], 1)) - - if quantized_offset in done: - continue - - done.add(quantized_offset) - - yield offset - - # Look for new offsets within the updated seed. - curr_seed = mask.crop_and_pad(seed, offset, model.pred_mask_size[::-1]) - todos = sorted( - movement.get_scored_move_offsets( - model.deltas[::-1], - curr_seed[0, ..., 0], - threshold=logit(FLAGS.threshold)), reverse=True) - queue.extend((x[2] + offset[0], - x[1] + offset[1], - x[0] + offset[2]) for _, x in todos) - - -def get_example(load_example, eval_tracker, model, get_offsets): - """Generates individual training examples. - - Args: - load_example: callable returning a tuple of image and label ndarrays - as well as the seed coordinate and volume name of the example - eval_tracker: EvalTracker object - model: FFNModel object - get_offsets: iterable of (x, y, z) offsets to investigate within the - training patch - - Yields: - tuple of: - seed array, shape [1, z, y, x, 1] - image array, shape [1, z, y, x, 1] - label array, shape [1, z, y, x, 1] - """ - seed_shape = train_canvas_size(model).tolist()[::-1] - - while True: - full_patches, full_labels, loss_weights, coord, volname = load_example() - # Always start with a clean seed. - seed = logit(mask.make_seed(seed_shape, 1, pad=FLAGS.seed_pad)) - - for off in get_offsets(model, seed): - predicted = mask.crop_and_pad(seed, off, model.input_seed_size[::-1]) - patches = mask.crop_and_pad(full_patches, off, model.input_image_size[::-1]) - labels = mask.crop_and_pad(full_labels, off, model.pred_mask_size[::-1]) - weights = mask.crop_and_pad(loss_weights, off, model.pred_mask_size[::-1]) - - # Necessary, since the caller is going to update the array and these - # changes need to be visible in the following iterations. - assert predicted.base is seed - yield predicted, patches, labels, weights - - eval_tracker.add_patch( - full_labels, seed, loss_weights, coord, volname, full_patches) - - -def get_batch(load_example, eval_tracker, model, batch_size, get_offsets): - """Generates batches of training examples. - - Args: - load_example: callable returning a tuple of image and label ndarrays - as well as the seed coordinate and volume name of the example - eval_tracker: EvalTracker object - model: FFNModel object - batch_size: desidred batch size - get_offsets: iterable of (x, y, z) offsets to investigate within the - training patch - - Yields: - tuple of: - seed array, shape [b, z, y, x, 1] - image array, shape [b, z, y, x, 1] - label array, shape [b, z, y, x, 1] - - where 'b' is the batch_size. - """ - def _batch(iterable): - for batch_vals in iterable: - # `batch_vals` is sequence of `batch_size` tuples returned by the - # `get_example` generator, to which we apply the following transformation: - # [(a0, b0), (a1, b1), .. (an, bn)] -> [(a0, a1, .., an), - # (b0, b1, .., bn)] - # (where n is the batch size) to get a sequence, each element of which - # represents a batch of values of a given type (e.g., seed, image, etc.) - yield zip(*batch_vals) - - # Create a separate generator for every element in the batch. This generator - # will automatically advance to a different training example once the allowed - # moves for the current location are exhausted. - for seeds, patches, labels, weights in _batch(six.moves.zip( - *[get_example(load_example, eval_tracker, model, get_offsets) for _ - in range(batch_size)])): - - batched_seeds = np.concatenate(seeds) - - yield (batched_seeds, np.concatenate(patches), np.concatenate(labels), - np.concatenate(weights)) - - # batched_seed is updated in place with new predictions by the code - # calling get_batch. Here we distribute these updated predictions back - # to the buffer of every generator. - for i in range(batch_size): - seeds[i][:] = batched_seeds[i, ...] - - def save_flags(): gfile.makedirs(FLAGS.train_dir) with gfile.GFile( @@ -614,9 +312,9 @@ def train_ffn(model_cls, **model_kwargs): # The constructor might define TF ops/placeholders, so it is important # that the FFN is instantiated within the current context. model = model_cls(**model_kwargs) - eval_shape_zyx = train_eval_size(model).tolist()[::-1] + eval_shape_zyx = train_eval_size(model.info).tolist()[::-1] - eval_tracker = EvalTracker(eval_shape_zyx) + eval_tracker = tracker.EvalTracker(eval_shape_zyx, model.shifts) load_data_ops = define_data_input(model, queue_batch=1) prepare_ffn(model) merge_summaries_op = tf.summary.merge_all() @@ -656,13 +354,35 @@ def train_ffn(model_cls, **model_kwargs): if FLAGS.shuffle_moves: random.shuffle(fov_shifts) + train_image_radius = train_image_size(model.info) // 2 + input_image_radius = np.array(model.info.input_image_size) // 2 policy_map = { - 'fixed': partial(fixed_offsets, fov_shifts=fov_shifts), - 'max_pred_moves': max_pred_offsets + 'fixed': + partial( + examples.fixed_offsets, + fov_shifts=fov_shifts, + threshold=special.logit(FLAGS.threshold)), + 'max_pred_moves': + partial( + examples.max_pred_offsets, + max_radius=train_image_radius - input_image_radius, + threshold=special.logit(FLAGS.threshold)), + 'no_step': + examples.no_offsets } - batch_it = get_batch(lambda: sess.run(load_data_ops), - eval_tracker, model, FLAGS.batch_size, - policy_map[FLAGS.fov_policy]) + policy_fn = policy_map[FLAGS.fov_policy] + + def _make_ffn_example(): + return examples.get_example( + lambda: sess.run(load_data_ops), + eval_tracker, + model.info, + policy_fn, + FLAGS.seed_pad, + seed_shape=tuple(train_canvas_size(model.info).tolist()[::-1])) + + batch_it = examples.BatchExampleIter(_make_ffn_example, eval_tracker, + FLAGS.batch_size, model.info) t_last = time.time() @@ -677,6 +397,7 @@ def train_ffn(model_cls, **model_kwargs): seed, patches, labels, weights = next(batch_it) + eval_tracker.to_tf() updated_seed, step, summ = run_training_step( sess, model, summ_op, feed_dict={ @@ -688,7 +409,7 @@ def train_ffn(model_cls, **model_kwargs): # Save prediction results in the original seed array so that # they can be used in subsequent steps. - mask.update_at(seed, (0, 0, 0), updated_seed) + batch_it.update_seeds(updated_seed) # Record summaries. if summ is not None: