Skip to content

Commit

Permalink
Update convstack model definition and extend optimizer config flags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600455981
  • Loading branch information
mjanusz authored and copybara-github committed Jan 22, 2024
1 parent efad8c4 commit 150f0d1
Show file tree
Hide file tree
Showing 12 changed files with 1,131 additions and 516 deletions.
42 changes: 26 additions & 16 deletions ffn/inference/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def __init__(self, model, session, counters, batch_size):
self.active_clients = 0

# Cache input/output sizes.
self._input_seed_size = np.array(model.input_seed_size[::-1]).tolist()
self._input_image_size = np.array(model.input_image_size[::-1]).tolist()
self._pred_size = np.array(model.pred_mask_size[::-1]).tolist()
self._input_seed_size = np.array(model.info.input_seed_size[::-1]).tolist()
self._input_image_size = np.array(
model.info.input_image_size[::-1]
).tolist()
self._pred_size = np.array(model.info.pred_mask_size[::-1]).tolist()

self._initialize_model()

Expand Down Expand Up @@ -111,8 +113,9 @@ class ThreadingBatchExecutor(BatchExecutor):
"""

def __init__(self, model, session, counters, batch_size, expected_clients=1):
super(ThreadingBatchExecutor, self).__init__(model, session, counters,
batch_size)
super(ThreadingBatchExecutor, self).__init__(
model, session, counters, batch_size
)
self._lock = threading.Lock()
self.outputs = {} # Will be populated by Queues as clients register.
# Used by clients to communiate with the executor. The protocol is
Expand All @@ -131,10 +134,12 @@ def __init__(self, model, session, counters, batch_size, expected_clients=1):
self.expected_clients = expected_clients

# Arrays fed to TF.
self.input_seed = np.zeros([batch_size] + self._input_seed_size + [1],
dtype=np.float32)
self.input_image = np.zeros([batch_size] + self._input_image_size + [1],
dtype=np.float32)
self.input_seed = np.zeros(
[batch_size] + self._input_seed_size + [1], dtype=np.float32
)
self.input_image = np.zeros(
[batch_size] + self._input_image_size + [1], dtype=np.float32
)
self.th_executor = None

def start_server(self):
Expand All @@ -146,7 +151,8 @@ def start_server(self):
"""
if self.th_executor is None:
self.th_executor = threading.Thread(
target=self._run_executor_log_exceptions)
target=self._run_executor_log_exceptions
)
self.th_executor.start()

def stop_server(self):
Expand All @@ -166,8 +172,10 @@ def _run_executor(self):

with timer_counter(self.counters, 'executor-input'):
ready = []
while (len(ready) < min(self.active_clients, self.batch_size) or
not self.active_clients):
while (
len(ready) < min(self.active_clients, self.batch_size)
or not self.active_clients
):
try:
data = self.input_queue.get(timeout=5)
except queue.Empty:
Expand Down Expand Up @@ -201,9 +209,12 @@ def _schedule_batch(self, client_ids, fetches):
with timer_counter(self.counters, 'executor-inference'):
try:
ret = self.session.run(
fetches, {
fetches,
{
self.model.input_seed: self.input_seed,
self.model.input_patches: self.input_image})
self.model.input_patches: self.input_image,
},
)
except Exception as e: # pylint:disable=broad-except
logging.exception(e)
# If calling TF didn't work (faulty hardware, misconfiguration, etc),
Expand All @@ -215,8 +226,7 @@ def _schedule_batch(self, client_ids, fetches):
with self._lock:
for i, client_id in enumerate(client_ids):
try:
self.outputs[client_id].put(
{k: v[i, ...] for k, v in ret.items()})
self.outputs[client_id].put({k: v[i, ...] for k, v in ret.items()})
except KeyError:
# This could happen if a client unregistered itself
# while inference was running.
Expand Down
24 changes: 14 additions & 10 deletions ffn/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from scipy.special import logit
import tensorflow.compat.v1 as tf
from tensorflow.io import gfile
from ..training import model as ffn_model
from ..training.import_util import import_symbol
from ..utils import bounding_box
from ..utils import ortho_plane_visualization
Expand All @@ -48,7 +49,7 @@

# Visualization.
# ---------------------------------------------------------------------------
class DynamicImage(object):
class DynamicImage:
def UpdateFromPIL(self, new_img):
from io import BytesIO
from IPython import display
Expand Down Expand Up @@ -172,11 +173,11 @@ def _halt_signaler(fetches, pos, orig_pos, counters, **unused_kwargs):


# TODO(mjanusz): Add support for sparse inference.
class Canvas(object):
class Canvas:
"""Tracks state of the inference progress and results within a subvolume."""

def __init__(self,
model,
model: ffn_model.FFNModel,
tf_executor,
image,
options,
Expand Down Expand Up @@ -242,9 +243,9 @@ def __init__(self,

# Cast to array to ensure we can do elementwise expressions later.
# All of these are in zyx order.
self._pred_size = np.array(model.pred_mask_size[::-1])
self._input_seed_size = np.array(model.input_seed_size[::-1])
self._input_image_size = np.array(model.input_image_size[::-1])
self._pred_size = np.array(model.info.pred_mask_size[::-1])
self._input_seed_size = np.array(model.info.input_seed_size[::-1])
self._input_image_size = np.array(model.info.input_image_size[::-1])
self.margin = self._input_image_size // 2

self._pred_delta = (self._input_seed_size - self._pred_size) // 2
Expand Down Expand Up @@ -277,7 +278,7 @@ def __init__(self,
if movement_policy_fn is None:
# The model.deltas are (for now) in xyz order and must be swapped to zyx.
self.movement_policy = movement.FaceMaxMovementPolicy(
self, deltas=model.deltas[::-1],
self, deltas=model.info.deltas[::-1],
score_threshold=self.options.move_threshold)
else:
self.movement_policy = movement_policy_fn(self)
Expand Down Expand Up @@ -789,7 +790,7 @@ def _maybe_save_checkpoint(self):
self.checkpoint_last = time.time()


class Runner(object):
class Runner:
"""Helper for managing FFN inference runs.
Takes care of initializing the FFN model and any related functionality
Expand All @@ -799,6 +800,9 @@ class Runner(object):

ALL_MASKED = 1

request: inference_pb2.InferenceRequest
executor: executor.BatchExecutor

def __init__(self):
self.counters = inference_utils.Counters()
self.executor = None
Expand Down Expand Up @@ -908,7 +912,7 @@ def _open_or_none(settings):

self.executor = exec_cls(
self.model, self.session, self.counters, batch_size)
self.movement_policy_fn = movement.get_policy_fn(request, self.model)
self.movement_policy_fn = movement.get_policy_fn(request, self.model.info)

self.saver = tf.train.Saver()
self._load_model_checkpoint(request.model_checkpoint_path)
Expand Down Expand Up @@ -975,7 +979,7 @@ def make_restrictor(self, corner, subvol_size, image, alignment):
start=self.request.shift_mask_fov.start,
size=self.request.shift_mask_fov.size)
else:
shift_mask_diameter = np.array(self.model.input_image_size)
shift_mask_diameter = np.array(self.model.info.input_image_size)
shift_mask_fov = bounding_box.BoundingBox(
start=-(shift_mask_diameter // 2), size=shift_mask_diameter)

Expand Down
86 changes: 54 additions & 32 deletions ffn/inference/movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@

from collections import deque
import json
from typing import Optional
import weakref

from connectomics.common import bounding_box
import numpy as np
from scipy.special import logit

from ..training import model as ffn_model
from ..training.import_util import import_symbol

# Unless stated otherwise, all shape/coordinate triples in this file are in zyx
Expand All @@ -34,7 +39,11 @@
# face, and look at the max probability point in every connected component
# within a face. Probably best to implement this in C++ and just use a Python
# wrapper.
def get_scored_move_offsets(deltas, prob_map, threshold=0.9):
def get_scored_move_offsets(
deltas: tuple[int, int, int] | np.ndarray,
prob_map: np.ndarray,
threshold: float = 0.9,
):
"""Looks for potential moves for a FFN.
The possible moves are determined by extracting probability map values
Expand All @@ -45,7 +54,7 @@ def get_scored_move_offsets(deltas, prob_map, threshold=0.9):
deltas: (z,y,x) tuple of base move offsets for the 3 axes
prob_map: current probability map as a (z,y,x) numpy array
threshold: minimum score required at the new FoV center for a move to be
considered valid
considered valid
Yields:
tuples of:
Expand All @@ -59,8 +68,7 @@ def get_scored_move_offsets(deltas, prob_map, threshold=0.9):
assert center.size == 3
# Selects a working subvolume no more than +/- delta away from the current
# center point.
subvol_sel = [slice(c - dx, c + dx + 1) for c, dx
in zip(center, deltas)]
subvol_sel = [slice(c - dx, c + dx + 1) for c, dx in zip(center, deltas)]

done = set()
for axis, axis_delta in enumerate(deltas):
Expand Down Expand Up @@ -92,7 +100,7 @@ def get_scored_move_offsets(deltas, prob_map, threshold=0.9):
yield ret


class BaseMovementPolicy(object):
class BaseMovementPolicy:
"""Base class for movement policy queues.
The principal usage is to initialize once with the policy's parameters and
Expand Down Expand Up @@ -120,9 +128,12 @@ def __len__(self):
def __iter__(self):
return self

def next(self):
def __next__(self):
raise StopIteration()

def next(self):
return self.__next__()

def append(self, item):
self.scored_coords.append(item)

Expand All @@ -132,7 +143,7 @@ def update(self, prob_map, position):
Args:
prob_map: object probability map returned by the FFN (in logit space)
position: postiion of the center of the FoV where inference was performed
(z, y, x)
(z, y, x)
"""
raise NotImplementedError()

Expand Down Expand Up @@ -167,10 +178,10 @@ def reset_state(self, start_pos):
self._start_pos = start_pos

def get_state(self):
return [(self.scored_coords, self.done_rounded_coords)]
return [(self.scored_coords, self.done_rounded_coords, self._start_pos)]

def restore_state(self, state):
self.scored_coords, self.done_rounded_coords = state[0]
self.scored_coords, self.done_rounded_coords, self._start_pos = state[0]

def __next__(self):
"""Pops positions from queue until a valid one is found and returns it."""
Expand All @@ -186,16 +197,13 @@ def __next__(self):

return tuple(coord)

def next(self):
return self.__next__()

def quantize_pos(self, pos):
"""Quantizes the positions symmetrically to a grid downsampled by deltas."""
# Compute offset relative to the origin of the current segment and
# shift by half delta size. This ensures that all directions are treated
# approximately symmetrically -- i.e. the origin point lies in the middle of
# a cell of the quantized lattice, as opposed to a corner of that cell.
rel_pos = (np.array(pos) - self._start_pos)
rel_pos = np.array(pos) - self._start_pos
coord = (rel_pos + self.deltas // 2) // np.maximum(self.deltas, 1)
return tuple(coord)

Expand All @@ -204,16 +212,17 @@ def update(self, prob_map, position):
qpos = self.quantize_pos(position)
self.done_rounded_coords.add(qpos)

scored_coords = get_scored_move_offsets(self.deltas, prob_map,
threshold=self.score_threshold)
scored_coords = get_scored_move_offsets(
self.deltas, prob_map, threshold=self.score_threshold
)
scored_coords = sorted(scored_coords, reverse=True)
for score, rel_coord in scored_coords:
# convert to whole cube coordinates
coord = [rel_coord[i] + position[i] for i in range(3)]
self.scored_coords.append((score, coord))


def get_policy_fn(request, ffn_model):
def get_policy_fn(request, model_info: ffn_model.ModelInfo):
"""Returns a policy class based on the InferenceRequest proto."""

if request.movement_policy_name:
Expand All @@ -228,41 +237,50 @@ def get_policy_fn(request, ffn_model):
else:
kwargs = {}
if 'deltas' not in kwargs:
kwargs['deltas'] = ffn_model.deltas[::-1]
kwargs['deltas'] = model_info.deltas[::-1]
if 'score_threshold' not in kwargs:
kwargs['score_threshold'] = logit(request.inference_options.move_threshold)

return lambda canvas: movement_policy_class(canvas, **kwargs)


class MovementRestrictor(object):
class MovementRestrictor:
"""Restricts the movement of the FFN FoV."""

def __init__(self, mask=None, shift_mask=None, shift_mask_fov=None,
shift_mask_threshold=4, shift_mask_scale=1, seed_mask=None):
def __init__(
self,
mask: Optional[np.ndarray] = None,
shift_mask: Optional[np.ndarray] = None,
shift_mask_fov: Optional[bounding_box.BoundingBox] = None,
shift_mask_threshold: int = 4,
shift_mask_scale: int = 1,
seed_mask: Optional[np.ndarray] = None,
):
"""Initializes the restrictor.
Args:
mask: 3d ndarray-like of shape (z, y, x); positive values indicate voxels
that are not going to be segmented
that are not going to be segmented
shift_mask: 4d ndarray-like of shape (2, z, y, x) representing a 2d shift
vector field
vector field
shift_mask_fov: bounding_box.BoundingBox around large shifts in which to
restrict movement. BoundingBox specified as XYZ, start can be
negative.
shift_mask_threshold: if any component of the shift vector exceeds this
value within the FoV, the location will not be segmented
restrict movement. BoundingBox specified as XYZ, start can be negative.
shift_mask_threshold: if any component of the shift vector equals or
exceeds this value within the FoV, the location will not be segmented
shift_mask_scale: an integer factor specifying how much larger the pixels
of the shift mask are compared to the data set processed by the FFN
of the shift mask are compared to the data set processed by the FFN
seed_mask: 3d ndarray-like of shape (z, y, x); positive values indicate
voxels where seeds are not going to be placed
"""
self.mask = mask
self.seed_mask = seed_mask

self._shift_mask_scale = shift_mask_scale
self.shift_mask = None
if shift_mask is not None:
self.shift_mask = (np.max(np.abs(shift_mask), axis=0) >=
shift_mask_threshold)
self.shift_mask = (
np.max(np.abs(shift_mask), axis=0) >= shift_mask_threshold
)

assert shift_mask_fov is not None
self._shift_mask_fov_pre_offset = shift_mask_fov.start[::-1]
Expand Down Expand Up @@ -306,9 +324,13 @@ def is_valid_pos(self, pos):
# Do not allow movement through highly distorted areas, which often
# result in merge errors. In the simplest case, the distortion magnitude
# is quantified with a patch-based cross-correlation map.
if np.any(self.shift_mask[fov_low[0]:(fov_high[0] + 1),
start[1]:(end[1] + 1),
start[2]:(end[2] + 1)]):
if np.any(
self.shift_mask[
fov_low[0] : (fov_high[0] + 1),
start[1] : (end[1] + 1),
start[2] : (end[2] + 1),
]
):
return False

return True
Loading

0 comments on commit 150f0d1

Please sign in to comment.