Skip to content

Commit

Permalink
Fix nested dask #1097 (#1098)
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay authored Aug 9, 2023
1 parent 4190078 commit 76c8ba3
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ jobs:
test-long-algos:
needs: [pre-commit, pretest]
runs-on: ${{ matrix.platform }}
continue-on-error: true
strategy:
matrix:
platform: [ubuntu-latest]
Expand Down
3 changes: 2 additions & 1 deletion src/orion/client/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import signal
import time
import typing
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable
Expand Down Expand Up @@ -61,7 +62,7 @@ def __enter__(self):
self.signal_installed = True

except ValueError: # ValueError: signal only works in main thread
log.warning(
warnings.warn(
"SIGINT/SIGTERM protection hooks could not be installed because "
"Runner is executing inside a thread/subprocess, results could get lost "
"on interruptions"
Expand Down
6 changes: 4 additions & 2 deletions src/orion/executor/dask_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ def __init__(self, n_workers=-1, client=None, **config):
self.client = client

def __getstate__(self):
return super().__getstate__()
state = super().__getstate__()
state["address"] = self.client.cluster.scheduler_address
return state

def __setstate__(self, state):
super().__setstate__(state)
self.client = get_client()
self.client = get_client(address=state["address"])

@property
def in_worker(self):
Expand Down
11 changes: 7 additions & 4 deletions src/orion/testing/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class BaseAlgoTests:
# Fixtures available as class attributes:
phase: ClassVar[pytest.fixture] # type: ignore

objective: float = 10
"""Target objective"""

def __init_subclass__(cls) -> None:

# Set the `algo_type` attribute, if necessary.
Expand Down Expand Up @@ -230,13 +233,13 @@ def phase(cls, request: pytest.FixtureRequest):
@pytest.fixture()
def first_phase(self, phase: TestPhase):
if phase != type(self).phases[0]:
pytest.skip(reason="Test runs only on first phase.")
pytest.skip("Test runs only on first phase.")
return phase

@pytest.fixture()
def last_phase(self, phase: TestPhase):
if phase != type(self).phases[-1]:
pytest.skip(reason="Test runs only on last phase.")
pytest.skip("Test runs only on last phase.")
return phase

@classmethod
Expand Down Expand Up @@ -504,7 +507,7 @@ def test_seed_rng_init(self):
suggested trials are reproducible.
"""
if "seed" not in inspect.signature(self.algo_type).parameters:
pytest.skip(reason="algo does not have a seed as a constructor argument.")
pytest.skip("algo does not have a seed as a constructor argument.")

config = self.config.copy()
config["seed"] = 1
Expand Down Expand Up @@ -777,7 +780,7 @@ def test_optimize_branin(self):
)

assert algo.is_done
assert min(all_objectives) <= 10
assert min(all_objectives) <= self.objective


class BaseParallelStrategyTests:
Expand Down
3 changes: 3 additions & 0 deletions tests/unittests/algo/long/ax/test_axoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class TestAxOptimizer(BaseAlgoTests):
TestPhase("BO", N_INIT, "space.sample"),
]

# Ax is nto always that good
objective = 12

@first_phase_only
def test_configuration_fail(self):
"""Test that Ax configuration is valid"""
Expand Down
3 changes: 3 additions & 0 deletions tests/unittests/algo/long/nevergrad/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ class TestNevergradOptimizer(BaseAlgoTests):
TestPhase("optim", TEST_MANY_TRIALS, "space.sample"),
]

# Nevergrad is not very good
objective = 12

def test_normal_data(self):
"""Test that algorithm supports normal dimensions"""
self.assert_dim_type_supported({"x": "normal(2, 5)"})
Expand Down

0 comments on commit 76c8ba3

Please sign in to comment.