Skip to content

Commit

Permalink
Factor out Runner into a separate module.
Browse files Browse the repository at this point in the history
Remove halt signaling, self-consistency looping, and histogram matching.

PiperOrigin-RevId: 600796058
  • Loading branch information
mjanusz authored and copybara-github committed Jan 23, 2024
1 parent f92852d commit 721001a
Show file tree
Hide file tree
Showing 11 changed files with 1,047 additions and 1,054 deletions.
42 changes: 26 additions & 16 deletions ffn/inference/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def __init__(self, model, session, counters, batch_size):
self.active_clients = 0

# Cache input/output sizes.
self._input_seed_size = np.array(model.input_seed_size[::-1]).tolist()
self._input_image_size = np.array(model.input_image_size[::-1]).tolist()
self._pred_size = np.array(model.pred_mask_size[::-1]).tolist()
self._input_seed_size = np.array(model.info.input_seed_size[::-1]).tolist()
self._input_image_size = np.array(
model.info.input_image_size[::-1]
).tolist()
self._pred_size = np.array(model.info.pred_mask_size[::-1]).tolist()

self._initialize_model()

Expand Down Expand Up @@ -111,8 +113,9 @@ class ThreadingBatchExecutor(BatchExecutor):
"""

def __init__(self, model, session, counters, batch_size, expected_clients=1):
super(ThreadingBatchExecutor, self).__init__(model, session, counters,
batch_size)
super(ThreadingBatchExecutor, self).__init__(
model, session, counters, batch_size
)
self._lock = threading.Lock()
self.outputs = {} # Will be populated by Queues as clients register.
# Used by clients to communiate with the executor. The protocol is
Expand All @@ -131,10 +134,12 @@ def __init__(self, model, session, counters, batch_size, expected_clients=1):
self.expected_clients = expected_clients

# Arrays fed to TF.
self.input_seed = np.zeros([batch_size] + self._input_seed_size + [1],
dtype=np.float32)
self.input_image = np.zeros([batch_size] + self._input_image_size + [1],
dtype=np.float32)
self.input_seed = np.zeros(
[batch_size] + self._input_seed_size + [1], dtype=np.float32
)
self.input_image = np.zeros(
[batch_size] + self._input_image_size + [1], dtype=np.float32
)
self.th_executor = None

def start_server(self):
Expand All @@ -146,7 +151,8 @@ def start_server(self):
"""
if self.th_executor is None:
self.th_executor = threading.Thread(
target=self._run_executor_log_exceptions)
target=self._run_executor_log_exceptions
)
self.th_executor.start()

def stop_server(self):
Expand All @@ -166,8 +172,10 @@ def _run_executor(self):

with timer_counter(self.counters, 'executor-input'):
ready = []
while (len(ready) < min(self.active_clients, self.batch_size) or
not self.active_clients):
while (
len(ready) < min(self.active_clients, self.batch_size)
or not self.active_clients
):
try:
data = self.input_queue.get(timeout=5)
except queue.Empty:
Expand Down Expand Up @@ -201,9 +209,12 @@ def _schedule_batch(self, client_ids, fetches):
with timer_counter(self.counters, 'executor-inference'):
try:
ret = self.session.run(
fetches, {
fetches,
{
self.model.input_seed: self.input_seed,
self.model.input_patches: self.input_image})
self.model.input_patches: self.input_image,
},
)
except Exception as e: # pylint:disable=broad-except
logging.exception(e)
# If calling TF didn't work (faulty hardware, misconfiguration, etc),
Expand All @@ -215,8 +226,7 @@ def _schedule_batch(self, client_ids, fetches):
with self._lock:
for i, client_id in enumerate(client_ids):
try:
self.outputs[client_id].put(
{k: v[i, ...] for k, v in ret.items()})
self.outputs[client_id].put({k: v[i, ...] for k, v in ret.items()})
except KeyError:
# This could happen if a client unregistered itself
# while inference was running.
Expand Down
Loading

0 comments on commit 721001a

Please sign in to comment.