From 721001a9871e5f0b62302b412756dc282b338c5c Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Tue, 23 Jan 2024 08:24:12 -0800 Subject: [PATCH] Factor out Runner into a separate module. Remove halt signaling, self-consistency looping, and histogram matching. PiperOrigin-RevId: 600796058 --- ffn/inference/executor.py | 42 ++- ffn/inference/inference.py | 548 +--------------------------- ffn/inference/movement.py | 86 +++-- ffn/inference/storage.py | 5 +- ffn/training/examples.py | 327 +++++++++++++++++ ffn/training/model.py | 139 +++---- ffn/training/models/convstack_3d.py | 58 ++- ffn/training/optimizer.py | 135 +++++-- ffn/training/tracker.py | 356 ++++++++++++++++++ ffn/utils/bounding_box.py | 2 +- train.py | 403 ++++---------------- 11 files changed, 1047 insertions(+), 1054 deletions(-) create mode 100644 ffn/training/examples.py create mode 100644 ffn/training/tracker.py diff --git a/ffn/inference/executor.py b/ffn/inference/executor.py index 9493534..fa7642d 100644 --- a/ffn/inference/executor.py +++ b/ffn/inference/executor.py @@ -41,9 +41,11 @@ def __init__(self, model, session, counters, batch_size): self.active_clients = 0 # Cache input/output sizes. - self._input_seed_size = np.array(model.input_seed_size[::-1]).tolist() - self._input_image_size = np.array(model.input_image_size[::-1]).tolist() - self._pred_size = np.array(model.pred_mask_size[::-1]).tolist() + self._input_seed_size = np.array(model.info.input_seed_size[::-1]).tolist() + self._input_image_size = np.array( + model.info.input_image_size[::-1] + ).tolist() + self._pred_size = np.array(model.info.pred_mask_size[::-1]).tolist() self._initialize_model() @@ -111,8 +113,9 @@ class ThreadingBatchExecutor(BatchExecutor): """ def __init__(self, model, session, counters, batch_size, expected_clients=1): - super(ThreadingBatchExecutor, self).__init__(model, session, counters, - batch_size) + super(ThreadingBatchExecutor, self).__init__( + model, session, counters, batch_size + ) self._lock = threading.Lock() self.outputs = {} # Will be populated by Queues as clients register. # Used by clients to communiate with the executor. The protocol is @@ -131,10 +134,12 @@ def __init__(self, model, session, counters, batch_size, expected_clients=1): self.expected_clients = expected_clients # Arrays fed to TF. - self.input_seed = np.zeros([batch_size] + self._input_seed_size + [1], - dtype=np.float32) - self.input_image = np.zeros([batch_size] + self._input_image_size + [1], - dtype=np.float32) + self.input_seed = np.zeros( + [batch_size] + self._input_seed_size + [1], dtype=np.float32 + ) + self.input_image = np.zeros( + [batch_size] + self._input_image_size + [1], dtype=np.float32 + ) self.th_executor = None def start_server(self): @@ -146,7 +151,8 @@ def start_server(self): """ if self.th_executor is None: self.th_executor = threading.Thread( - target=self._run_executor_log_exceptions) + target=self._run_executor_log_exceptions + ) self.th_executor.start() def stop_server(self): @@ -166,8 +172,10 @@ def _run_executor(self): with timer_counter(self.counters, 'executor-input'): ready = [] - while (len(ready) < min(self.active_clients, self.batch_size) or - not self.active_clients): + while ( + len(ready) < min(self.active_clients, self.batch_size) + or not self.active_clients + ): try: data = self.input_queue.get(timeout=5) except queue.Empty: @@ -201,9 +209,12 @@ def _schedule_batch(self, client_ids, fetches): with timer_counter(self.counters, 'executor-inference'): try: ret = self.session.run( - fetches, { + fetches, + { self.model.input_seed: self.input_seed, - self.model.input_patches: self.input_image}) + self.model.input_patches: self.input_image, + }, + ) except Exception as e: # pylint:disable=broad-except logging.exception(e) # If calling TF didn't work (faulty hardware, misconfiguration, etc), @@ -215,8 +226,7 @@ def _schedule_batch(self, client_ids, fetches): with self._lock: for i, client_id in enumerate(client_ids): try: - self.outputs[client_id].put( - {k: v[i, ...] for k, v in ret.items()}) + self.outputs[client_id].put({k: v[i, ...] for k, v in ret.items()}) except KeyError: # This could happen if a client unregistered itself # while inference was running. diff --git a/ffn/inference/inference.py b/ffn/inference/inference.py index 1315677..77def40 100644 --- a/ffn/inference/inference.py +++ b/ffn/inference/inference.py @@ -16,6 +16,7 @@ from collections import namedtuple import functools +from io import BytesIO import json import logging import os @@ -33,14 +34,16 @@ from .inference_utils import TimedIter from .inference_utils import timer_counter import numpy as np + from numpy.lib.stride_tricks import as_strided +from ..utils import ortho_plane_visualization from scipy.special import expit from scipy.special import logit import tensorflow.compat.v1 as tf from tensorflow.io import gfile +from ..training import model as ffn_model from ..training.import_util import import_symbol from ..utils import bounding_box -from ..utils import ortho_plane_visualization MSEC_IN_SEC = 1000 MAX_SELF_CONSISTENT_ITERS = 32 @@ -48,10 +51,10 @@ # Visualization. # --------------------------------------------------------------------------- -class DynamicImage(object): + +class DynamicImage: def UpdateFromPIL(self, new_img): - from io import BytesIO - from IPython import display + from IPython import display # pytype:disable=import-error display.clear_output(wait=True) image = BytesIO() new_img.save(image, format='png') @@ -116,74 +119,18 @@ def visualize_state(seed_logits, pos, movement_policy, dynimage): vis = Image.fromarray(val) dynimage.UpdateFromPIL(vis) - -# Self-prediction halting -# --------------------------------------------------------------------------- -HALT_SILENT = 0 -PRINT_HALTS = 1 -HALT_VERBOSE = 2 - -HaltInfo = namedtuple('HaltInfo', ['is_halt', 'extra_fetches']) - - -def no_halt(verbosity=HALT_SILENT, log_function=logging.info): - """Dummy HaltInfo.""" - def _halt_signaler(*unused_args, **unused_kwargs): - return False - - def _halt_signaler_verbose(fetches, pos, **unused_kwargs): - log_function('%s, %s' % (pos, fetches)) - return False - - if verbosity == HALT_VERBOSE: - return HaltInfo(_halt_signaler_verbose, []) - else: - return HaltInfo(_halt_signaler, []) - - -def self_prediction_halt( - threshold, orig_threshold=None, verbosity=HALT_SILENT, - log_function=logging.info): - """HaltInfo based on FFN self-predictions.""" - - def _halt_signaler(fetches, pos, orig_pos, counters, **unused_kwargs): - """Returns true if FFN prediction should be halted.""" - if pos == orig_pos and orig_threshold is not None: - t = orig_threshold - else: - t = threshold - - # [0] is by convention the total incorrect proportion prediction. - halt = fetches['self_prediction'][0] > t - - if halt: - counters['halts'].Increment() - - if verbosity == HALT_VERBOSE or ( - halt and verbosity == PRINT_HALTS): - log_function('%s, %s' % (pos, fetches)) - - return halt - - # Add self_prediction to the extra_fetches. - return HaltInfo(_halt_signaler, ['self_prediction']) - -# --------------------------------------------------------------------------- - - # TODO(mjanusz): Add support for sparse inference. -class Canvas(object): +class Canvas: """Tracks state of the inference progress and results within a subvolume.""" def __init__(self, - model, + model: ffn_model.FFNModel, tf_executor, image, options, counters=None, restrictor=None, movement_policy_fn=None, - halt_signaler=no_halt(), keep_history=False, checkpoint_path=None, checkpoint_interval_sec=0, @@ -202,7 +149,6 @@ def __init__(self, movement_policy_fn: callable taking the Canvas object as its only argument and returning a movement policy object (see movement.BaseMovementPolicy) - halt_signaler: HaltInfo object determining early stopping policy keep_history: whether to maintain a record of locations visited by the FFN, together with any associated metadata; note that this data is kept only for the object currently being segmented @@ -224,8 +170,6 @@ def __init__(self, 'segment_threshold'): setattr(self.options, attr, logit(getattr(self.options, attr))) - self.halt_signaler = halt_signaler - self.counters = counters if counters is not None else Counters() self.checkpoint_interval_sec = checkpoint_interval_sec self.checkpoint_path = checkpoint_path @@ -242,9 +186,9 @@ def __init__(self, # Cast to array to ensure we can do elementwise expressions later. # All of these are in zyx order. - self._pred_size = np.array(model.pred_mask_size[::-1]) - self._input_seed_size = np.array(model.input_seed_size[::-1]) - self._input_image_size = np.array(model.input_image_size[::-1]) + self._pred_size = np.array(model.info.pred_mask_size[::-1]) + self._input_seed_size = np.array(model.info.input_seed_size[::-1]) + self._input_image_size = np.array(model.info.input_image_size[::-1]) self.margin = self._input_image_size // 2 self._pred_delta = (self._input_seed_size - self._pred_size) // 2 @@ -277,7 +221,7 @@ def __init__(self, if movement_policy_fn is None: # The model.deltas are (for now) in xyz order and must be swapped to zyx. self.movement_policy = movement.FaceMaxMovementPolicy( - self, deltas=model.deltas[::-1], + self, deltas=model.info.deltas[::-1], score_threshold=self.options.move_threshold) else: self.movement_policy = movement_policy_fn(self) @@ -362,7 +306,7 @@ def _get_image(self, pos): img = self.image[tuple(slice(s, e) for s, e in zip(start, end))] return img - def predict(self, pos, logit_seed, extra_fetches): + def predict(self, pos, logit_seed): """Runs a single step of FFN prediction. Args: @@ -385,16 +329,15 @@ def predict(self, pos, logit_seed, extra_fetches): self.counters['inference-not-predict-ms'].IncrementBy( delta_t * MSEC_IN_SEC) - extra_fetches['logits'] = self.model.logits + fetches = {'logits': self.model.logits} with timer_counter(self.counters, 'inference'): fetches = self.executor.predict(self._exec_client_id, - logit_seed, img, extra_fetches) + logit_seed, img, fetches) self.t_last_predict = time.time() logits = fetches.pop('logits') - prob = expit(logits) - return (prob[..., 0], logits[..., 0]), fetches + return logits[..., 0] def update_at(self, pos, start_pos): """Updates object mask prediction at a specific position. @@ -421,27 +364,7 @@ def update_at(self, pos, start_pos): init_prediction = np.isnan(logit_seed) logit_seed[init_prediction] = np.float32(self.options.pad_value) - extra_fetches = {f: getattr(self.model, f) for f - in self.halt_signaler.extra_fetches} - - prob_seed = expit(logit_seed) - for _ in range(MAX_SELF_CONSISTENT_ITERS): - (prob, logits), fetches = self.predict(pos, logit_seed, - extra_fetches=extra_fetches) - if self.options.consistency_threshold <= 0: - break - - diff = np.average(np.abs(prob_seed - prob)) - if diff < self.options.consistency_threshold: - break - - prob_seed, logit_seed = prob, logits - - if self.halt_signaler.is_halt(fetches=fetches, pos=pos, - orig_pos=start_pos, - counters=self.counters): - logits[:] = np.float32(self.options.pad_value) - + logits = self.predict(pos, logit_seed) start += self._pred_delta end = start + self._pred_size sel = tuple(slice(s, e) for s, e in zip(start, end)) @@ -460,8 +383,8 @@ def update_at(self, pos, start_pos): self.options.disco_seed_threshold): # Because (x > NaN) is always False, this mask excludes positions that # were previously uninitialized (i.e. set to NaN in old_seed). + old_err = np.seterr(invalid='ignore') try: - old_err = np.seterr(invalid='ignore') mask = ((old_seed < th_max) & (logits > old_seed)) finally: np.seterr(**old_err) @@ -787,436 +710,3 @@ def _maybe_save_checkpoint(self): self.save_checkpoint(self.checkpoint_path) self.checkpoint_last = time.time() - - -class Runner(object): - """Helper for managing FFN inference runs. - - Takes care of initializing the FFN model and any related functionality - (e.g. movement policies), as well as input/output of the FFN inference - data (loading inputs, saving segmentations). - """ - - ALL_MASKED = 1 - - def __init__(self): - self.counters = inference_utils.Counters() - self.executor = None - - def __del__(self): - self.stop_executor() - - def stop_executor(self): - """Shuts down the executor. - - No-op when no executor is active. - """ - if self.executor is not None: - self.executor.stop_server() - self.executor = None - - def _load_model_checkpoint(self, checkpoint_path): - """Restores the inference model from a training checkpoint. - - Args: - checkpoint_path: the string path to the checkpoint file to load - """ - with timer_counter(self.counters, 'restore-tf-checkpoint'): - logging.info('Loading checkpoint.') - self.model.saver.restore(self.session, checkpoint_path) - logging.info('Checkpoint loaded.') - - def start(self, request, batch_size=1, exec_cls=None, session=None): - """Opens input volumes and initializes the FFN.""" - self.request = request - assert self.request.segmentation_output_dir - - logging.debug('Received request:\n%s', request) - - if not gfile.exists(request.segmentation_output_dir): - gfile.makedirs(request.segmentation_output_dir) - - with timer_counter(self.counters, 'volstore-open'): - # Disabling cache compression can improve access times by 20-30% - # as of Aug 2016. - self._image_volume = storage.decorated_volume( - request.image, cache_max_bytes=int(1e8), - cache_compression=False) - assert self._image_volume is not None - - if request.HasField('init_segmentation'): - self.init_seg_volume = storage.decorated_volume( - request.init_segmentation, cache_max_bytes=int(1e8)) - else: - self.init_seg_volume = None - - def _open_or_none(settings): - if settings.WhichOneof('volume_path') is None: - return None - return storage.decorated_volume( - settings, cache_max_bytes=int(1e7), cache_compression=False) - self._mask_volumes = {} - self._shift_mask_volume = _open_or_none(request.shift_mask) - - alignment_options = request.alignment_options - null_alignment = inference_pb2.AlignmentOptions.NO_ALIGNMENT - if not alignment_options or alignment_options.type == null_alignment: - self._aligner = align.Aligner() - else: - type_name = inference_pb2.AlignmentOptions.AlignType.Name( - alignment_options.type) - error_string = 'Alignment for type %s is not implemented' % type_name - logging.error(error_string) - raise NotImplementedError(error_string) - - def _open_or_none(settings): - if settings.WhichOneof('volume_path') is None: - return None - return storage.decorated_volume( - settings, cache_max_bytes=int(1e7), cache_compression=False) - self._mask_volumes = {} - self._shift_mask_volume = _open_or_none(request.shift_mask) - - if request.reference_histogram: - with gfile.GFile(request.reference_histogram, 'rb') as f: - data = np.load(f) - self._reference_lut = data['lut'] - else: - self._reference_lut = None - - self.stop_executor() - - if session is None: - config = tf.ConfigProto() - tf.reset_default_graph() - session = tf.Session(config=config) - self.session = session - logging.info('Available TF devices: %r', self.session.list_devices()) - - # Initialize the FFN model. - model_class = import_symbol(request.model_name) - if request.model_args: - args = json.loads(request.model_args) - else: - args = {} - - args['batch_size'] = batch_size - self.model = model_class(**args) - - if exec_cls is None: - exec_cls = executor.ThreadingBatchExecutor - - self.executor = exec_cls( - self.model, self.session, self.counters, batch_size) - self.movement_policy_fn = movement.get_policy_fn(request, self.model) - - self.saver = tf.train.Saver() - self._load_model_checkpoint(request.model_checkpoint_path) - - self.executor.start_server() - - def make_restrictor(self, corner, subvol_size, image, alignment): - """Builds a MovementRestrictor object.""" - kwargs = {} - - if self.request.masks: - with timer_counter(self.counters, 'load-mask'): - final_mask = storage.build_mask(self.request.masks, - corner, subvol_size, - self._mask_volumes, - image, alignment) - - if np.all(final_mask): - logging.info('Everything masked.') - return self.ALL_MASKED - - kwargs['mask'] = final_mask - - if self.request.seed_masks: - with timer_counter(self.counters, 'load-seed-mask'): - seed_mask = storage.build_mask(self.request.seed_masks, - corner, subvol_size, - self._mask_volumes, - image, alignment) - - if np.all(seed_mask): - logging.info('All seeds masked.') - return self.ALL_MASKED - - kwargs['seed_mask'] = seed_mask - - if self._shift_mask_volume: - with timer_counter(self.counters, 'load-shift-mask'): - s = self.request.shift_mask_scale - shift_corner = np.array(corner) // (1, s, s) - shift_size = -(-np.array(subvol_size) // (1, s, s)) - - shift_alignment = alignment.rescaled( - np.array((1.0, 1.0, 1.0)) / (1, s, s)) - src_corner, src_size = shift_alignment.expand_bounds( - shift_corner, shift_size, forward=False) - src_corner, src_size = storage.clip_subvolume_to_bounds( - src_corner, src_size, self._shift_mask_volume) - src_end = src_corner + src_size - - expanded_shift_mask = self._shift_mask_volume[ - 0:2, # - src_corner[0]:src_end[0], # - src_corner[1]:src_end[1], # - src_corner[2]:src_end[2]] - shift_mask = np.array([ - shift_alignment.align_and_crop( - src_corner, expanded_shift_mask[i], shift_corner, shift_size) - for i in range(2)]) - shift_mask = alignment.transform_shift_mask(corner, s, shift_mask) - - if self.request.HasField('shift_mask_fov'): - shift_mask_fov = bounding_box.BoundingBox( - start=self.request.shift_mask_fov.start, - size=self.request.shift_mask_fov.size) - else: - shift_mask_diameter = np.array(self.model.input_image_size) - shift_mask_fov = bounding_box.BoundingBox( - start=-(shift_mask_diameter // 2), size=shift_mask_diameter) - - kwargs.update({ - 'shift_mask': shift_mask, - 'shift_mask_fov': shift_mask_fov, - 'shift_mask_scale': self.request.shift_mask_scale, - 'shift_mask_threshold': self.request.shift_mask_threshold}) - - if kwargs: - return movement.MovementRestrictor(**kwargs) - else: - return None - - def make_canvas(self, corner, subvol_size, **canvas_kwargs): - """Builds the Canvas object for inference on a subvolume. - - Args: - corner: start of the subvolume (z, y, x) - subvol_size: size of the subvolume (z, y, x) - **canvas_kwargs: passed to Canvas - - Returns: - A tuple of: - Canvas object - Alignment object - """ - subvol_counters = self.counters.get_sub_counters() - with timer_counter(subvol_counters, 'load-image'): - logging.info('Process subvolume: %r', corner) - - # A Subvolume with bounds defined by (src_size, src_corner) is guaranteed - # to result in no missing data when aligned to (dst_size, dst_corner). - # Likewise, one defined by (dst_size, dst_corner) is guaranteed to result - # in no missing data when reverse-aligned to (corner, subvol_size). - alignment = self._aligner.generate_alignment(corner, subvol_size) - - # Bounding box for the aligned destination subvolume. - dst_corner, dst_size = alignment.expand_bounds( - corner, subvol_size, forward=True) - # Bounding box for the pre-aligned imageset to be fetched from the volume. - src_corner, src_size = alignment.expand_bounds( - dst_corner, dst_size, forward=False) - # Ensure that the request bounds don't extend beyond volume bounds. - src_corner, src_size = storage.clip_subvolume_to_bounds( - src_corner, src_size, self._image_volume) - - logging.info('Requested bounds are %r + %r', corner, subvol_size) - logging.info('Destination bounds are %r + %r', dst_corner, dst_size) - logging.info('Fetch bounds are %r + %r', src_corner, src_size) - - # Fetch the image from the volume using the src bounding box. - def get_data_3d(volume, bbox): - slc = bbox.to_slice() - if volume.ndim == 4: - slc = np.index_exp[0:1] + slc - data = volume[slc] - if data.ndim == 4: - data = data.squeeze(axis=0) - return data - src_bbox = bounding_box.BoundingBox( - start=src_corner[::-1], size=src_size[::-1]) - src_image = get_data_3d(self._image_volume, src_bbox) - logging.info('Fetched image of size %r prior to transform', - src_image.shape) - - def align_and_crop(image): - return alignment.align_and_crop(src_corner, image, dst_corner, dst_size, - forward=True) - - # Align and crop to the dst bounding box. - image = align_and_crop(src_image) - # image now has corner dst_corner and size dst_size. - - logging.info('Image data loaded, shape: %r.', image.shape) - - restrictor = self.make_restrictor(dst_corner, dst_size, image, alignment) - - try: - if self._reference_lut is not None: - if self.request.histogram_masks: - histogram_mask = storage.build_mask(self.request.histogram_masks, - dst_corner, dst_size, - self._mask_volumes, - image, alignment) - else: - histogram_mask = None - - inference_utils.match_histogram(image, self._reference_lut, - mask=histogram_mask) - except ValueError as e: - # This can happen if the subvolume is relatively small because of tiling - # done by CLAHE. For now we just ignore these subvolumes. - # TODO(mjanusz): Handle these cases by reducing the number of tiles. - logging.info('Could not match histogram: %r', e) - return None, None - - image = (image.astype(np.float32) - - self.request.image_mean) / self.request.image_stddev - if restrictor == self.ALL_MASKED: - return None, None - - if self.request.HasField('self_prediction'): - halt_signaler = self_prediction_halt( - self.request.self_prediction.threshold, - orig_threshold=self.request.self_prediction.orig_threshold, - verbosity=PRINT_HALTS) - else: - halt_signaler = no_halt() - - canvas = Canvas( - self.model, - self.executor, - image, - self.request.inference_options, - counters=subvol_counters, - restrictor=restrictor, - movement_policy_fn=self.movement_policy_fn, - halt_signaler=halt_signaler, - checkpoint_path=storage.checkpoint_path( - self.request.segmentation_output_dir, corner), - checkpoint_interval_sec=self.request.checkpoint_interval, - corner_zyx=dst_corner, - **canvas_kwargs) - - if self.request.HasField('init_segmentation'): - canvas.init_segmentation_from_volume(self.init_seg_volume, src_corner, - src_bbox.end[::-1], align_and_crop) - return canvas, alignment - - def get_seed_policy(self, corner, subvol_size): - """Get seed policy generating callable. - - Args: - corner: the original corner of the requested subvolume, before any - modification e.g. dynamic alignment. - subvol_size: the original requested size. - - Returns: - A callable for generating seed policies. - """ - policy_cls = getattr(seed, self.request.seed_policy) - kwargs = {'corner': corner, 'subvol_size': subvol_size} - if self.request.seed_policy_args: - kwargs.update(json.loads(self.request.seed_policy_args)) - return functools.partial(policy_cls, **kwargs) - - def save_segmentation(self, canvas, alignment, target_path, prob_path): - """Saves segmentation to a file. - - Args: - canvas: Canvas object containing the segmentation - alignment: the local Alignment used with the canvas, or None - target_path: path to the file where the segmentation should - be saved - prob_path: path to the file where the segmentation probability - map should be saved - """ - def unalign_image(im3d): - if alignment is None: - return im3d - return alignment.align_and_crop( - canvas.corner_zyx, - im3d, - alignment.corner, - alignment.size, - forward=False) - - def unalign_origins(origins, canvas_corner): - out_origins = dict() - for key, value in origins.items(): - zyx = np.array(value.start_zyx) + canvas_corner - zyx = alignment.transform(zyx[:, np.newaxis], forward=False).squeeze() - zyx -= canvas_corner - out_origins[key] = value._replace(start_zyx=tuple(zyx)) - return out_origins - - # Remove markers. - canvas.segmentation[canvas.segmentation < 0] = 0 - - # Save segmentation results. Reduce # of bits per item if possible. - storage.save_subvolume( - unalign_image(canvas.segmentation), - unalign_origins(canvas.origins, np.array(canvas.corner_zyx)), - target_path, - request=self.request.SerializeToString(), - counters=canvas.counters.dumps(), - overlaps=canvas.overlaps) - - # Save probability map separately. This has to happen after the - # segmentation is saved, as `save_subvolume` will create any necessary - # directories. - prob = unalign_image(canvas.seg_prob) - with storage.atomic_file(prob_path) as fd: - np.savez_compressed(fd, qprob=prob) - - def run(self, corner, subvol_size, reset_counters=True): - """Runs FFN inference over a subvolume. - - Args: - corner: start of the subvolume (z, y, x) - subvol_size: size of the subvolume (z, y, x) - reset_counters: whether to reset the counters - - Returns: - Canvas object with the segmentation or None if the canvas could not - be created or the segmentation subvolume already exists. - """ - if reset_counters: - self.counters.reset() - - seg_path = storage.segmentation_path( - self.request.segmentation_output_dir, corner) - prob_path = storage.object_prob_path( - self.request.segmentation_output_dir, corner) - cpoint_path = storage.checkpoint_path( - self.request.segmentation_output_dir, corner) - - if gfile.exists(seg_path): - return None - - canvas, alignment = self.make_canvas(corner, subvol_size) - if canvas is None: - return None - - if gfile.exists(cpoint_path): - canvas.restore_checkpoint(cpoint_path) - - if self.request.alignment_options.save_raw: - image_path = storage.subvolume_path(self.request.segmentation_output_dir, - corner, 'align') - with storage.atomic_file(image_path) as fd: - np.savez_compressed(fd, im=canvas.image) - - canvas.segment_all(seed_policy=self.get_seed_policy(corner, subvol_size)) - self.save_segmentation(canvas, alignment, seg_path, prob_path) - - # Attempt to remove the checkpoint file now that we no longer need it. - try: - gfile.remove(cpoint_path) - except: # pylint: disable=bare-except - pass - - return canvas diff --git a/ffn/inference/movement.py b/ffn/inference/movement.py index 4adf649..634a966 100644 --- a/ffn/inference/movement.py +++ b/ffn/inference/movement.py @@ -16,9 +16,14 @@ from collections import deque import json +from typing import Optional import weakref + +from connectomics.common import bounding_box import numpy as np from scipy.special import logit + +from ..training import model as ffn_model from ..training.import_util import import_symbol # Unless stated otherwise, all shape/coordinate triples in this file are in zyx @@ -34,7 +39,11 @@ # face, and look at the max probability point in every connected component # within a face. Probably best to implement this in C++ and just use a Python # wrapper. -def get_scored_move_offsets(deltas, prob_map, threshold=0.9): +def get_scored_move_offsets( + deltas: tuple[int, int, int] | np.ndarray, + prob_map: np.ndarray, + threshold: float = 0.9, +): """Looks for potential moves for a FFN. The possible moves are determined by extracting probability map values @@ -45,7 +54,7 @@ def get_scored_move_offsets(deltas, prob_map, threshold=0.9): deltas: (z,y,x) tuple of base move offsets for the 3 axes prob_map: current probability map as a (z,y,x) numpy array threshold: minimum score required at the new FoV center for a move to be - considered valid + considered valid Yields: tuples of: @@ -59,8 +68,7 @@ def get_scored_move_offsets(deltas, prob_map, threshold=0.9): assert center.size == 3 # Selects a working subvolume no more than +/- delta away from the current # center point. - subvol_sel = [slice(c - dx, c + dx + 1) for c, dx - in zip(center, deltas)] + subvol_sel = [slice(c - dx, c + dx + 1) for c, dx in zip(center, deltas)] done = set() for axis, axis_delta in enumerate(deltas): @@ -92,7 +100,7 @@ def get_scored_move_offsets(deltas, prob_map, threshold=0.9): yield ret -class BaseMovementPolicy(object): +class BaseMovementPolicy: """Base class for movement policy queues. The principal usage is to initialize once with the policy's parameters and @@ -120,9 +128,12 @@ def __len__(self): def __iter__(self): return self - def next(self): + def __next__(self): raise StopIteration() + def next(self): + return self.__next__() + def append(self, item): self.scored_coords.append(item) @@ -132,7 +143,7 @@ def update(self, prob_map, position): Args: prob_map: object probability map returned by the FFN (in logit space) position: postiion of the center of the FoV where inference was performed - (z, y, x) + (z, y, x) """ raise NotImplementedError() @@ -167,10 +178,10 @@ def reset_state(self, start_pos): self._start_pos = start_pos def get_state(self): - return [(self.scored_coords, self.done_rounded_coords)] + return [(self.scored_coords, self.done_rounded_coords, self._start_pos)] def restore_state(self, state): - self.scored_coords, self.done_rounded_coords = state[0] + self.scored_coords, self.done_rounded_coords, self._start_pos = state[0] def __next__(self): """Pops positions from queue until a valid one is found and returns it.""" @@ -186,16 +197,13 @@ def __next__(self): return tuple(coord) - def next(self): - return self.__next__() - def quantize_pos(self, pos): """Quantizes the positions symmetrically to a grid downsampled by deltas.""" # Compute offset relative to the origin of the current segment and # shift by half delta size. This ensures that all directions are treated # approximately symmetrically -- i.e. the origin point lies in the middle of # a cell of the quantized lattice, as opposed to a corner of that cell. - rel_pos = (np.array(pos) - self._start_pos) + rel_pos = np.array(pos) - self._start_pos coord = (rel_pos + self.deltas // 2) // np.maximum(self.deltas, 1) return tuple(coord) @@ -204,8 +212,9 @@ def update(self, prob_map, position): qpos = self.quantize_pos(position) self.done_rounded_coords.add(qpos) - scored_coords = get_scored_move_offsets(self.deltas, prob_map, - threshold=self.score_threshold) + scored_coords = get_scored_move_offsets( + self.deltas, prob_map, threshold=self.score_threshold + ) scored_coords = sorted(scored_coords, reverse=True) for score, rel_coord in scored_coords: # convert to whole cube coordinates @@ -213,7 +222,7 @@ def update(self, prob_map, position): self.scored_coords.append((score, coord)) -def get_policy_fn(request, ffn_model): +def get_policy_fn(request, model_info: ffn_model.ModelInfo): """Returns a policy class based on the InferenceRequest proto.""" if request.movement_policy_name: @@ -228,32 +237,40 @@ def get_policy_fn(request, ffn_model): else: kwargs = {} if 'deltas' not in kwargs: - kwargs['deltas'] = ffn_model.deltas[::-1] + kwargs['deltas'] = model_info.deltas[::-1] if 'score_threshold' not in kwargs: kwargs['score_threshold'] = logit(request.inference_options.move_threshold) return lambda canvas: movement_policy_class(canvas, **kwargs) -class MovementRestrictor(object): +class MovementRestrictor: """Restricts the movement of the FFN FoV.""" - def __init__(self, mask=None, shift_mask=None, shift_mask_fov=None, - shift_mask_threshold=4, shift_mask_scale=1, seed_mask=None): + def __init__( + self, + mask: Optional[np.ndarray] = None, + shift_mask: Optional[np.ndarray] = None, + shift_mask_fov: Optional[bounding_box.BoundingBox] = None, + shift_mask_threshold: int = 4, + shift_mask_scale: int = 1, + seed_mask: Optional[np.ndarray] = None, + ): """Initializes the restrictor. Args: mask: 3d ndarray-like of shape (z, y, x); positive values indicate voxels - that are not going to be segmented + that are not going to be segmented shift_mask: 4d ndarray-like of shape (2, z, y, x) representing a 2d shift - vector field + vector field shift_mask_fov: bounding_box.BoundingBox around large shifts in which to - restrict movement. BoundingBox specified as XYZ, start can be - negative. - shift_mask_threshold: if any component of the shift vector exceeds this - value within the FoV, the location will not be segmented + restrict movement. BoundingBox specified as XYZ, start can be negative. + shift_mask_threshold: if any component of the shift vector equals or + exceeds this value within the FoV, the location will not be segmented shift_mask_scale: an integer factor specifying how much larger the pixels - of the shift mask are compared to the data set processed by the FFN + of the shift mask are compared to the data set processed by the FFN + seed_mask: 3d ndarray-like of shape (z, y, x); positive values indicate + voxels where seeds are not going to be placed """ self.mask = mask self.seed_mask = seed_mask @@ -261,8 +278,9 @@ def __init__(self, mask=None, shift_mask=None, shift_mask_fov=None, self._shift_mask_scale = shift_mask_scale self.shift_mask = None if shift_mask is not None: - self.shift_mask = (np.max(np.abs(shift_mask), axis=0) >= - shift_mask_threshold) + self.shift_mask = ( + np.max(np.abs(shift_mask), axis=0) >= shift_mask_threshold + ) assert shift_mask_fov is not None self._shift_mask_fov_pre_offset = shift_mask_fov.start[::-1] @@ -306,9 +324,13 @@ def is_valid_pos(self, pos): # Do not allow movement through highly distorted areas, which often # result in merge errors. In the simplest case, the distortion magnitude # is quantified with a patch-based cross-correlation map. - if np.any(self.shift_mask[fov_low[0]:(fov_high[0] + 1), - start[1]:(end[1] + 1), - start[2]:(end[2] + 1)]): + if np.any( + self.shift_mask[ + fov_low[0] : (fov_high[0] + 1), + start[1] : (end[1] + 1), + start[2] : (end[2] + 1), + ] + ): return False return True diff --git a/ffn/inference/storage.py b/ffn/inference/storage.py index bcbda59..5928009 100644 --- a/ffn/inference/storage.py +++ b/ffn/inference/storage.py @@ -21,7 +21,7 @@ import os import re import tempfile -from typing import Optional +from typing import Any, Optional from connectomics.common import bounding_box import h5py @@ -33,6 +33,7 @@ from . import segmentation OriginInfo = namedtuple('OriginInfo', ['start_zyx', 'iters', 'walltime_sec']) +Volume = Any class SyncAdapter: @@ -51,7 +52,7 @@ def __repr__(self): return f'{self.__class__.__name__}({repr(self.tstore)})' -def decorated_volume(settings, **kwargs): +def decorated_volume(settings, **kwargs) -> Volume: """Converts DecoratedVolume proto object into volume objects. Args: diff --git a/ffn/training/examples.py b/ffn/training/examples.py new file mode 100644 index 0000000..a729e94 --- /dev/null +++ b/ffn/training/examples.py @@ -0,0 +1,327 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for building training examples for FFN training.""" + +import collections +from concurrent import futures +import itertools +from typing import Callable, Iterable, Optional, Sequence + +import numpy as np +from scipy import special + +from ..inference import movement +from . import mask +from . import model as ffn_model +from . import tracker + +GetOffsets = Callable[ + [ffn_model.ModelInfo, np.ndarray, np.ndarray, tracker.EvalTracker], + Iterable[tuple[int, int, int]]] + + +def get_example(load_example, eval_tracker: tracker.EvalTracker, + info: ffn_model.ModelInfo, get_offsets: GetOffsets, + seed_pad: float, seed_shape: tuple[int, int, int]): + """Generates individual training examples. + + Args: + load_example: callable returning a tuple of image and label ndarrays as well + as the seed coordinate and volume name of the example + eval_tracker: tracker.EvalTracker object + info: ModelInfo metadata about the model + get_offsets: callable returning an iterable of (x, y, z) offsets to + investigate within the training patch + seed_pad: value to fill the empty areas of the seed with + seed_shape: z, y, x shape of the seed + + Yields: + tuple of [1, z, y, x, 1]-shaped arrays for: + seed, image, label, weights + """ + while True: + ex = load_example() + full_patches, full_labels, loss_weights, coord, volname = ex + + # Start with a clean seed. + seed = special.logit(mask.make_seed(seed_shape, 1, pad=seed_pad)) + + for off in get_offsets(info, seed, full_labels, eval_tracker): + predicted = mask.crop_and_pad(seed, off, info.input_seed_size[::-1]) + patches = mask.crop_and_pad(full_patches, off, + info.input_image_size[::-1]) + labels = mask.crop_and_pad(full_labels, off, info.pred_mask_size[::-1]) + weights = mask.crop_and_pad(loss_weights, off, info.pred_mask_size[::-1]) + + # Necessary, since the caller is going to update the array and these + # changes need to be visible in the following iterations. + assert predicted.base is seed + yield predicted, patches, labels, weights + + eval_tracker.add_patch(full_labels, seed, loss_weights, coord) + + +ExampleGenerator = Iterable[tuple[np.ndarray, np.ndarray, np.ndarray, + np.ndarray]] +_BatchGenerator = Iterable[tuple[Sequence[np.ndarray], Sequence[np.ndarray], + Sequence[np.ndarray], Sequence[np.ndarray]]] + + +def _batch_gen(make_example_generator_fn: Callable[[], ExampleGenerator], + batch_size: int) -> _BatchGenerator: + """Generates batches of training examples.""" + # Create a separate generator for every element in the batch. This generator + # will automatically advance to a different training example once the + # allowed moves for the current location are exhausted. + example_gens = [make_example_generator_fn() for _ in range(batch_size)] + + with futures.ThreadPoolExecutor(max_workers=batch_size) as tpe: + while True: + fs = [] + for gen in example_gens: + fs.append(tpe.submit(next, gen)) + + # `batch` is sequence of `batch_size` tuples returned by the + # `get_example` generator, to which we apply the following transformation: + # [(a0, b0), (a1, b1), .. (an, bn)] -> [(a0, a1, .., an), + # (b0, b1, .., bn)] + # (where n is the batch size) to get a sequence, each element of which + # represents a batch of values of a given type (e.g., seed, image, etc.) + batch = [f.result() for f in fs] + yield tuple(zip(*batch)) + + +class BatchExampleIter: + """Generates batches of training examples.""" + + def __init__(self, example_generator_fn: Callable[[], ExampleGenerator], + eval_tracker: tracker.EvalTracker, batch_size: int, + info: ffn_model.ModelInfo): + self._eval_tracker = eval_tracker + self._batch_generator = _batch_gen(example_generator_fn, batch_size) + self._seeds = None + self._info = info + + def __iter__(self): + return self + + def __next__(self): + seeds, patches, labels, weights = next(self._batch_generator) + self._seeds = seeds + batched_seeds = np.concatenate(seeds) + batched_weights = np.concatenate(weights) + self._eval_tracker.track_weights(batched_weights) + return (batched_seeds, np.concatenate(patches), np.concatenate(labels), + batched_weights) + + def update_seeds(self, batched_seeds: np.ndarray): + """Distributes updated predictions back to the generator buffers. + + Args: + batched_seeds: [b, z, y, x, c] array of the part of the seed updated by + the model + """ + assert self._seeds is not None + + # Convert to numpy array in case this function was called with an array-like + # object backed by accelerator memory. + batched_seeds = np.asarray(batched_seeds) + + dx = self._info.input_seed_size[0] - self._info.pred_mask_size[0] + dy = self._info.input_seed_size[1] - self._info.pred_mask_size[1] + dz = self._info.input_seed_size[2] - self._info.pred_mask_size[2] + + if dz == 0 and dy == 0 and dx == 0: + for i in range(len(self._seeds)): + self._seeds[i][:] = batched_seeds[i, ...] + else: + for i in range(len(self._seeds)): + self._seeds[i][:, # + dz // 2:-(dz - dz // 2), # + dy // 2:-(dy - dy // 2), # + dx // 2:-(dx - dx // 2), # + :] = batched_seeds[i, ...] + + +def _eval_move(seed: np.ndarray, labels: np.ndarray, + off_xyz: tuple[int, int, int], seed_threshold: float, + label_threshold: float) -> tuple[bool, bool]: + """Evaluates a FOV move.""" + valid_move = seed[:, # + seed.shape[1] // 2 + off_xyz[2], # + seed.shape[2] // 2 + off_xyz[1], # + seed.shape[3] // 2 + off_xyz[0], # + 0] >= seed_threshold + wanted_move = ( + labels[:, # + labels.shape[1] // 2 + off_xyz[2], # + labels.shape[2] // 2 + off_xyz[1], # + labels.shape[3] // 2 + off_xyz[0], # + 0] >= label_threshold) + + return valid_move, wanted_move + + +FovShifts = Optional[Iterable[tuple[int, int, int]]] + + +def fixed_offsets(info: ffn_model.ModelInfo, + seed: np.ndarray, + labels: np.ndarray, + eval_tracker: tracker.EvalTracker, + threshold: float, + fov_shifts: FovShifts = None): + """Generates offsets based on a fixed list.""" + del info + + label_threshold = special.expit(threshold) + for off in itertools.chain([(0, 0, 0)], fov_shifts): # xyz + valid_move, wanted_move = _eval_move(seed, labels, off, threshold, + label_threshold) + eval_tracker.record_move(wanted_move, valid_move, off) + if not valid_move: + continue + + yield off + + +def fixed_offsets_window(info: ffn_model.ModelInfo, + seed: np.ndarray, + labels: np.ndarray, + eval_tracker: tracker.EvalTracker, + threshold: float, + fov_shifts: FovShifts = None, + radius: int = 4): + """Like fixed_offsets, but allows more flexible moves. + + Instead of looking at the single voxel pointed to by the offset vector, + considers a small window in the plane orthogonal to the movement direction. + + This helps with training on thin processes that might not be followed by the + 'fixed_offsets' movement policy. + + Args: + info: ModelInfo object + seed: seed array (logits) + labels: label array (probabilities) + eval_tracker: EvalTracker object + threshold: value that the seed needs to match or exceed in order to be + considered a valid move target + fov_shifts: list of XYZ moves to evaluate + radius: max distance away from the offset vector to look for voxels crossing + threshold (within a plan ortohogonal to that vector) + + Yields: + XYZ offset tuples indicating moves to take relative to the center of 'seed' + """ + off = 0, 0, 0 + label_threshold = special.expit(threshold) + valid_move, wanted_move = _eval_move(seed, labels, off, threshold, + label_threshold) + eval_tracker.record_move(wanted_move, valid_move, off) + if valid_move: + yield off + + seed_center = np.array(seed.shape[1:4]) // 2 + label_center = np.array(labels.shape[1:4]) // 2 + + # Define a thin shell at distance of 'delta' around the center. + hz, hy, hx = np.mgrid[:seed.shape[1], :seed.shape[2], :seed.shape[3]] + hz -= seed_center[0] + hy -= seed_center[1] + hx -= seed_center[2] + halo = ((np.abs(hx) <= info.deltas[0]) & # + (np.abs(hy) <= info.deltas[1]) & # + (np.abs(hz) <= info.deltas[2]) & ( # + (np.abs(hx) == info.deltas[0]) | # + (np.abs(hy) == info.deltas[1]) | # + (np.abs(hz) == info.deltas[2]))) + + for off in fov_shifts: # xyz + # Look for voxels within a window of radius 'radius' around the standard + # move point. We can extend this window in any direction below since + # the 'halo' array is set up to restrict us to relevant voxels only. + off_center = seed_center + off[::-1] + pre = off_center - radius + post = off_center + radius + 1 + zz, yy, xx = np.where(halo[pre[0]:post[0], pre[1]:post[1], pre[2]:post[2]]) + + zz_s = zz + pre[0] + yy_s = yy + pre[1] + xx_s = xx + pre[2] + xx_l = xx_s + label_center[2] - seed_center[2] + yy_l = yy_s + label_center[1] - seed_center[1] + zz_l = zz_s + label_center[0] - seed_center[0] + + # Under 'fixed_offsets' the exact voxel at the offset vector would + # have to cross the threshold. Here it is instead sufficient that any voxel + # with a specified radius does. + valid_move = np.any(seed[:, zz_s, yy_s, xx_s, :] >= threshold) + wanted_move = np.any(labels[:, zz_l, yy_l, xx_l, :] >= label_threshold) + eval_tracker.record_move(wanted_move, valid_move, off) + if valid_move: + yield off + + +def no_offsets(info: ffn_model.ModelInfo, seed: np.ndarray, labels: np.ndarray, + eval_tracker: tracker.EvalTracker): + del info, labels, seed + eval_tracker.record_move(True, True, (0, 0, 0)) + yield (0, 0, 0) + + +def max_pred_offsets(info: ffn_model.ModelInfo, seed: np.ndarray, + labels: np.ndarray, eval_tracker: tracker.EvalTracker, + threshold: float, max_radius: np.ndarray): + """Generates offsets with the policy used for inference.""" + # Always start at the center. + queue = collections.deque([(0, 0, 0)]) # xyz + done = set() + + label_threshold = special.expit(threshold) + deltas = np.array(info.deltas) + while queue: + offset = np.array(queue.popleft()) + + # Drop any offsets that would take us beyond the image fragment we + # loaded for training. + if np.any(np.abs(np.array(offset)) > max_radius): + continue + + # Ignore locations that were visited previously. + quantized_offset = tuple((offset + deltas / 2) // np.maximum(deltas, 1)) + + if quantized_offset in done: + continue + + valid, wanted = _eval_move(seed, labels, tuple(offset), threshold, + label_threshold) + eval_tracker.record_move(wanted, valid, (0, 0, 0)) + + if not valid or (not wanted and quantized_offset != (0, 0, 0)): + continue + + done.add(quantized_offset) + + yield tuple(offset) + + # Look for new offsets within the updated seed. + curr_seed = mask.crop_and_pad(seed, offset, info.pred_mask_size[::-1]) + todos = sorted( + movement.get_scored_move_offsets( + info.deltas[::-1], curr_seed[0, ..., 0], threshold=threshold), + reverse=True) + queue.extend((x[2] + offset[0], x[1] + offset[1], x[0] + offset[2]) + for _, x in todos) diff --git a/ffn/training/model.py b/ffn/training/model.py index 6411d5c..aa62a2b 100644 --- a/ffn/training/model.py +++ b/ffn/training/model.py @@ -14,36 +14,45 @@ # ============================================================================== """Classes for FFN model definition.""" -from typing import Optional +import dataclasses +import numpy as np import tensorflow.compat.v1 as tf from . import optimizer -class FFNModel(object): - """Base class for FFN models.""" - - # Dimensionality of the model (2 or 3). - dim: int = None - - ############################################################################ - # (x, y, z) tuples defining various properties of the network. - # Note that 3-tuples should be used even for 2D networks, in which case - # the third (z) value is ignored. +@dataclasses.dataclass +class ModelInfo: + """Basic geometric information about the network. + Arrays are (x, y, z), even for 2D models, in which case the z value is + ignored. + """ # How far to move the field of view in the respective directions. - deltas: tuple[int, int, int] = None + deltas: np.ndarray + + # Size of the predicted patch as returned by the model. + pred_mask_size: np.ndarray # Size of the input image and seed subvolumes to be used during inference. # This is enough information to execute a single prediction step, without # moving the field of view. - input_image_size: tuple[int, int, int] = None - input_seed_size: tuple[int, int, int] = None + input_seed_size: np.ndarray + input_image_size: np.ndarray - # Size of the predicted patch as returned by the model. - pred_mask_size: tuple[int, int, int] = None - ########################################################################### + # For JAX models only: whether the predicted seed should be added to + # its initial state. + additive: bool = False + + +class FFNModel: + """Base class for FFN models.""" + + info: ModelInfo + + # Dimensionality of the model (2 or 3). + dim: int = None # TF op to compute loss optimized during training. This should include all # loss components in case more than just the pixelwise loss is used. @@ -52,23 +61,21 @@ class FFNModel(object): # TF op to call to perform loss optimization on the model. train_op = None - def __init__( - self, - deltas: tuple[int, int, int], - batch_size: Optional[int] = None, - define_global_step: bool = True, - ): + def __init__(self, + info: ModelInfo, + batch_size=None, + define_global_step=True): assert self.dim is not None - self.deltas = deltas + self.info = info self.batch_size = batch_size # Initialize the shift collection. This is used during training with the # fixed step size policy. self.shifts = [] - for dx in (-self.deltas[0], 0, self.deltas[0]): - for dy in (-self.deltas[1], 0, self.deltas[1]): - for dz in (-self.deltas[2], 0, self.deltas[2]): + for dx in (-self.info.deltas[0], 0, self.info.deltas[0]): + for dy in (-self.info.deltas[1], 0, self.info.deltas[1]): + for dz in (-self.info.deltas[2], 0, self.info.deltas[2]): if dx == 0 and dy == 0 and dz == 0: continue self.shifts.append((dx, dy, dz)) @@ -88,43 +95,22 @@ def __init__( # If specified, should have the same shape as self.labels. self.loss_weights = None - self.logits = None # type: tf.Operation + # Updated part of the seed, in logit form. + self.logits: tf.Operation = None # List of image tensors to save in summaries. The images are concatenated # along the X axis. self._images = [] - def set_uniform_io_size(self, patch_size): - """Initializes unset input/output sizes to 'patch_size', sets input shapes. - - This assumes that the inputs and outputs are of equal size, and that exactly - one step is executed in every direction during training. - - Args: - patch_size: (x, y, z) specifying the input/output patch size - - Returns: - None - """ - if self.pred_mask_size is None: - self.pred_mask_size = patch_size - if self.input_seed_size is None: - self.input_seed_size = patch_size - if self.input_image_size is None: - self.input_image_size = patch_size - self.set_input_shapes() - def set_input_shapes(self): """Sets the shape inference for input_seed and input_patches. Assumes input_seed_size and input_image_size are already set. """ - self.input_seed.set_shape( - [self.batch_size] + list(self.input_seed_size[::-1]) + [1] - ) - self.input_patches.set_shape( - [self.batch_size] + list(self.input_image_size[::-1]) + [1] - ) + self.input_seed.set_shape([self.batch_size] + + list(self.info.input_seed_size[::-1]) + [1]) + self.input_patches.set_shape([self.batch_size] + + list(self.info.input_image_size[::-1]) + [1]) def set_up_sigmoid_pixelwise_loss(self, logits): """Sets up the loss function of the model.""" @@ -132,8 +118,7 @@ def set_up_sigmoid_pixelwise_loss(self, logits): assert self.loss_weights is not None pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits( - logits=logits, labels=self.labels - ) + logits=logits, labels=self.labels) pixel_loss *= self.loss_weights self.loss = tf.reduce_mean(pixel_loss) tf.summary.scalar('pixel_loss', self.loss) @@ -146,6 +131,8 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): tf.summary.scalar('optimizer_loss', self.loss) opt = optimizer.optimizer_from_flags() + self.opt = opt + grads_and_vars = opt.compute_gradients(loss) for g, v in grads_and_vars: @@ -153,15 +140,9 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): tf.logging.error('Gradient is None: %s', v.op.name) if max_gradient_entry_mag > 0.0: - grads_and_vars = [ - ( - tf.clip_by_value( - g, -max_gradient_entry_mag, +max_gradient_entry_mag - ), - v, - ) - for g, v, in grads_and_vars - ] + grads_and_vars = [(tf.clip_by_value(g, -max_gradient_entry_mag, + +max_gradient_entry_mag), v) + for g, v, in grads_and_vars] trainables = tf.trainable_variables() if trainables: @@ -173,8 +154,7 @@ def set_up_optimizer(self, loss=None, max_gradient_entry_mag=0.7): update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): self.train_op = opt.apply_gradients( - grads_and_vars, global_step=self.global_step, name='train' - ) + grads_and_vars, global_step=self.global_step, name='train') def show_center_slice(self, image, sigmoid=True): image = image[:, image.get_shape().dims[1] // 2, :, :, :] @@ -187,23 +167,19 @@ def add_summaries(self): def update_seed(self, seed, update): """Updates the initial 'seed' with 'update'.""" - dx = self.input_seed_size[0] - self.pred_mask_size[0] - dy = self.input_seed_size[1] - self.pred_mask_size[1] - dz = self.input_seed_size[2] - self.pred_mask_size[2] + dx = self.info.input_seed_size[0] - self.info.pred_mask_size[0] + dy = self.info.input_seed_size[1] - self.info.pred_mask_size[1] + dz = self.info.input_seed_size[2] - self.info.pred_mask_size[2] if dx == 0 and dy == 0 and dz == 0: seed += update else: - seed += tf.pad( - update, - [ - [0, 0], - [dz // 2, dz - dz // 2], - [dy // 2, dy - dy // 2], - [dx // 2, dx - dx // 2], - [0, 0], - ], - ) + seed += tf.pad(update, + [[0, 0], # + [dz // 2, dz - dz // 2], # + [dy // 2, dy - dy // 2], # + [dx // 2, dx - dx // 2], # + [0, 0]]) return seed def define_tf_graph(self): @@ -213,5 +189,4 @@ def define_tf_graph(self): computing and optimizing the loss. """ raise NotImplementedError( - 'DefineTFGraph needs to be defined by a subclass.' - ) + 'DefineTFGraph needs to be defined by a subclass.') diff --git a/ffn/training/models/convstack_3d.py b/ffn/training/models/convstack_3d.py index 675145b..b6ee826 100644 --- a/ffn/training/models/convstack_3d.py +++ b/ffn/training/models/convstack_3d.py @@ -15,6 +15,7 @@ """Simplest FFN model, as described in https://arxiv.org/abs/1611.00421.""" import functools +import itertools import tensorflow.compat.v1 as tf import tf_slim from .. import model @@ -22,49 +23,70 @@ # Note: this model was originally trained with conv3d layers initialized with # TruncatedNormalInitializedVariable with stddev = 0.01. -def _predict_object_mask(net, depth=9): +def _predict_object_mask(net, depth=9, features=32): """Computes single-object mask prediction.""" + conv = tf_slim.convolution3d conv = functools.partial( - tf_slim.convolution3d, - num_outputs=32, kernel_size=(3, 3, 3), padding='SAME') - net = conv(net, scope='conv0_a') - net = conv(net, scope='conv0_b', activation_fn=None) + tf_slim.convolution3d, kernel_size=(3, 3, 3), padding='SAME' + ) + + if isinstance(features, int): + feats = itertools.repeat(features) + else: + feats = iter(features) + + net = conv(net, scope='conv0_a', num_outputs=next(feats)) + net = conv(net, scope='conv0_b', activation_fn=None, num_outputs=next(feats)) for i in range(1, depth): with tf.name_scope('residual%d' % i): in_net = net net = tf.nn.relu(net) - net = conv(net, scope='conv%d_a' % i) - net = conv(net, scope='conv%d_b' % i, activation_fn=None) + net = conv(net, scope='conv%d_a' % i, num_outputs=next(feats)) + net = conv( + net, scope='conv%d_b' % i, activation_fn=None, num_outputs=next(feats) + ) net += in_net net = tf.nn.relu(net) logits = tf_slim.convolution3d( - net, 1, (1, 1, 1), activation_fn=None, scope='conv_lom') + net, 1, (1, 1, 1), activation_fn=None, scope='conv_lom' + ) return logits class ConvStack3DFFNModel(model.FFNModel): + """A simple conv-stack FFN model. + + The model is composed of `depth` residual modules, operating at a + constant spatial resolution. + """ + dim = 3 - def __init__(self, fov_size=None, deltas=None, batch_size=None, depth=9): - super(ConvStack3DFFNModel, self).__init__(deltas, batch_size) - self.set_uniform_io_size(fov_size) + def __init__( + self, + fov_size=None, + deltas=None, + batch_size=None, + depth: int = 9, + features: int = 32, + **kwargs + ): + info = model.ModelInfo(deltas, fov_size, fov_size, fov_size) + super().__init__(info, batch_size, **kwargs) + self.set_input_shapes() self.depth = depth + self.features = features def define_tf_graph(self): self.show_center_slice(self.input_seed) - if self.input_patches is None: - self.input_patches = tf.placeholder( - tf.float32, [1] + list(self.input_image_size[::-1]) +[1], - name='patches') - net = tf.concat([self.input_patches, self.input_seed], 4) with tf.variable_scope('seed_update', reuse=False): - logit_update = _predict_object_mask(net, self.depth) + logit_update = _predict_object_mask(net, self.depth, self.features) logit_seed = self.update_seed(self.input_seed, logit_update) @@ -78,5 +100,3 @@ def define_tf_graph(self): self.show_center_slice(logit_seed) self.show_center_slice(self.labels, sigmoid=False) self.add_summaries() - - self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) diff --git a/ffn/training/optimizer.py b/ffn/training/optimizer.py index 6cc2417..bcd5e84 100644 --- a/ffn/training/optimizer.py +++ b/ffn/training/optimizer.py @@ -14,44 +14,115 @@ # ============================================================================== """Utilities to configure TF optimizers.""" -import tensorflow.compat.v1 as tf - from absl import flags +import tensorflow.compat.v1 as tf +_OPTIMIZER = flags.DEFINE_enum( + 'optimizer', + 'sgd', + ['momentum', 'sgd', 'adagrad', 'adam', 'rmsprop'], + 'Which optimizer to use.', +) +_LEARNING_RATE = flags.DEFINE_float( + 'learning_rate', 0.001, 'Initial learning rate.' +) +_MOMENTUM = flags.DEFINE_float('momentum', 0.9, 'Momentum.') +_LEARNING_RATE_DECAY_FACTOR = flags.DEFINE_float( + 'learning_rate_decay_factor', None, 'Learning rate decay factor.' +) +_DECAY_STEPS = flags.DEFINE_integer( + 'decay_steps', + None, + ( + 'How many steps the model needs to train for in order for ' + 'the decay factor to be applied to the learning rate.' + ), +) +_NUM_EPOCHS_PER_DECAY = flags.DEFINE_float( + 'num_epochs_per_decay', + 2.0, + 'Number of epochs after which learning rate decays.', +) +_RMSPROP_DECAY = flags.DEFINE_float( + 'rmsprop_decay', 0.9, 'Decay term for RMSProp.' +) +_ADAM_BETA1 = flags.DEFINE_float( + 'adam_beta1', 0.9, 'Gradient decay term for Adam.' +) +_ADAM_BETA2 = flags.DEFINE_float( + 'adam_beta2', 0.999, 'Gradient^2 decay term for Adam.' +) +_EPSILON = flags.DEFINE_float( + 'epsilon', 1e-8, 'Epsilon term for RMSProp and Adam.' +) +_SYNC_SGD = flags.DEFINE_boolean( + 'sync_sgd', False, 'Whether to use synchronous SGD.' +) +_REPLICAS_TO_AGGREGATE = flags.DEFINE_integer( + 'replicas_to_aggregate', + None, + 'When using sync SGD, over how many replicas to aggregate the gradients.', +) +_TOTAL_REPLICAS = flags.DEFINE_integer( + 'total_replicas', + None, + 'When using sync SGD, total number of replicas in the training pool.', +) -FLAGS = flags.FLAGS -flags.DEFINE_string('optimizer', 'sgd', - 'Which optimizer to use. Valid values are: ' - 'momentum, sgd, adagrad, adam, rmsprop.') -flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.') -flags.DEFINE_float('momentum', 0.9, 'Momentum.') -flags.DEFINE_float('learning_rate_decay_factor', 0.94, - 'Learning rate decay factor.') -flags.DEFINE_float('num_epochs_per_decay', 2.0, - 'Number of epochs after which learning rate decays.') -flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') -flags.DEFINE_float('adam_beta1', 0.9, 'Gradient decay term for Adam.') -flags.DEFINE_float('adam_beta2', 0.999, 'Gradient^2 decay term for Adam.') -flags.DEFINE_float('epsilon', 1e-8, 'Epsilon term for RMSProp and Adam.') +def _optimizer_from_flags(): + """Defines a TF optimizer based on flag settings.""" + lr = _LEARNING_RATE.value + if ( + _LEARNING_RATE_DECAY_FACTOR.value is not None + and _DECAY_STEPS.value is not None + ): + lr = tf.train.exponential_decay( + _LEARNING_RATE.value, + tf.train.get_or_create_global_step(), + _DECAY_STEPS.value, + _LEARNING_RATE_DECAY_FACTOR.value, + staircase=True, + ) + tf.summary.scalar('learning_rate', lr) -def optimizer_from_flags(): - lr = FLAGS.learning_rate - if FLAGS.optimizer == 'momentum': - return tf.train.MomentumOptimizer(lr, FLAGS.momentum) - elif FLAGS.optimizer == 'sgd': + if _OPTIMIZER.value == 'momentum': + return tf.train.MomentumOptimizer(lr, _MOMENTUM.value) + elif _OPTIMIZER.value == 'sgd': return tf.train.GradientDescentOptimizer(lr) - elif FLAGS.optimizer == 'adagrad': + elif _OPTIMIZER.value == 'adagrad': return tf.train.AdagradOptimizer(lr) - elif FLAGS.optimizer == 'adam': - return tf.train.AdamOptimizer(learning_rate=lr, - beta1=FLAGS.adam_beta1, - beta2=FLAGS.adam_beta2, - epsilon=FLAGS.epsilon) - elif FLAGS.optimizer == 'rmsprop': - return tf.train.RMSPropOptimizer(lr, FLAGS.rmsprop_decay, - momentum=FLAGS.momentum, - epsilon=FLAGS.epsilon) + elif _OPTIMIZER.value == 'adam': + return tf.train.AdamOptimizer( + learning_rate=lr, + beta1=_ADAM_BETA1.value, + beta2=_ADAM_BETA2.value, + epsilon=_EPSILON.value, + ) + elif _OPTIMIZER.value == 'rmsprop': + return tf.train.RMSPropOptimizer( + lr, + _RMSPROP_DECAY.value, + momentum=_MOMENTUM.value, + epsilon=_EPSILON.value, + ) + else: + raise ValueError('Unknown optimizer: %s' % _OPTIMIZER.value) + + +def optimizer_from_flags(): + """Defines a TF optimizer based on command-line flags.""" + opt = _optimizer_from_flags() + if _SYNC_SGD.value: + assert _REPLICAS_TO_AGGREGATE.value is not None + if _TOTAL_REPLICAS.value is not None: + assert _TOTAL_REPLICAS.value >= _REPLICAS_TO_AGGREGATE.value + + return tf.train.SyncReplicasOptimizer( + opt, + replicas_to_aggregate=_REPLICAS_TO_AGGREGATE.value, + total_num_replicas=_TOTAL_REPLICAS.value, + ) else: - raise ValueError('Unknown optimizer: %s' % FLAGS.optimizer) + return opt diff --git a/ffn/training/tracker.py b/ffn/training/tracker.py new file mode 100644 index 0000000..2c1377c --- /dev/null +++ b/ffn/training/tracker.py @@ -0,0 +1,356 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for tracking and reporting the training status.""" + +import collections +import enum +import io +from typing import Optional, Sequence + +import numpy as np + +import PIL +import PIL.Image +import PIL.ImageDraw +from scipy import special + +import tensorflow.compat.v1 as tf +from . import mask +from . import variables + + +if tf.executing_eagerly(): + tf.compat.v2.experimental.numpy.experimental_enable_numpy_behavior() + + +class MoveType(enum.IntEnum): + CORRECT = 0 + MISSED = 1 + SPURIOUS = 2 + + +class VoxelType(enum.IntEnum): + TOTAL = 0 + MASKED = 1 + + +class PredictionType(enum.IntEnum): + TP = 0 + TN = 1 + FP = 2 + FN = 3 + + +class FovStat(enum.IntEnum): + TOTAL_VOXELS = 0 + MASKED_VOXELS = 1 + WEIGHTS_SUM = 2 + + +class EvalTracker: + """Tracks eval results over multiple training steps.""" + + def __init__(self, + eval_shape: list[int], + shifts: Sequence[tuple[int, int, int]]): + # TODO(mjanusz): Remove this TFv1 code once no longer used. + if not tf.executing_eagerly(): + self.eval_labels = tf.compat.v1.placeholder( + tf.float32, [1] + eval_shape + [1], name='eval_labels') + self.eval_preds = tf.compat.v1.placeholder( + tf.float32, [1] + eval_shape + [1], name='eval_preds') + self.eval_weights = tf.compat.v1.placeholder( + tf.float32, [1] + eval_shape + [1], name='eval_weights') + self.eval_loss = tf.reduce_mean( + self.eval_weights * tf.nn.sigmoid_cross_entropy_with_logits( + logits=self.eval_preds, labels=self.eval_labels)) + self.sess = None + self.eval_threshold = special.logit(0.9) + self._eval_shape = eval_shape # zyx + self._define_tf_vars(shifts) + self._patch_count = 0 + + self.reset() + + def _add_tf_var(self, name, shape, dtype): + v = variables.TFSyncVariable(name, shape, dtype) + setattr(self, name, v) + self._tf_vars.append(v) + return v + + def _define_tf_vars(self, fov_shifts: Sequence[tuple[int, int, int]]): + """Defines TFSyncVariables.""" + self._tf_vars = [] + self._add_tf_var('moves', [3], tf.int64) + self._add_tf_var('loss', [1], tf.float32) + self._add_tf_var('num_voxels', [2], tf.int64) + self._add_tf_var('num_patches', [1], tf.int64) + self._add_tf_var('prediction_counts', [4], tf.int64) + self._add_tf_var('fov_stats', [3], tf.float32) + + radii = set([int(np.linalg.norm(s)) for s in fov_shifts]) + radii.add(0) + self.moves_by_r = {} + for r in radii: + self.moves_by_r[r] = self._add_tf_var('moves_%d' % r, [3], tf.int64) + + def to_tf(self): + ops = [] + feed_dict = {} + + for var in self._tf_vars: + var.to_tf(ops, feed_dict) + + assert self.sess is not None + self.sess.run(ops, feed_dict) + + def from_tf(self): + ops = [var.from_tf for var in self._tf_vars] + assert self.sess is not None + values = self.sess.run(ops) + + for value, var in zip(values, self._tf_vars): + var.tf_value = value + + def reset(self): + """Resets status of the tracker.""" + self.images_xy = collections.deque(maxlen=16) + self.images_xz = collections.deque(maxlen=16) + self.images_yz = collections.deque(maxlen=16) + self.meshes = collections.deque(maxlen=16 * 3) + for var in self._tf_vars: + var.reset() + + def track_weights(self, weights: np.ndarray): + self.fov_stats.value[FovStat.TOTAL_VOXELS] += weights.size + self.fov_stats.value[FovStat.MASKED_VOXELS] += np.sum(weights == 0.0) + self.fov_stats.value[FovStat.WEIGHTS_SUM] += np.sum(weights) + + def record_move(self, wanted: bool, executed: bool, + offset_xyz: Sequence[int]): + """Records an FFN FOV move.""" + r = int(np.linalg.norm(offset_xyz)) + assert r in self.moves_by_r, ('%d not in %r' % + (r, list(self.moves_by_r.keys()))) + + if wanted: + if executed: + self.moves.value[MoveType.CORRECT] += 1 + self.moves_by_r[r].value[MoveType.CORRECT] += 1 + else: + self.moves.value[MoveType.MISSED] += 1 + self.moves_by_r[r].value[MoveType.MISSED] += 1 + elif executed: + self.moves.value[MoveType.SPURIOUS] += 1 + self.moves_by_r[r].value[MoveType.SPURIOUS] += 1 + + def slice_image(self, coord: np.ndarray, labels: np.ndarray, + predicted: np.ndarray, weights: np.ndarray, + slice_axis: int) -> tf.Summary.Value: + """Builds a tf.Summary showing a slice of an object mask. + + The object mask slice is shown side by side with the corresponding + ground truth mask. + + Args: + coord: [1, 3] xyz coordinate as ndarray + labels: ndarray of ground truth data, shape [1, z, y, x, 1] + predicted: ndarray of predicted data, shape [1, z, y, x, 1] + weights: ndarray of loss weights, shape [1, z, y, x, 1] + slice_axis: axis in the middle of which to place the cutting plane for + which the summary image will be generated, valid values are 2 ('x'), 1 + ('y'), and 0 ('z'). + + Returns: + tf.Summary.Value object with the image. + """ + zyx = list(labels.shape[1:-1]) + selector = [0, slice(None), slice(None), slice(None), 0] + selector[slice_axis + 1] = zyx[slice_axis] // 2 + selector = tuple(selector) # for numpy indexing + + del zyx[slice_axis] + h, w = zyx + + buf = io.BytesIO() + labels = (labels[selector] * 255).astype(np.uint8) + predicted = (predicted[selector] * 255).astype(np.uint8) + weights = (weights[selector] * 255).astype(np.uint8) + + im = PIL.Image.fromarray( + np.repeat( + np.concatenate([labels, predicted, weights], axis=1)[..., + np.newaxis], + 3, + axis=2), 'RGB') + draw = PIL.ImageDraw.Draw(im) + + x, y, z = coord.squeeze() + draw.text((1, 1), '%d %d %d' % (x, y, z), fill='rgb(255,64,64)') + del draw + + im.save(buf, 'PNG') + + axis_names = 'zyx' + axis_names = axis_names.replace(axis_names[slice_axis], '') + + return tf.Summary.Value( + tag='final_%s' % axis_names[::-1], + image=tf.Summary.Image( + height=h, + width=w * 3, + colorspace=3, # RGB + encoded_image_string=buf.getvalue())) + + def add_patch(self, + labels: np.ndarray, + predicted: np.ndarray, + weights: np.ndarray, + coord: Optional[np.ndarray] = None, + image_summaries: bool = True): + """Evaluates single-object segmentation quality.""" + + predicted = mask.crop_and_pad(predicted, (0, 0, 0), self._eval_shape) + weights = mask.crop_and_pad(weights, (0, 0, 0), self._eval_shape) + labels = mask.crop_and_pad(labels, (0, 0, 0), self._eval_shape) + + if not tf.executing_eagerly(): + assert self.sess is not None + loss, = self.sess.run( + [self.eval_loss], { + self.eval_labels: labels, + self.eval_preds: predicted, + self.eval_weights: weights + }) + else: + loss = tf.reduce_mean(weights * tf.nn.sigmoid_cross_entropy_with_logits( + logits=predicted, labels=labels)) + + self.loss.value[:] += loss + self.num_voxels.value[VoxelType.TOTAL] += labels.size + self.num_voxels.value[VoxelType.MASKED] += np.sum(weights == 0.0) + + pred_mask = predicted >= self.eval_threshold + true_mask = labels > 0.5 + pred_bg = np.logical_not(pred_mask) + true_bg = np.logical_not(true_mask) + + self.prediction_counts.value[PredictionType.TP] += np.sum(pred_mask + & true_mask) + self.prediction_counts.value[PredictionType.TN] += np.sum(pred_bg & true_bg) + self.prediction_counts.value[PredictionType.FP] += np.sum(pred_mask + & true_bg) + self.prediction_counts.value[PredictionType.FN] += np.sum(pred_bg + & true_mask) + self.num_patches.value[:] += 1 + + if image_summaries: + predicted = special.expit(predicted) + self.images_xy.append( + self.slice_image(coord, labels, predicted, weights, 0)) + self.images_xz.append( + self.slice_image(coord, labels, predicted, weights, 1)) + self.images_yz.append( + self.slice_image(coord, labels, predicted, weights, 2)) + + def _compute_classification_metrics(self, prediction_counts, prefix): + """Computes standard classification metrics.""" + tp = prediction_counts.tf_value[PredictionType.TP] + fp = prediction_counts.tf_value[PredictionType.FP] + tn = prediction_counts.tf_value[PredictionType.TN] + fn = prediction_counts.tf_value[PredictionType.FN] + + precision = tp / max(tp + fp, 1) + recall = tp / max(tp + fn, 1) + + if precision > 0 or recall > 0: + f1 = (2.0 * precision * recall / (precision + recall)) + else: + f1 = 0.0 + + return [ + tf.Summary.Value( + tag='%s/accuracy' % prefix, + simple_value=(tp + tn) / max(tp + tn + fp + fn, 1)), + tf.Summary.Value(tag='%s/precision' % prefix, simple_value=precision), + tf.Summary.Value(tag='%s/recall' % prefix, simple_value=recall), + tf.Summary.Value( + tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1)), + tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1) + ] + + def get_summaries(self) -> list[tf.Summary.Value]: + """Gathers tensorflow summaries into single list.""" + + self.from_tf() + if not self.num_voxels.tf_value[VoxelType.TOTAL]: + return [] + + for images in self.images_xy, self.images_xz, self.images_yz: + for i, summary in enumerate(images): + summary.tag += '/%d' % i + + total_moves = sum(self.moves.tf_value) + move_summaries = [] + for mt in MoveType: + move_summaries.append( + tf.Summary.Value( + tag='moves/all/%s' % mt.name.lower(), + simple_value=self.moves.tf_value[mt] / total_moves)) + + summaries = [ + tf.Summary.Value( + tag='fov/masked_voxel_fraction', + simple_value=(self.fov_stats.tf_value[FovStat.MASKED_VOXELS] / + self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])), + tf.Summary.Value( + tag='fov/average_weight', + simple_value=(self.fov_stats.tf_value[FovStat.WEIGHTS_SUM] / + self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])), + tf.Summary.Value( + tag='masked_voxel_fraction', + simple_value=(self.num_voxels.tf_value[VoxelType.MASKED] / + self.num_voxels.tf_value[VoxelType.TOTAL])), + tf.Summary.Value( + tag='eval/patch_loss', + simple_value=self.loss.tf_value[0] / self.num_patches.tf_value[0]), + tf.Summary.Value( + tag='eval/patches', simple_value=self.num_patches.tf_value[0]), + tf.Summary.Value(tag='moves/total', simple_value=total_moves) + ] + move_summaries + ( + list(self.meshes) + list(self.images_xy) + list(self.images_xz) + + list(self.images_yz)) + + summaries.extend( + self._compute_classification_metrics(self.prediction_counts, + 'eval/all')) + + for r, r_moves in self.moves_by_r.items(): + total_moves = sum(r_moves.tf_value) + summaries.extend([ + tf.Summary.Value( + tag='moves/r=%d/correct' % r, + simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves), + tf.Summary.Value( + tag='moves/r=%d/spurious' % r, + simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves), + tf.Summary.Value( + tag='moves/r=%d/missed' % r, + simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves), + tf.Summary.Value( + tag='moves/r=%d/total' % r, simple_value=total_moves) + ]) + + return summaries diff --git a/ffn/utils/bounding_box.py b/ffn/utils/bounding_box.py index 0b0b323..a3bd270 100644 --- a/ffn/utils/bounding_box.py +++ b/ffn/utils/bounding_box.py @@ -238,7 +238,7 @@ def containing(*boxes): """ if not boxes: raise ValueError('At least one bounding box must be specified') - boxes_objs = map(BoundingBox, boxes) + boxes_objs = list(map(BoundingBox, boxes)) start = boxes_objs[0].start end = boxes_objs[0].end for box in boxes_objs[1:]: diff --git a/train.py b/train.py index d8a3988..6aaef93 100644 --- a/train.py +++ b/train.py @@ -19,41 +19,30 @@ of view in a way dependent on the initial predictions. """ -from collections import deque -from io import BytesIO from functools import partial -import itertools import json import logging import os import random import time +from typing import Optional +from absl import app +from absl import flags +from ffn.training import augmentation +from ffn.training import examples +from ffn.training import inputs +from ffn.training import model as ffn_model +# Necessary so that optimizer flags are defined. +from ffn.training import optimizer # pylint: disable=unused-import +from ffn.training import tracker +from ffn.training.import_util import import_symbol import h5py import numpy as np - -import PIL -import PIL.Image - -import six - -from scipy.special import expit -from scipy.special import logit +from scipy import special import tensorflow.compat.v1 as tf - -from absl import app -from absl import flags from tensorflow.io import gfile -from ffn.inference import movement -from ffn.training import mask -from ffn.training.import_util import import_symbol -from ffn.training import inputs -from ffn.training import augmentation -# Necessary so that optimizer flags are defined. -# pylint: disable=unused-import -from ffn.training import optimizer -# pylint: enable=unused-import FLAGS = flags.FLAGS @@ -143,148 +132,9 @@ FLAGS = flags.FLAGS -class EvalTracker(object): - """Tracks eval results over multiple training steps.""" - - def __init__(self, eval_shape): - self.eval_labels = tf.placeholder( - tf.float32, [1] + eval_shape + [1], name='eval_labels') - self.eval_preds = tf.placeholder( - tf.float32, [1] + eval_shape + [1], name='eval_preds') - self.eval_loss = tf.reduce_mean( - tf.nn.sigmoid_cross_entropy_with_logits( - logits=self.eval_preds, labels=self.eval_labels)) - self.reset() - self.eval_threshold = logit(0.9) - self.sess = None - self._eval_shape = eval_shape - - def reset(self): - """Resets status of the tracker.""" - self.loss = 0 - self.num_patches = 0 - self.tp = 0 - self.tn = 0 - self.fn = 0 - self.fp = 0 - self.total_voxels = 0 - self.masked_voxels = 0 - self.images_xy = deque(maxlen=16) - self.images_xz = deque(maxlen=16) - self.images_yz = deque(maxlen=16) - - def slice_image(self, labels, predicted, weights, slice_axis): - """Builds a tf.Summary showing a slice of an object mask. - - The object mask slice is shown side by side with the corresponding - ground truth mask. - - Args: - labels: ndarray of ground truth data, shape [1, z, y, x, 1] - predicted: ndarray of predicted data, shape [1, z, y, x, 1] - weights: ndarray of loss weights, shape [1, z, y, x, 1] - slice_axis: axis in the middle of which to place the cutting plane - for which the summary image will be generated, valid values are - 2 ('x'), 1 ('y'), and 0 ('z'). - - Returns: - tf.Summary.Value object with the image. - """ - zyx = list(labels.shape[1:-1]) - selector = [0, slice(None), slice(None), slice(None), 0] - selector[slice_axis + 1] = zyx[slice_axis] // 2 - selector = tuple(selector) - - del zyx[slice_axis] - h, w = zyx - - buf = BytesIO() - labels = (labels[selector] * 255).astype(np.uint8) - predicted = (predicted[selector] * 255).astype(np.uint8) - weights = (weights[selector] * 255).astype(np.uint8) - - im = PIL.Image.fromarray(np.concatenate([labels, predicted, - weights], axis=1), 'L') - im.save(buf, 'PNG') - - axis_names = 'zyx' - axis_names = axis_names.replace(axis_names[slice_axis], '') - - return tf.Summary.Value( - tag='final_%s' % axis_names[::-1], - image=tf.Summary.Image( - height=h, width=w * 3, colorspace=1, # greyscale - encoded_image_string=buf.getvalue())) - - def add_patch(self, labels, predicted, weights, - coord=None, volname=None, patches=None): - """Evaluates single-object segmentation quality.""" - - predicted = mask.crop_and_pad(predicted, (0, 0, 0), self._eval_shape) - weights = mask.crop_and_pad(weights, (0, 0, 0), self._eval_shape) - labels = mask.crop_and_pad(labels, (0, 0, 0), self._eval_shape) - loss, = self.sess.run([self.eval_loss], {self.eval_labels: labels, - self.eval_preds: predicted}) - self.loss += loss - self.total_voxels += labels.size - self.masked_voxels += np.sum(weights == 0.0) - - pred_mask = predicted >= self.eval_threshold - true_mask = labels > 0.5 - pred_bg = np.logical_not(pred_mask) - true_bg = np.logical_not(true_mask) - - self.tp += np.sum(pred_mask & true_mask) - self.fp += np.sum(pred_mask & true_bg) - self.fn += np.sum(pred_bg & true_mask) - self.tn += np.sum(pred_bg & true_bg) - self.num_patches += 1 - - predicted = expit(predicted) - self.images_xy.append(self.slice_image(labels, predicted, weights, 0)) - self.images_xz.append(self.slice_image(labels, predicted, weights, 1)) - self.images_yz.append(self.slice_image(labels, predicted, weights, 2)) - - def get_summaries(self): - """Gathers tensorflow summaries into single list.""" - - if not self.total_voxels: - return [] - - precision = self.tp / max(self.tp + self.fp, 1) - recall = self.tp / max(self.tp + self.fn, 1) - - for images in self.images_xy, self.images_xz, self.images_yz: - for i, summary in enumerate(images): - summary.tag += '/%d' % i - - summaries = ( - list(self.images_xy) + list(self.images_xz) + list(self.images_yz) + [ - tf.Summary.Value(tag='masked_voxel_fraction', - simple_value=(self.masked_voxels / - self.total_voxels)), - tf.Summary.Value(tag='eval/patch_loss', - simple_value=self.loss / self.num_patches), - tf.Summary.Value(tag='eval/patches', - simple_value=self.num_patches), - tf.Summary.Value(tag='eval/accuracy', - simple_value=(self.tp + self.tn) / ( - self.tp + self.tn + self.fp + self.fn)), - tf.Summary.Value(tag='eval/precision', - simple_value=precision), - tf.Summary.Value(tag='eval/recall', - simple_value=recall), - tf.Summary.Value(tag='eval/specificity', - simple_value=self.tn / max(self.tn + self.fp, 1)), - tf.Summary.Value(tag='eval/f1', - simple_value=(2.0 * precision * recall / - (precision + recall))) - ]) - - return summaries - - -def run_training_step(sess, model, fetch_summary, feed_dict): +def run_training_step(sess: tf.Session, model: ffn_model.FFNModel, + fetch_summary: Optional[tf.Operation], + feed_dict: dict[str, np.ndarray]): """Runs one training step for a single FFN FOV.""" ops_to_run = [model.train_op, model.global_step, model.logits] @@ -302,34 +152,34 @@ def run_training_step(sess, model, fetch_summary, feed_dict): return prediction, step, summ -def fov_moves(): +def fov_moves() -> int: # Add one more move to get a better fill of the evaluation area. if FLAGS.fov_policy == 'max_pred_moves': return FLAGS.fov_moves + 1 return FLAGS.fov_moves -def train_labels_size(model): - return (np.array(model.pred_mask_size) + - np.array(model.deltas) * 2 * fov_moves()) +def train_labels_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.pred_mask_size) + + np.array(info.deltas) * 2 * fov_moves()) -def train_eval_size(model): - return (np.array(model.pred_mask_size) + - np.array(model.deltas) * 2 * FLAGS.fov_moves) +def train_eval_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.pred_mask_size) + + np.array(info.deltas) * 2 * FLAGS.fov_moves) -def train_image_size(model): - return (np.array(model.input_image_size) + - np.array(model.deltas) * 2 * fov_moves()) +def train_image_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.input_image_size) + + np.array(info.deltas) * 2 * fov_moves()) -def train_canvas_size(model): - return (np.array(model.input_seed_size) + - np.array(model.deltas) * 2 * fov_moves()) +def train_canvas_size(info: ffn_model.ModelInfo) -> np.ndarray: + return (np.array(info.input_seed_size) + + np.array(info.deltas) * 2 * fov_moves()) -def _get_offset_and_scale_map(): +def _get_offset_and_scale_map() -> dict[str, tuple[float, float]]: if not FLAGS.image_offset_scale_map: return {} @@ -436,167 +286,15 @@ def define_data_input(model, queue_batch=None): return patches, labels, loss_weights, coord, volname -def prepare_ffn(model): +def prepare_ffn(model: ffn_model.FFNModel): """Creates the TF graph for an FFN.""" - shape = [FLAGS.batch_size] + list(model.pred_mask_size[::-1]) + [1] + shape = [FLAGS.batch_size] + list(model.info.pred_mask_size[::-1]) + [1] model.labels = tf.placeholder(tf.float32, shape, name='labels') model.loss_weights = tf.placeholder(tf.float32, shape, name='loss_weights') model.define_tf_graph() -def fixed_offsets(model, seed, fov_shifts=None): - """Generates offsets based on a fixed list.""" - for off in itertools.chain([(0, 0, 0)], fov_shifts): - if model.dim == 3: - is_valid_move = seed[:, - seed.shape[1] // 2 + off[2], - seed.shape[2] // 2 + off[1], - seed.shape[3] // 2 + off[0], - 0] >= logit(FLAGS.threshold) - else: - is_valid_move = seed[:, - seed.shape[1] // 2 + off[1], - seed.shape[2] // 2 + off[0], - 0] >= logit(FLAGS.threshold) - - if not is_valid_move: - continue - - yield off - - -def max_pred_offsets(model, seed): - """Generates offsets with the policy used for inference.""" - # Always start at the center. - queue = deque([(0, 0, 0)]) - done = set() - - train_image_radius = train_image_size(model) // 2 - input_image_radius = np.array(model.input_image_size) // 2 - - while queue: - offset = queue.popleft() - - # Drop any offsets that would take us beyond the image fragment we - # loaded for training. - if np.any(np.abs(np.array(offset)) + input_image_radius > - train_image_radius): - continue - - # Ignore locations that were visited previously. - quantized_offset = ( - offset[0] // max(model.deltas[0], 1), - offset[1] // max(model.deltas[1], 1), - offset[2] // max(model.deltas[2], 1)) - - if quantized_offset in done: - continue - - done.add(quantized_offset) - - yield offset - - # Look for new offsets within the updated seed. - curr_seed = mask.crop_and_pad(seed, offset, model.pred_mask_size[::-1]) - todos = sorted( - movement.get_scored_move_offsets( - model.deltas[::-1], - curr_seed[0, ..., 0], - threshold=logit(FLAGS.threshold)), reverse=True) - queue.extend((x[2] + offset[0], - x[1] + offset[1], - x[0] + offset[2]) for _, x in todos) - - -def get_example(load_example, eval_tracker, model, get_offsets): - """Generates individual training examples. - - Args: - load_example: callable returning a tuple of image and label ndarrays - as well as the seed coordinate and volume name of the example - eval_tracker: EvalTracker object - model: FFNModel object - get_offsets: iterable of (x, y, z) offsets to investigate within the - training patch - - Yields: - tuple of: - seed array, shape [1, z, y, x, 1] - image array, shape [1, z, y, x, 1] - label array, shape [1, z, y, x, 1] - """ - seed_shape = train_canvas_size(model).tolist()[::-1] - - while True: - full_patches, full_labels, loss_weights, coord, volname = load_example() - # Always start with a clean seed. - seed = logit(mask.make_seed(seed_shape, 1, pad=FLAGS.seed_pad)) - - for off in get_offsets(model, seed): - predicted = mask.crop_and_pad(seed, off, model.input_seed_size[::-1]) - patches = mask.crop_and_pad(full_patches, off, model.input_image_size[::-1]) - labels = mask.crop_and_pad(full_labels, off, model.pred_mask_size[::-1]) - weights = mask.crop_and_pad(loss_weights, off, model.pred_mask_size[::-1]) - - # Necessary, since the caller is going to update the array and these - # changes need to be visible in the following iterations. - assert predicted.base is seed - yield predicted, patches, labels, weights - - eval_tracker.add_patch( - full_labels, seed, loss_weights, coord, volname, full_patches) - - -def get_batch(load_example, eval_tracker, model, batch_size, get_offsets): - """Generates batches of training examples. - - Args: - load_example: callable returning a tuple of image and label ndarrays - as well as the seed coordinate and volume name of the example - eval_tracker: EvalTracker object - model: FFNModel object - batch_size: desidred batch size - get_offsets: iterable of (x, y, z) offsets to investigate within the - training patch - - Yields: - tuple of: - seed array, shape [b, z, y, x, 1] - image array, shape [b, z, y, x, 1] - label array, shape [b, z, y, x, 1] - - where 'b' is the batch_size. - """ - def _batch(iterable): - for batch_vals in iterable: - # `batch_vals` is sequence of `batch_size` tuples returned by the - # `get_example` generator, to which we apply the following transformation: - # [(a0, b0), (a1, b1), .. (an, bn)] -> [(a0, a1, .., an), - # (b0, b1, .., bn)] - # (where n is the batch size) to get a sequence, each element of which - # represents a batch of values of a given type (e.g., seed, image, etc.) - yield zip(*batch_vals) - - # Create a separate generator for every element in the batch. This generator - # will automatically advance to a different training example once the allowed - # moves for the current location are exhausted. - for seeds, patches, labels, weights in _batch(six.moves.zip( - *[get_example(load_example, eval_tracker, model, get_offsets) for _ - in range(batch_size)])): - - batched_seeds = np.concatenate(seeds) - - yield (batched_seeds, np.concatenate(patches), np.concatenate(labels), - np.concatenate(weights)) - - # batched_seed is updated in place with new predictions by the code - # calling get_batch. Here we distribute these updated predictions back - # to the buffer of every generator. - for i in range(batch_size): - seeds[i][:] = batched_seeds[i, ...] - - def save_flags(): gfile.makedirs(FLAGS.train_dir) with gfile.GFile( @@ -614,9 +312,9 @@ def train_ffn(model_cls, **model_kwargs): # The constructor might define TF ops/placeholders, so it is important # that the FFN is instantiated within the current context. model = model_cls(**model_kwargs) - eval_shape_zyx = train_eval_size(model).tolist()[::-1] + eval_shape_zyx = train_eval_size(model.info).tolist()[::-1] - eval_tracker = EvalTracker(eval_shape_zyx) + eval_tracker = tracker.EvalTracker(eval_shape_zyx, model.shifts) load_data_ops = define_data_input(model, queue_batch=1) prepare_ffn(model) merge_summaries_op = tf.summary.merge_all() @@ -656,13 +354,35 @@ def train_ffn(model_cls, **model_kwargs): if FLAGS.shuffle_moves: random.shuffle(fov_shifts) + train_image_radius = train_image_size(model.info) // 2 + input_image_radius = np.array(model.info.input_image_size) // 2 policy_map = { - 'fixed': partial(fixed_offsets, fov_shifts=fov_shifts), - 'max_pred_moves': max_pred_offsets + 'fixed': + partial( + examples.fixed_offsets, + fov_shifts=fov_shifts, + threshold=special.logit(FLAGS.threshold)), + 'max_pred_moves': + partial( + examples.max_pred_offsets, + max_radius=train_image_radius - input_image_radius, + threshold=special.logit(FLAGS.threshold)), + 'no_step': + examples.no_offsets } - batch_it = get_batch(lambda: sess.run(load_data_ops), - eval_tracker, model, FLAGS.batch_size, - policy_map[FLAGS.fov_policy]) + policy_fn = policy_map[FLAGS.fov_policy] + + def _make_ffn_example(): + return examples.get_example( + lambda: sess.run(load_data_ops), + eval_tracker, + model.info, + policy_fn, + FLAGS.seed_pad, + seed_shape=tuple(train_canvas_size(model.info).tolist()[::-1])) + + batch_it = examples.BatchExampleIter(_make_ffn_example, eval_tracker, + FLAGS.batch_size, model.info) t_last = time.time() @@ -677,6 +397,7 @@ def train_ffn(model_cls, **model_kwargs): seed, patches, labels, weights = next(batch_it) + eval_tracker.to_tf() updated_seed, step, summ = run_training_step( sess, model, summ_op, feed_dict={ @@ -688,7 +409,7 @@ def train_ffn(model_cls, **model_kwargs): # Save prediction results in the original seed array so that # they can be used in subsequent steps. - mask.update_at(seed, (0, 0, 0), updated_seed) + batch_it.update_seeds(updated_seed) # Record summaries. if summ is not None: