Skip to content

Commit

Permalink
[Improvement] Raise RuntimeError if predict is used with no parameters.
Browse files Browse the repository at this point in the history
[Improvement] None can now be used as a value for Discrete Hyper parameters.

[Bugfix] None as a value for discrete hyper parameters caused crashes, but not caught in tests.
  • Loading branch information
titu1994 committed Oct 29, 2018
1 parent 5d9880c commit 0dc9718
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 8 deletions.
20 changes: 18 additions & 2 deletions pyshac/config/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
_CUSTOM_PARAMETERS = OrderedDict()


class _NoneTypeWrapper(object):
"""
A wrapper to handle cases when `None` is passed as a possible parameter
value to the engine.
"""
def __init__(self):
pass

def __call__(self, *args, **kwargs):
return args[0]


class AbstractHyperParameter(ABC):
"""
Abstract Hyper Parameter that defines the methods that all hyperparameters
Expand Down Expand Up @@ -143,8 +155,12 @@ def _build_maps(self, values):
self.id2param[i] = v

# prepare a type map from string to its type, for fast checks
self.param2type[v] = type(v)
self.param2type[str(v)] = type(v)
if v is not None:
self.param2type[v] = type(v)
self.param2type[str(v)] = type(v)
else:
self.param2type[v] = _NoneTypeWrapper()
self.param2type[str(v)] = _NoneTypeWrapper()

def __repr__(self):
s = self.name + " : "
Expand Down
7 changes: 7 additions & 0 deletions pyshac/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,17 @@ def predict(self, num_samples=None, num_batches=None, num_workers_per_batch=None
# Raises:
ValueError: If `max_classifiers` is larger than the number of available
classifiers.
RuntimeError: If `classifiers` are not available, either due to not being
trained or not being loaded into the engine.
# Returns:
batches of samples in the form of an OrderedDict
"""
if self.parameters is None:
raise RuntimeError("Unable to find any parameters. Please make sure "
"to set the parameters for the engine first, or "
"load them into the engine prior to calling `predict`")

if max_classfiers is not None and max_classfiers > len(self.classifiers):
raise ValueError("Maximum number of classifiers (%d) must be less than the number of "
"classifiers (%d)" % (max_classfiers, len(self.classifiers)))
Expand Down
35 changes: 29 additions & 6 deletions tests/config/test_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,27 @@ def test_discrete_encode_decode():
decoded = h1.decode(encoded)
assert decoded == values[encoded]

# Test for None input
values = [None, 1, 2, 3]
h2 = hp.DiscreteHyperParameter('h1', values)
sample = None

encoded = h2.encode(sample)
assert encoded == 0

decoded = h2.decode(encoded)
assert decoded == values[encoded]


def test_discrete_serialization_deserialization():
h1 = hp.DiscreteHyperParameter('h1', [0, 1])
h1 = hp.DiscreteHyperParameter('h1', [0, 1, None])

config = h1.get_config()
assert 'name' in config
assert 'values' in config

values = config['values']
assert len(values) == 2
assert len(values) == 3

h2 = hp.DiscreteHyperParameter.load_from_config(config)
config = h2.get_config()
Expand All @@ -104,7 +115,7 @@ def test_discrete_serialization_deserialization():
assert 'values' in config

values = config['values']
assert len(values) == 2
assert len(values) == 3


def test_multi_discrete():
Expand Down Expand Up @@ -156,17 +167,29 @@ def test_multi_discrete_encode_decode():
for i in range(len(decoded)):
assert decoded[i] == values[encoded[i]]

# Test for None input
values = [None, 1, 2, 3]
h2 = hp.MultiDiscreteHyperParameter('h1', values, sample_count=10)
sample = h2.sample()

encoded = h2.encode(sample)
assert encoded == [3, 1, 3, 1, 2, 0, 3, 2, 0, 0]

decoded = h2.decode(encoded)
for i in range(len(decoded)):
assert decoded[i] == values[encoded[i]]


def test_multi_discrete_serialization_deserialization():
h1 = hp.MultiDiscreteHyperParameter('h1', [0, 1], sample_count=5)
h1 = hp.MultiDiscreteHyperParameter('h1', [0, 1, None], sample_count=5)

config = h1.get_config()
assert 'name' in config
assert 'values' in config
assert 'sample_count' in config

values = config['values']
assert len(values) == 2
assert len(values) == 3
assert config['sample_count'] == 5

h2 = hp.MultiDiscreteHyperParameter.load_from_config(config)
Expand All @@ -177,7 +200,7 @@ def test_multi_discrete_serialization_deserialization():
assert 'sample_count' in config

values = config['values']
assert len(values) == 2
assert len(values) == 3
assert config['sample_count'] == 5


Expand Down
7 changes: 7 additions & 0 deletions tests/core/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ def test_shac_initialization():
with pytest.raises(ValueError):
shac.evaluator_backend = 'random'

shac = engine.SHAC(None, total_budget=total_budget,
num_batches=batch_size, objective=objective)

# No parameters
with pytest.raises(RuntimeError):
shac.predict()


@optimizer_wrapper
def test_shac_simple():
Expand Down

0 comments on commit 0dc9718

Please sign in to comment.