Skip to content

Commit

Permalink
0.2.6 (#126)
Browse files Browse the repository at this point in the history
* 0.2.5 setup.py (#111)

* [WIP] Attempts to Fix Memory Error  (#112)

* Add Website Badge in README.md, apply timeout in search function in search.py

* Add timeout in maximize_acq function in search.py

* Update unit test to allow timeout to raise TimeoutError

* Add unit test for timeout resume

* Remove TimeoutError from expectation

* Check Timeout exception in search() in search.py

* 0.2.5 setup.py (#110)

* Prevent gpu memory copy to main process after train() finished

* Cast loss from tensor to float

* Add pass() in MockProcess

* [MRG] Search Space limited to avoid out of memory (#121)

* limited the search space

* limited the search space

* reduce search space

* test added

* [MRG]Pytorch mp (#124)

* Change multiprcoessing to torch.multiprocessing

* Replace multiprocessing.Pool with torch.multiprocessing.Pool in tests

* 0.2.6 (#125)

* new release

* auto deploy
  • Loading branch information
haifeng-jin authored Aug 23, 2018
1 parent 3c8d1c1 commit f1b5366
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 35 deletions.
21 changes: 16 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
language: python
python:
- "3.6"
- '3.6'
before_install:
- sudo apt-get install graphviz
- sudo apt-get install graphviz
install:
- pip install -r requirements.txt --quiet
- pip install -r requirements.txt --quiet
script:
- pytest tests --cov=autokeras
- pytest tests --cov=autokeras
after_success:
- pip install python-coveralls && coveralls
- pip install python-coveralls && coveralls
deploy:
provider: pypi
user: "jhfjhfj1"
password:
secure: U8U+CHDkQlCv6s0ZjbwAqW1RWxU8sV+EOeXVfCIOFJt9Igk16/Onoi8Bbdc52w3rPlgdtR78MxLgR6HsaVWSoZn0dardl6C/F3u+aGzh3dsQF2lfEH+Dl4Y+EasPmo4uKsH7rttT2kpn+w8zmL4IZz/W/ML1WkXv0Dhz6iH1uPS2dsC+uxdfz+W1iii/BACaPPzZdu79N+QBepRJs3VurXyB9UpV/FmsYSTB/2eP1yD/AzkV4RbhfSVRXDqIwvwpHOBqpY/4GJFONkC7oi7JtOMs/6ZjHw/2bpIziWPK1s7HbrYtqfs9w5RG279iwakU94O45xh2v73E++/oJtRqpFAGHAw0SNhm7dprSEU/EwJUaNKj2HGRmIQ2dvr56h3W1lsteS1GR9sNcHrR/AtA+sEiHDFUxlhALEVb9xI7Yk6Bl7/rpoqIBWKSyeu7LomGoCEs6ZE2ljPcRB8OIZSE7FzjiPpWfxabwPG84XufhR834z1C3qnepTkDDRYM/5M954lEbBp8eB8GJGj51DMNifXfHvtJosQPLChFrDrZlRYKLcvKuYASFeJc/v2+Q1p1vPyjHy1eMyM+bES57qBK1KbMVa+sY7dre26b0NMpcUpk6YczR7AIiD330ME7QHglJAHaacXUyBIfflRnPvmsAwB1HY0waH4/OvTFYMU+TLo=
on:
tags: true
branch: master
repo: jhfjhfj1/autokeras
python: 3.6

8 changes: 4 additions & 4 deletions autokeras/bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,15 @@ def optimize_acq(self, model_ids, descriptors, timeout):
t_min = self.t_min
alpha = 0.9
opt_acq = self._get_init_opt_acq_value()
remaining_time = timeout - (time.time() - start_time)
remaining_time = timeout
while not pq.empty() and t > t_min and remaining_time > 0:
elem = pq.get()
if self.metric.higher_better():
temp_exp = min((elem.metric_value - opt_acq) / t, 709.0)
temp_exp = min((elem.metric_value - opt_acq) / t, 1.0)
else:
temp_exp = min((opt_acq - elem.metric_value) / t, 709.0)
temp_exp = min((opt_acq - elem.metric_value) / t, 1.0)
ap = math.exp(temp_exp)
if ap > random.uniform(0, 1):
if ap >= random.uniform(0, 1):
for temp_graph in transform(elem.graph):
if contain(descriptors, temp_graph.extract_descriptor()):
continue
Expand Down
4 changes: 2 additions & 2 deletions autokeras/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class Constant:
KERNEL_LAMBDA = 0.1
T_MIN = 0.0001
N_NEIGHBOURS = 8
# T_MIN = 0.8
# N_NEIGHBOURS = 1
MAX_MODEL_WIDTH = 2048
MAX_MODEL_DEPTH = 15

# Model Defaults

Expand Down
11 changes: 10 additions & 1 deletion autokeras/net_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
from autokeras.graph import NetworkDescriptor

from autokeras.constant import Constant
from autokeras.layers import is_layer
from autokeras.layers import is_layer, layer_width


def to_wider_graph(graph):
weighted_layer_ids = graph.wide_layer_ids()
weighted_layer_ids = list(filter(lambda x: layer_width(graph.layer_list[x]) * 2 <= Constant.MAX_MODEL_WIDTH,
weighted_layer_ids))

if len(weighted_layer_ids) == 0:
return None
# n_wider_layer = randint(1, len(weighted_layer_ids))
# wider_layers = sample(weighted_layer_ids, n_wider_layer)
wider_layers = sample(weighted_layer_ids, 1)
Expand Down Expand Up @@ -56,6 +61,9 @@ def to_skip_connection_graph(graph):

def to_deeper_graph(graph):
weighted_layer_ids = graph.deep_layer_ids()
if len(weighted_layer_ids) >= Constant.MAX_MODEL_DEPTH:
return None

deeper_layer_ids = sample(weighted_layer_ids, 1)
# n_deeper_layer = randint(1, len(weighted_layer_ids))
# deeper_layer_ids = sample(weighted_layer_ids, n_deeper_layer)
Expand Down Expand Up @@ -87,6 +95,7 @@ def transform(graph):
graphs.append(to_wider_graph(deepcopy(graph)))
elif a == 2:
graphs.append(to_skip_connection_graph(deepcopy(graph)))
graphs = list(filter(lambda x: x, graphs))
return list(filter(lambda x: legal_graph(x), graphs))


Expand Down
15 changes: 7 additions & 8 deletions autokeras/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from autokeras.net_transformer import default_transform
from autokeras.utils import ModelTrainer, pickle_to_file, pickle_from_file

import multiprocessing
import torch.multiprocessing as mp


class Searcher:
Expand Down Expand Up @@ -161,8 +161,8 @@ def search(self, train_data, test_data, timeout=60 * 60 * 24):
graph, father_id, model_id = self.training_queue.pop(0)
if self.verbose:
print('Training model ', model_id)
multiprocessing.set_start_method('spawn', force=True)
pool = multiprocessing.Pool(1)
mp.set_start_method('spawn', force=True)
pool = mp.Pool(1)
train_results = pool.map_async(train, [(graph, train_data, test_data, self.trainer_args,
os.path.join(self.path, str(model_id) + '.png'),
self.metric, self.loss, self.verbose)])
Expand All @@ -189,13 +189,12 @@ def search(self, train_data, test_data, timeout=60 * 60 * 24):
metric_value, loss, graph = train_results.get(timeout=remaining_time)[0]
else:
raise TimeoutError
except (multiprocessing.TimeoutError, TimeoutError) as e:
except (mp.TimeoutError, TimeoutError) as e:
raise TimeoutError from e
finally:
# terminate and join the subprocess to prevent any resource leak
pool.terminate()
pool.close()
pool.join()

self.add_model(metric_value, loss, graph, model_id)
self.search_tree.add_child(father_id, model_id)
self.bo.fit(self.x_queue, self.y_queue)
Expand Down Expand Up @@ -252,14 +251,14 @@ def train(args):
model = graph.produce_model()
# if path is not None:
# plot_model(model, to_file=path, show_shapes=True)
loss, metric_value = ModelTrainer(model,
loss, mertic_value = ModelTrainer(model,
train_data,
test_data,
metric,
loss,
verbose).train_model(**trainer_args)
model.set_weight_to_graph()
return metric_value, loss, model.graph
return mertic_value, loss, model.graph


def same_graph(des1, des2):
Expand Down
3 changes: 2 additions & 1 deletion autokeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def _test(self):
for batch_idx, (inputs, targets) in enumerate(deepcopy(loader)):
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = self.model(inputs)
test_loss += self.loss_function(outputs, targets)
# cast tensor to float
test_loss += float(self.loss_function(outputs, targets))

all_predicted.append(outputs.cpu().numpy())
all_targets.append(targets.cpu().numpy())
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
packages=['autokeras'], # this must be the same as the name above
install_requires=['torch==0.4.1', 'torchvision==0.2.1', 'numpy==1.14.5', 'keras==2.2.2', 'scikit-learn==0.19.1',
'tensorflow==1.8.0'],
version='0.2.5',
version='0.2.6',
description='AutoML for deep learning',
author='Haifeng Jin',
author_email='[email protected]',
url='http://autokeras.com',
download_url='https://github.com/jhfjhfj1/autokeras/archive/0.2.5.tar.gz',
download_url='https://github.com/jhfjhfj1/autokeras/archive/0.2.6.tar.gz',
keywords=['automl'], # arbitrary keywords
classifiers=[]
)
4 changes: 3 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def get(self, timeout=None):
def terminate(self):
pass

def close(self):
pass


def simple_transform(graph):
graph.to_wider_model(5, 64)
return [deepcopy(graph)]

16 changes: 8 additions & 8 deletions tests/test_image_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_x_float_exception():
assert str(info.value) == 'x_train should only contain numerical data.'


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_fit_predict(_):
Constant.MAX_ITER_NUM = 1
Expand All @@ -52,11 +52,11 @@ def test_fit_predict(_):
clean_dir(path)


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
def test_timeout():
# Constant.MAX_MODEL_NUM = 4
Constant.SEARCH_MAX_ITER = 1000
Constant.T_MIN = 0.8
Constant.T_MIN = 0.0001
Constant.DATA_AUGMENTATION = False
path = 'tests/resources/temp'
clean_dir(path)
Expand All @@ -68,7 +68,7 @@ def test_timeout():
clean_dir(path)


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_timeout_resume(_):
Constant.MAX_ITER_NUM = 1
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_timeout_resume(_):
clean_dir(path)


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.bayesian.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_final_fit(_, _1):
Expand All @@ -123,7 +123,7 @@ def test_final_fit(_, _1):
clean_dir(path)


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_save_continue(_):
Constant.MAX_ITER_NUM = 1
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_save_continue(_):
clean_dir(path)


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.bayesian.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_fit_csv_file(_, _1):
Expand All @@ -182,7 +182,7 @@ def test_init_image_classifier_with_none_path(_):
assert clf.path == 'dummy_path/'


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_fit_predict_regression(_):
Constant.MAX_ITER_NUM = 1
Expand Down
8 changes: 8 additions & 0 deletions tests/test_net_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@ def test_default_transform():
model(torch.Tensor(get_conv_data()))
assert len(graphs) == 1
assert len(graphs[0].layer_list) == 43


def test_search_space_limit():
graph = CnnGenerator(10, (32, 32, 3)).generate(model_len=3, model_width=2048)
assert to_wider_graph(graph) is None

graph = CnnGenerator(10, (32, 32, 3)).generate(model_len=14)
assert to_deeper_graph(graph) is None
6 changes: 3 additions & 3 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def mock_train(**_):
return 1, 0


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.bayesian.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_bayesian_searcher(_, _1):
Expand All @@ -38,7 +38,7 @@ def test_search_tree():
assert len(tree.adj_list) == 3


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.bayesian.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_export_json(_, _1):
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_graph_duplicate():
assert not same_graph(get_concat_skip_model().extract_descriptor(), get_add_skip_model().extract_descriptor())


@patch('multiprocessing.Pool', new=MockProcess)
@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.bayesian.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_max_acq(_, _1):
Expand Down

0 comments on commit f1b5366

Please sign in to comment.