diff --git a/.travis.yml b/.travis.yml index 9075656ce..b685005dd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,22 @@ language: python python: - - "3.6" +- '3.6' before_install: - - sudo apt-get install graphviz +- sudo apt-get install graphviz install: - - pip install -r requirements.txt --quiet +- pip install -r requirements.txt --quiet script: - - pytest tests --cov=autokeras +- pytest tests --cov=autokeras after_success: - - pip install python-coveralls && coveralls +- pip install python-coveralls && coveralls +deploy: + provider: pypi + user: "jhfjhfj1" + password: + secure: U8U+CHDkQlCv6s0ZjbwAqW1RWxU8sV+EOeXVfCIOFJt9Igk16/Onoi8Bbdc52w3rPlgdtR78MxLgR6HsaVWSoZn0dardl6C/F3u+aGzh3dsQF2lfEH+Dl4Y+EasPmo4uKsH7rttT2kpn+w8zmL4IZz/W/ML1WkXv0Dhz6iH1uPS2dsC+uxdfz+W1iii/BACaPPzZdu79N+QBepRJs3VurXyB9UpV/FmsYSTB/2eP1yD/AzkV4RbhfSVRXDqIwvwpHOBqpY/4GJFONkC7oi7JtOMs/6ZjHw/2bpIziWPK1s7HbrYtqfs9w5RG279iwakU94O45xh2v73E++/oJtRqpFAGHAw0SNhm7dprSEU/EwJUaNKj2HGRmIQ2dvr56h3W1lsteS1GR9sNcHrR/AtA+sEiHDFUxlhALEVb9xI7Yk6Bl7/rpoqIBWKSyeu7LomGoCEs6ZE2ljPcRB8OIZSE7FzjiPpWfxabwPG84XufhR834z1C3qnepTkDDRYM/5M954lEbBp8eB8GJGj51DMNifXfHvtJosQPLChFrDrZlRYKLcvKuYASFeJc/v2+Q1p1vPyjHy1eMyM+bES57qBK1KbMVa+sY7dre26b0NMpcUpk6YczR7AIiD330ME7QHglJAHaacXUyBIfflRnPvmsAwB1HY0waH4/OvTFYMU+TLo= + on: + tags: true + branch: master + repo: jhfjhfj1/autokeras + python: 3.6 + diff --git a/autokeras/bayesian.py b/autokeras/bayesian.py index eaf27a52e..9acba7117 100644 --- a/autokeras/bayesian.py +++ b/autokeras/bayesian.py @@ -245,15 +245,15 @@ def optimize_acq(self, model_ids, descriptors, timeout): t_min = self.t_min alpha = 0.9 opt_acq = self._get_init_opt_acq_value() - remaining_time = timeout - (time.time() - start_time) + remaining_time = timeout while not pq.empty() and t > t_min and remaining_time > 0: elem = pq.get() if self.metric.higher_better(): - temp_exp = min((elem.metric_value - opt_acq) / t, 709.0) + temp_exp = min((elem.metric_value - opt_acq) / t, 1.0) else: - temp_exp = min((opt_acq - elem.metric_value) / t, 709.0) + temp_exp = min((opt_acq - elem.metric_value) / t, 1.0) ap = math.exp(temp_exp) - if ap > random.uniform(0, 1): + if ap >= random.uniform(0, 1): for temp_graph in transform(elem.graph): if contain(descriptors, temp_graph.extract_descriptor()): continue diff --git a/autokeras/constant.py b/autokeras/constant.py index 999c6f3e1..83c6c1a06 100644 --- a/autokeras/constant.py +++ b/autokeras/constant.py @@ -10,8 +10,8 @@ class Constant: KERNEL_LAMBDA = 0.1 T_MIN = 0.0001 N_NEIGHBOURS = 8 - # T_MIN = 0.8 - # N_NEIGHBOURS = 1 + MAX_MODEL_WIDTH = 2048 + MAX_MODEL_DEPTH = 15 # Model Defaults diff --git a/autokeras/net_transformer.py b/autokeras/net_transformer.py index d6488bfcd..49bb6936d 100644 --- a/autokeras/net_transformer.py +++ b/autokeras/net_transformer.py @@ -5,11 +5,16 @@ from autokeras.graph import NetworkDescriptor from autokeras.constant import Constant -from autokeras.layers import is_layer +from autokeras.layers import is_layer, layer_width def to_wider_graph(graph): weighted_layer_ids = graph.wide_layer_ids() + weighted_layer_ids = list(filter(lambda x: layer_width(graph.layer_list[x]) * 2 <= Constant.MAX_MODEL_WIDTH, + weighted_layer_ids)) + + if len(weighted_layer_ids) == 0: + return None # n_wider_layer = randint(1, len(weighted_layer_ids)) # wider_layers = sample(weighted_layer_ids, n_wider_layer) wider_layers = sample(weighted_layer_ids, 1) @@ -56,6 +61,9 @@ def to_skip_connection_graph(graph): def to_deeper_graph(graph): weighted_layer_ids = graph.deep_layer_ids() + if len(weighted_layer_ids) >= Constant.MAX_MODEL_DEPTH: + return None + deeper_layer_ids = sample(weighted_layer_ids, 1) # n_deeper_layer = randint(1, len(weighted_layer_ids)) # deeper_layer_ids = sample(weighted_layer_ids, n_deeper_layer) @@ -87,6 +95,7 @@ def transform(graph): graphs.append(to_wider_graph(deepcopy(graph))) elif a == 2: graphs.append(to_skip_connection_graph(deepcopy(graph))) + graphs = list(filter(lambda x: x, graphs)) return list(filter(lambda x: legal_graph(x), graphs)) diff --git a/autokeras/search.py b/autokeras/search.py index 0ecc1f6f5..297f9fe20 100644 --- a/autokeras/search.py +++ b/autokeras/search.py @@ -9,7 +9,7 @@ from autokeras.net_transformer import default_transform from autokeras.utils import ModelTrainer, pickle_to_file, pickle_from_file -import multiprocessing +import torch.multiprocessing as mp class Searcher: @@ -161,8 +161,8 @@ def search(self, train_data, test_data, timeout=60 * 60 * 24): graph, father_id, model_id = self.training_queue.pop(0) if self.verbose: print('Training model ', model_id) - multiprocessing.set_start_method('spawn', force=True) - pool = multiprocessing.Pool(1) + mp.set_start_method('spawn', force=True) + pool = mp.Pool(1) train_results = pool.map_async(train, [(graph, train_data, test_data, self.trainer_args, os.path.join(self.path, str(model_id) + '.png'), self.metric, self.loss, self.verbose)]) @@ -189,13 +189,12 @@ def search(self, train_data, test_data, timeout=60 * 60 * 24): metric_value, loss, graph = train_results.get(timeout=remaining_time)[0] else: raise TimeoutError - except (multiprocessing.TimeoutError, TimeoutError) as e: + except (mp.TimeoutError, TimeoutError) as e: raise TimeoutError from e finally: # terminate and join the subprocess to prevent any resource leak - pool.terminate() + pool.close() pool.join() - self.add_model(metric_value, loss, graph, model_id) self.search_tree.add_child(father_id, model_id) self.bo.fit(self.x_queue, self.y_queue) @@ -252,14 +251,14 @@ def train(args): model = graph.produce_model() # if path is not None: # plot_model(model, to_file=path, show_shapes=True) - loss, metric_value = ModelTrainer(model, + loss, mertic_value = ModelTrainer(model, train_data, test_data, metric, loss, verbose).train_model(**trainer_args) model.set_weight_to_graph() - return metric_value, loss, model.graph + return mertic_value, loss, model.graph def same_graph(des1, des2): diff --git a/autokeras/utils.py b/autokeras/utils.py index fc540d92b..86ab47a57 100644 --- a/autokeras/utils.py +++ b/autokeras/utils.py @@ -153,7 +153,8 @@ def _test(self): for batch_idx, (inputs, targets) in enumerate(deepcopy(loader)): inputs, targets = inputs.to(self.device), targets.to(self.device) outputs = self.model(inputs) - test_loss += self.loss_function(outputs, targets) + # cast tensor to float + test_loss += float(self.loss_function(outputs, targets)) all_predicted.append(outputs.cpu().numpy()) all_targets.append(targets.cpu().numpy()) diff --git a/setup.py b/setup.py index fc25435fa..f4635b11e 100644 --- a/setup.py +++ b/setup.py @@ -5,12 +5,12 @@ packages=['autokeras'], # this must be the same as the name above install_requires=['torch==0.4.1', 'torchvision==0.2.1', 'numpy==1.14.5', 'keras==2.2.2', 'scikit-learn==0.19.1', 'tensorflow==1.8.0'], - version='0.2.5', + version='0.2.6', description='AutoML for deep learning', author='Haifeng Jin', author_email='jhfjhfj1@gmail.com', url='http://autokeras.com', - download_url='https://github.com/jhfjhfj1/autokeras/archive/0.2.5.tar.gz', + download_url='https://github.com/jhfjhfj1/autokeras/archive/0.2.6.tar.gz', keywords=['automl'], # arbitrary keywords classifiers=[] ) diff --git a/tests/common.py b/tests/common.py index b3a785de4..dc89d3626 100644 --- a/tests/common.py +++ b/tests/common.py @@ -243,8 +243,10 @@ def get(self, timeout=None): def terminate(self): pass + def close(self): + pass + def simple_transform(graph): graph.to_wider_model(5, 64) return [deepcopy(graph)] - diff --git a/tests/test_image_supervised.py b/tests/test_image_supervised.py index ef7a20252..91a702f0b 100644 --- a/tests/test_image_supervised.py +++ b/tests/test_image_supervised.py @@ -33,7 +33,7 @@ def test_x_float_exception(): assert str(info.value) == 'x_train should only contain numerical data.' -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_fit_predict(_): Constant.MAX_ITER_NUM = 1 @@ -52,11 +52,11 @@ def test_fit_predict(_): clean_dir(path) -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) def test_timeout(): # Constant.MAX_MODEL_NUM = 4 Constant.SEARCH_MAX_ITER = 1000 - Constant.T_MIN = 0.8 + Constant.T_MIN = 0.0001 Constant.DATA_AUGMENTATION = False path = 'tests/resources/temp' clean_dir(path) @@ -68,7 +68,7 @@ def test_timeout(): clean_dir(path) -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_timeout_resume(_): Constant.MAX_ITER_NUM = 1 @@ -99,7 +99,7 @@ def test_timeout_resume(_): clean_dir(path) -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.bayesian.transform', side_effect=simple_transform) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_final_fit(_, _1): @@ -123,7 +123,7 @@ def test_final_fit(_, _1): clean_dir(path) -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_save_continue(_): Constant.MAX_ITER_NUM = 1 @@ -156,7 +156,7 @@ def test_save_continue(_): clean_dir(path) -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.bayesian.transform', side_effect=simple_transform) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_fit_csv_file(_, _1): @@ -182,7 +182,7 @@ def test_init_image_classifier_with_none_path(_): assert clf.path == 'dummy_path/' -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_fit_predict_regression(_): Constant.MAX_ITER_NUM = 1 diff --git a/tests/test_net_transformer.py b/tests/test_net_transformer.py index 0b2a1dda0..3bb2b25d0 100644 --- a/tests/test_net_transformer.py +++ b/tests/test_net_transformer.py @@ -53,3 +53,11 @@ def test_default_transform(): model(torch.Tensor(get_conv_data())) assert len(graphs) == 1 assert len(graphs[0].layer_list) == 43 + + +def test_search_space_limit(): + graph = CnnGenerator(10, (32, 32, 3)).generate(model_len=3, model_width=2048) + assert to_wider_graph(graph) is None + + graph = CnnGenerator(10, (32, 32, 3)).generate(model_len=14) + assert to_deeper_graph(graph) is None diff --git a/tests/test_search.py b/tests/test_search.py index 2a53f8a39..3691356e2 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -14,7 +14,7 @@ def mock_train(**_): return 1, 0 -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.bayesian.transform', side_effect=simple_transform) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_bayesian_searcher(_, _1): @@ -38,7 +38,7 @@ def test_search_tree(): assert len(tree.adj_list) == 3 -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.bayesian.transform', side_effect=simple_transform) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_export_json(_, _1): @@ -66,7 +66,7 @@ def test_graph_duplicate(): assert not same_graph(get_concat_skip_model().extract_descriptor(), get_add_skip_model().extract_descriptor()) -@patch('multiprocessing.Pool', new=MockProcess) +@patch('torch.multiprocessing.Pool', new=MockProcess) @patch('autokeras.bayesian.transform', side_effect=simple_transform) @patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train) def test_max_acq(_, _1):