Skip to content

Commit

Permalink
Reset future after cleanup (#1115)
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay authored Aug 17, 2023
1 parent f485f62 commit e7d03c9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/orion/client/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,23 +369,22 @@ def scatter(self, new_trials):
for trial in new_trials:
try:
self.prepare_trial(self.client, trial)
prepared = True
# pylint:disable=broad-except
except Exception as e:
future = self.client.executor.submit(delayed_exception, e)
prepared = False

if prepared:
future = self.client.executor.submit(
_optimize, trial, self.fct, self.trial_arg, **self.kwargs
)

# pylint:disable=broad-except
except Exception as e:
future = self.client.executor.submit(delayed_exception, e)

self.pending_trials[future] = trial
new_futures.append(future)

self.futures.extend(new_futures)
if new_futures:
log.debug("Scheduled new trials")

return len(new_futures)

def gather(self):
Expand All @@ -401,7 +400,10 @@ def gather(self):
# NOTE: For Ptera instrumentation
trials = 0 # pylint:disable=unused-variable
for result in results:
trial = self.pending_trials.pop(result.future)
trial = self.pending_trials.pop(result.future, None)

if trial is None:
log.warning(f"Future does not have a matching trial, {result}")

if isinstance(result, AsyncResult):
try:
Expand Down Expand Up @@ -462,13 +464,15 @@ def _release_all(self):
"""
# Sanity check
for _, trial in self.pending_trials.items():
for future, trial in self.pending_trials.items():
self.client.executor.cancel(future)
try:
self.client.release(trial, status="interrupted")
except AlreadyReleased:
pass

self.pending_trials = {}
self.futures = []

def _suggest_trials(self, count):
"""Suggest a bunch of trials to be dispatched to the workers"""
Expand Down
10 changes: 10 additions & 0 deletions src/orion/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ def __getstate__(self):
def __setstate__(self, state):
self.n_workers = state["n_workers"]

def cancel(self, future):
"""Cancel a given future
Parameters
----------
future: `concurrent.futures.Futures` or equivalent interface
The future to be cancelled
"""

def wait(self, futures):
"""Wait for all futures to complete execution.
Expand Down

0 comments on commit e7d03c9

Please sign in to comment.