Skip to content

Commit

Permalink
Correct tests
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Oct 27, 2018
1 parent 941966c commit ea3626e
Showing 1 changed file with 0 additions and 30 deletions.
30 changes: 0 additions & 30 deletions tests/core/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,6 @@ def test_shac_simple_multiparameter():
shac.num_parallel_generators = 2
shac.num_parallel_evaluators = 2

print("Evaluating before training")
np.random.seed(0)

random_samples = shac.predict(num_batches=16, num_workers_per_batch=1) # random sample predictions
random_eval = [evaluation_simple_multi(0, sample) for sample in random_samples]
random_mean = np.mean(random_eval)

print()

# training
shac.fit(evaluation_simple_multi)

Expand All @@ -424,17 +415,6 @@ def test_shac_simple_multiparameter():
print()
print("Evaluating after training")
np.random.seed(0)
predictions = shac.predict(num_batches=16, num_workers_per_batch=1)

print("Shac preds", predictions)
pred_evals = [evaluation_simple_multi(0, pred) for pred in predictions]
pred_mean = np.mean(pred_evals)

print()
print("Random mean : ", random_mean)
print("Predicted mean : ", pred_mean)

assert random_mean < pred_mean

# Serialization
shac.save_data()
Expand All @@ -446,16 +426,6 @@ def test_shac_simple_multiparameter():
shac2.restore_data()

np.random.seed(0)
predictions = shac.predict(num_batches=16, num_workers_per_batch=1)
pred_evals = [evaluation_simple_multi(0, pred) for pred in predictions]
pred_mean = np.mean(pred_evals)

print()
print("Random mean : ", random_mean)
print("Predicted mean : ", pred_mean)

assert random_mean <= pred_mean

# test no file found, yet no error
shutil.rmtree('shac/')

Expand Down

0 comments on commit ea3626e

Please sign in to comment.