Skip to content

Commit

Permalink
Develop (#146)
Browse files Browse the repository at this point in the history
* merge (#128)

* 0.2.6 (#126)

* 0.2.5 setup.py (#111)

* [WIP] Attempts to Fix Memory Error  (#112)

* Add Website Badge in README.md, apply timeout in search function in search.py

* Add timeout in maximize_acq function in search.py

* Update unit test to allow timeout to raise TimeoutError

* Add unit test for timeout resume

* Remove TimeoutError from expectation

* Check Timeout exception in search() in search.py

* 0.2.5 setup.py (#110)

* Prevent gpu memory copy to main process after train() finished

* Cast loss from tensor to float

* Add pass() in MockProcess

* [MRG] Search Space limited to avoid out of memory (#121)

* limited the search space

* limited the search space

* reduce search space

* test added

* [MRG]Pytorch mp (#124)

* Change multiprcoessing to torch.multiprocessing

* Replace multiprocessing.Pool with torch.multiprocessing.Pool in tests

* 0.2.6 (#125)

* new release

* auto deploy

* auto deploy of docs

* fix the docs auto deploy

* Create CNAME

* deploy docs fixed

* update

* bug fix (#127)

* setup.py

* rm print

* Issue#37 and Issue #79 Save keras model/autokeras model (#122)

* Issue #37 Export Keras model

* Issue #79 Save autokeras model

* Issue #37 and Issue#79 Fixed comments

* Issue #37 and Issue #79

* Issue #37 and Issue #79

* Issue #37 and Issue #79 Fixed pytests

* Issue #37 and Issue #79

* quick fix test

* Progbar (#143)

* 0.2.6 (#126)

* 0.2.5 setup.py (#111)

* [WIP] Attempts to Fix Memory Error  (#112)

* Add Website Badge in README.md, apply timeout in search function in search.py

* Add timeout in maximize_acq function in search.py

* Update unit test to allow timeout to raise TimeoutError

* Add unit test for timeout resume

* Remove TimeoutError from expectation

* Check Timeout exception in search() in search.py

* 0.2.5 setup.py (#110)

* Prevent gpu memory copy to main process after train() finished

* Cast loss from tensor to float

* Add pass() in MockProcess

* [MRG] Search Space limited to avoid out of memory (#121)

* limited the search space

* limited the search space

* reduce search space

* test added

* [MRG]Pytorch mp (#124)

* Change multiprcoessing to torch.multiprocessing

* Replace multiprocessing.Pool with torch.multiprocessing.Pool in tests

* 0.2.6 (#125)

* new release

* auto deploy

* auto deploy of docs

* fix the docs auto deploy

* Create CNAME

* deploy docs fixed

* update

* bug fix (#127)

* setup.py

* contribute guide

* Add Progress Bar

* Update utils.py

* Update search.py

* update constant (#145)
  • Loading branch information
haifeng-jin authored Aug 26, 2018
1 parent b6a9812 commit 112dddb
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 26 deletions.
6 changes: 3 additions & 3 deletions autokeras/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ class Constant:
# Searcher

MAX_MODEL_NUM = 1000
BETA = 10.576
BETA = 2.576
KERNEL_LAMBDA = 0.1
T_MIN = 0.0001
N_NEIGHBOURS = 8
MAX_MODEL_WIDTH = 2048
MAX_MODEL_DEPTH = 15
MAX_MODEL_WIDTH = 1024
MAX_MODEL_DEPTH = 10

# Model Defaults

Expand Down
1 change: 0 additions & 1 deletion autokeras/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,6 @@ def extract_descriptor(self):

# The position of each node, how many Conv and Dense layers before it.
pos = [0] * len(topological_node_list)
print(sorted(topological_node_list))
for v in topological_node_list:
layer_count = 0
for u, layer_id in self.reverse_adj_list[v]:
Expand Down
53 changes: 52 additions & 1 deletion autokeras/image_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sklearn.model_selection import train_test_split

from autokeras.loss_function import classification_loss, regression_loss
from autokeras.supervised import Supervised
from autokeras.supervised import Supervised, PortableClass
from autokeras.constant import Constant
from autokeras.metric import Accuracy, MSE
from autokeras.preprocessor import OneHotEncoder, DataTransformer
Expand Down Expand Up @@ -323,6 +323,16 @@ def get_best_model_id(self):
""" Return an integer indicating the id of the best model."""
return self.load_searcher().get_best_model_id()

def export_keras_model(self, model_file_name):
""" Exports the best Keras model to the given filename. """
self.load_searcher().load_best_model().produce_keras_model().save(model_file_name)

def export_autokeras_model(self, model_file_name):
""" Creates and Exports the AutoKeras model to the given filename. """
portable_model = PortableImageSupervised(graph=self.load_searcher().load_best_model(), \
y_encoder=self.y_encoder, data_transformer=self.data_transformer)
pickle_to_file(portable_model, model_file_name)


class ImageClassifier(ImageSupervised):
@property
Expand Down Expand Up @@ -365,3 +375,44 @@ def transform_y(self, y_train):

def inverse_transform_y(self, output):
return output.flatten()


class PortableImageSupervised(PortableClass):
def __init__(self, graph, data_transformer, y_encoder):
"""Initialize the instance.
Args:
graph: The graph form of the learned model
"""
super().__init__(graph)
self.data_transformer = data_transformer
self.y_encoder = y_encoder

def predict(self, x_test):
"""Return predict results for the testing data.
Args:
x_test: An instance of numpy.ndarray containing the testing data.
Returns:
A numpy.ndarray containing the results.
"""
if Constant.LIMIT_MEMORY:
pass
test_loader = self.data_transformer.transform_test(x_test)
model = self.graph.produce_model()
model.eval()

outputs = []
with torch.no_grad():
for index, inputs in enumerate(test_loader):
outputs.append(model(inputs).numpy())
output = reduce(lambda x, y: np.concatenate((x, y)), outputs)
return self.inverse_transform_y(output)

def inverse_transform_y(self, output):
return self.y_encoder.inverse_transform(output)

def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
y_predict = self.predict(x_test)
return accuracy_score(y_test, y_predict)
43 changes: 32 additions & 11 deletions autokeras/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,30 @@ def replace_model(self, graph, model_id):

def add_model(self, metric_value, loss, graph, model_id):
if self.verbose:
print('Saving model.')
print('\nSaving model.')

pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.h5'))

# Update best_model text file

if self.verbose:
print('Model ID:', model_id)
print('Loss:', loss)
print('Metric Value:', metric_value)

ret = {'model_id': model_id, 'loss': loss, 'metric_value': metric_value}
self.history.append(ret)
if model_id == self.get_best_model_id():
file = open(os.path.join(self.path, 'best_model.txt'), 'w')
file.write('best model: ' + str(model_id))
file.close()

if self.verbose:
idx = ['model_id', 'loss', 'metric_value']
header = ['Model ID', 'Loss', 'Metric Value']
line = '|'.join(x.center(24) for x in header)
print('+' + '-' * len(line) + '+')
print('|' + line + '|')
for i, r in enumerate(self.history):
print('+' + '-' * len(line) + '+')
line = '|'.join(str(r[x]).center(24) for x in idx)
print('|' + line + '|')
print('+' + '-' * len(line) + '+')

descriptor = graph.extract_descriptor()
self.x_queue.append(descriptor)
self.y_queue.append(metric_value)
Expand All @@ -135,7 +141,7 @@ def add_model(self, metric_value, loss, graph, model_id):

def init_search(self):
if self.verbose:
print('Initializing search.')
print('\nInitializing search.')
graph = CnnGenerator(self.n_classes,
self.input_shape).generate(self.default_model_len,
self.default_model_width)
Expand All @@ -160,7 +166,10 @@ def search(self, train_data, test_data, timeout=60 * 60 * 24):
# Start the new process for training.
graph, father_id, model_id = self.training_queue.pop(0)
if self.verbose:
print('Training model ', model_id)
print('\n')
print('╒' + '=' * 46 + '╕')
print('|' + 'Training model {}'.format(model_id).center(46) + '|')
print('╘' + '=' * 46 + '╛')
mp.set_start_method('spawn', force=True)
pool = mp.Pool(1)
train_results = pool.map_async(train, [(graph, train_data, test_data, self.trainer_args,
Expand All @@ -182,8 +191,20 @@ def search(self, train_data, test_data, timeout=60 * 60 * 24):
self.descriptors.append(new_graph.extract_descriptor())

if self.verbose:
print('Father ID: ', new_father_id)
print(new_graph.operation_history)
cell_size = [24, 49]
header = ['Father Model ID', 'Added Operation']
line = '|'.join(str(x).center(cell_size[i]) for i, x in enumerate(header))
print('\n' + '+' + '-' * len(line) + '+')
print('|' + line + '|')
print('+' + '-' * len(line) + '+')
for i in range(len(new_graph.operation_history)):
if i == len(new_graph.operation_history)//2:
r = [new_father_id, new_graph.operation_history[i]]
else:
r = [' ', new_graph.operation_history[i]]
line = '|'.join(str(x).center(cell_size[i]) for i, x in enumerate(r))
print('|' + line + '|')
print('+' + '-' * len(line) + '+')
remaining_time = timeout - (time.time() - start_time)
if remaining_time > 0:
metric_value, loss, graph = train_results.get(timeout=remaining_time)[0]
Expand Down
28 changes: 28 additions & 0 deletions autokeras/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,31 @@ def predict(self, x_test):
def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
pass


class PortableClass(ABC):
def __init__(self, graph):
"""Initialize the instance.
Args:
graph: The graph form of the learned model
"""
self.graph = graph

@abstractmethod
def predict(self, x_test):
"""Return predict results for the testing data.
Args:
x_test: An instance of numpy.ndarray containing the testing data.
Returns:
A numpy.ndarray containing the results.
"""
pass

@abstractmethod
def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
pass
45 changes: 38 additions & 7 deletions autokeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from autokeras.constant import Constant

from tqdm.autonotebook import tqdm

class NoImprovementError(Exception):
def __init__(self, message):
Expand Down Expand Up @@ -110,26 +110,58 @@ def train_model(self,
test_metric_value_list = []
test_loss_list = []
self.optimizer = torch.optim.Adam(self.model.parameters())

if self.verbose:
pbar = tqdm(total=max_iter_num,
desc=' Model ',
file=sys.stdout,
leave=False,
ncols=75,
position=1,
unit=' epoch')

for epoch in range(max_iter_num):
self._train()
test_loss, metric_value = self._test()
test_metric_value_list.append(metric_value)
test_loss_list.append(test_loss)
if self.verbose:
print('Epoch {}: loss {}, metric_value {}'.format(epoch + 1, test_loss, metric_value))
pbar.update(1)
if epoch == 0:
header = ['Epoch', 'Loss', 'Accuracy']
line = '|'.join(x.center(24) for x in header)
pbar.write('+' + '-' * len(line) + '+')
pbar.write('|' + line + '|')
pbar.write('+' + '-' * len(line) + '+')
r = [epoch + 1, test_loss, metric_value]
line = '|'.join(str(x).center(24) for x in r)
pbar.write('|' + line + '|')
pbar.write('+' + '-' * len(line) + '+')
decreasing = self.early_stop.on_epoch_end(test_loss)
if not decreasing:
if self.verbose:
print('No loss decrease after {} epochs'.format(max_no_improvement_num))
print('\nNo loss decrease after {} epochs.\n'.format(max_no_improvement_num))
break
if self.verbose:
pbar.close()
return (sum(test_loss_list[-max_no_improvement_num:]) / max_no_improvement_num,
sum(test_metric_value_list[-max_no_improvement_num:]) / max_no_improvement_num)

def _train(self):
self.model.train()
loader = self.train_loader

for batch_idx, (inputs, targets) in enumerate(deepcopy(loader)):
cp_loader = deepcopy(loader)
if self.verbose:
pbar = tqdm(total=len(cp_loader),
desc='Current Epoch',
file=sys.stdout,
leave=False,
ncols=75,
position=0,
unit=' batch')

for batch_idx, (inputs, targets) in enumerate(cp_loader):
inputs, targets = inputs.to(self.device), targets.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(inputs)
Expand All @@ -138,10 +170,9 @@ def _train(self):
self.optimizer.step()
if self.verbose:
if batch_idx % 10 == 0:
print('.', end='')
sys.stdout.flush()
pbar.update(10)
if self.verbose:
print()
pbar.close()

def _test(self):
self.model.eval()
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ torchvision==0.2.1
numpy==1.15.1
scikit-learn==0.19.1
keras==2.2.2
tqdm==4.25.0
tensorflow==1.10.1
pytest
pytest-cov
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
packages=['autokeras'], # this must be the same as the name above
install_requires=['torch==0.4.1', 'torchvision==0.2.1', 'numpy>=1.15.1', 'keras==2.2.2', 'scikit-learn==0.19.1',
'tensorflow>=1.10.1'],
version='0.2.7',
version='0.2.8',
description='AutoML for deep learning',
author='Haifeng Jin',
author_email='[email protected]',
url='http://autokeras.com',
download_url='https://github.com/jhfjhfj1/autokeras/archive/0.2.7.tar.gz',
download_url='https://github.com/jhfjhfj1/autokeras/archive/0.2.8.tar.gz',
keywords=['automl'], # arbitrary keywords
classifiers=[]
)
38 changes: 37 additions & 1 deletion tests/test_image_supervised.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from unittest.mock import patch


import pytest

from autokeras.image_supervised import *
Expand Down Expand Up @@ -199,3 +198,40 @@ def test_fit_predict_regression(_):
results = clf.predict(train_x)
assert len(results) == len(train_x)
clean_dir(path)


@patch('torch.multiprocessing.Pool', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_export_keras_model(_):
Constant.MAX_ITER_NUM = 1
Constant.MAX_MODEL_NUM = 1
Constant.SEARCH_MAX_ITER = 1
Constant.T_MIN = 0.8
train_x = np.random.rand(100, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
test_x = np.random.rand(100, 25, 25, 1)
path = 'tests/resources/temp'
clean_dir(path)
clf = ImageClassifier(path=path, verbose=False, resume=False)
clf.n_epochs = 100
clf.fit(train_x, train_y)
score = clf.evaluate(train_x, train_y)
assert score <= 1.0

model_file_name = os.path.join(path, 'test_keras_model.h5')
clf.export_keras_model(model_file_name)
from keras.models import load_model
model = load_model(model_file_name)
results = model.predict(test_x)
assert len(results) == len(test_x)
del model, results, model_file_name

model_file_name = os.path.join(path, 'test_autokeras_model.pkl')
clf.export_autokeras_model(model_file_name)
from autokeras.utils import pickle_from_file
model = pickle_from_file(model_file_name)
results = model.predict(test_x)
assert len(results) == len(test_x)
score = model.evaluate(train_x, train_y)
assert score <= 1.0
clean_dir(path)

0 comments on commit 112dddb

Please sign in to comment.