-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
System Administrator
authored and
System Administrator
committed
Sep 27, 2017
1 parent
69aa210
commit bacb424
Showing
14 changed files
with
542 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import os, sys | ||
try: | ||
import mxnet as mx | ||
except ImportError: | ||
curr_path = os.path.abspath(os.path.dirname(__file__)) | ||
sys.path.append(os.path.join(curr_path, "../../../python")) | ||
import mxnet as mx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import mxnet as mx | ||
import logging | ||
import os | ||
import time | ||
|
||
def _get_lr_scheduler(args, kv): | ||
if 'lr_factor' not in args or args.lr_factor >= 1: | ||
return (args.lr, None) | ||
epoch_size = args.num_examples / args.batch_size | ||
if 'dist' in args.kv_store: | ||
epoch_size /= kv.num_workers | ||
begin_epoch = args.load_epoch if args.load_epoch else 0 | ||
step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] | ||
lr = args.lr | ||
for s in step_epochs: | ||
if begin_epoch >= s: | ||
lr *= args.lr_factor | ||
if lr != args.lr: | ||
logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch)) | ||
|
||
steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0] | ||
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor)) | ||
|
||
def _load_model(args, rank=0): | ||
if 'load_epoch' not in args or args.load_epoch is None: | ||
return (None, None, None) | ||
assert args.model_prefix is not None | ||
|
||
#Add UAI output_dir path | ||
model_prefix = os.path.join(args.output_dir, args.model_prefix) | ||
if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): | ||
model_prefix += "-%d" % (rank) | ||
sym, arg_params, aux_params = mx.model.load_checkpoint( | ||
model_prefix, args.load_epoch) | ||
logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) | ||
return (sym, arg_params, aux_params) | ||
|
||
def _save_model(args, rank=0): | ||
if args.model_prefix is None: | ||
return None | ||
|
||
#Add UAI output_dir path | ||
model_prefix = os.path.join(args.output_dir, args.model_prefix) | ||
dst_dir = os.path.dirname(model_prefix) | ||
if not os.path.isdir(dst_dir): | ||
os.mkdir(dst_dir) | ||
return mx.callback.do_checkpoint(model_prefix if rank == 0 else "%s-%d" % ( | ||
model_prefix, rank)) | ||
|
||
def add_fit_args(parser): | ||
""" | ||
parser : argparse.ArgumentParser | ||
return a parser added with args required by fit | ||
""" | ||
train = parser.add_argument_group('Training', 'model training') | ||
train.add_argument('--network', type=str, | ||
help='the neural network to use') | ||
train.add_argument('--num-layers', type=int, | ||
help='number of layers in the neural network, required by some networks such as resnet') | ||
train.add_argument('--gpus', type=str, | ||
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu') | ||
train.add_argument('--kv-store', type=str, default='device', | ||
help='key-value store type') | ||
train.add_argument('--num-epochs', type=int, default=100, | ||
help='max num of epochs') | ||
train.add_argument('--lr', type=float, default=0.1, | ||
help='initial learning rate') | ||
train.add_argument('--lr-factor', type=float, default=0.1, | ||
help='the ratio to reduce lr on each step') | ||
train.add_argument('--lr-step-epochs', type=str, | ||
help='the epochs to reduce the lr, e.g. 30,60') | ||
train.add_argument('--optimizer', type=str, default='sgd', | ||
help='the optimizer type') | ||
train.add_argument('--mom', type=float, default=0.9, | ||
help='momentum for sgd') | ||
train.add_argument('--wd', type=float, default=0.0001, | ||
help='weight decay for sgd') | ||
train.add_argument('--batch-size', type=int, default=128, | ||
help='the batch size') | ||
train.add_argument('--disp-batches', type=int, default=20, | ||
help='show progress for every n batches') | ||
train.add_argument('--model-prefix', type=str, | ||
help='model prefix') | ||
parser.add_argument('--monitor', dest='monitor', type=int, default=0, | ||
help='log network parameters every N iters if larger than 0') | ||
train.add_argument('--load-epoch', type=int, | ||
help='load the model on an epoch using the model-load-prefix') | ||
train.add_argument('--top-k', type=int, default=0, | ||
help='report the top-k accuracy. 0 means no report.') | ||
train.add_argument('--test-io', type=int, default=0, | ||
help='1 means test reading speed without training') | ||
train.add_argument('--dtype', type=str, default='float32', | ||
help='precision: float32 or float16') | ||
return train | ||
|
||
def fit(args, network, data_loader, **kwargs): | ||
""" | ||
train a model | ||
args : argparse returns | ||
network : the symbol definition of the nerual network | ||
data_loader : function that returns the train and val data iterators | ||
""" | ||
# kvstore | ||
kv = mx.kvstore.create(args.kv_store) | ||
|
||
# logging | ||
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' | ||
logging.basicConfig(level=logging.DEBUG, format=head) | ||
logging.info('start with arguments %s', args) | ||
|
||
# data iterators | ||
(train, val) = data_loader(args, kv) | ||
if args.test_io: | ||
tic = time.time() | ||
for i, batch in enumerate(train): | ||
for j in batch.data: | ||
j.wait_to_read() | ||
if (i+1) % args.disp_batches == 0: | ||
logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % ( | ||
i, args.disp_batches*args.batch_size/(time.time()-tic))) | ||
tic = time.time() | ||
|
||
return | ||
|
||
|
||
# load model | ||
if 'arg_params' in kwargs and 'aux_params' in kwargs: | ||
arg_params = kwargs['arg_params'] | ||
aux_params = kwargs['aux_params'] | ||
else: | ||
sym, arg_params, aux_params = _load_model(args, kv.rank) | ||
if sym is not None: | ||
assert sym.tojson() == network.tojson() | ||
|
||
# save model | ||
checkpoint = _save_model(args, kv.rank) | ||
|
||
# devices for training | ||
# Add UAI multi-gpu dev support | ||
devs = mx.cpu() if args.num_gpus is None or args.num_gpus is 0 else [ | ||
mx.gpu(i) for i in range(args.num_gpus)] | ||
|
||
# learning rate | ||
lr, lr_scheduler = _get_lr_scheduler(args, kv) | ||
|
||
# create model | ||
model = mx.mod.Module( | ||
context = devs, | ||
symbol = network | ||
) | ||
|
||
lr_scheduler = lr_scheduler | ||
optimizer_params = { | ||
'learning_rate': lr, | ||
'wd' : args.wd, | ||
'lr_scheduler': lr_scheduler} | ||
|
||
# Add 'multi_precision' parameter only for SGD optimizer | ||
if args.optimizer == 'sgd': | ||
optimizer_params['multi_precision'] = True | ||
|
||
# Only a limited number of optimizers have 'momentum' property | ||
has_momentum = {'sgd', 'dcasgd', 'nag'} | ||
if args.optimizer in has_momentum: | ||
optimizer_params['momentum'] = args.mom | ||
|
||
monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None | ||
|
||
if args.network == 'alexnet': | ||
# AlexNet will not converge using Xavier | ||
initializer = mx.init.Normal() | ||
else: | ||
initializer = mx.init.Xavier( | ||
rnd_type='gaussian', factor_type="in", magnitude=2) | ||
# initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), | ||
|
||
# evaluation metrices | ||
eval_metrics = ['accuracy'] | ||
if args.top_k > 0: | ||
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k)) | ||
|
||
# callbacks that run after each batch | ||
batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)] | ||
if 'batch_end_callback' in kwargs: | ||
cbs = kwargs['batch_end_callback'] | ||
batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] | ||
|
||
# run | ||
model.fit(train, | ||
begin_epoch = args.load_epoch if args.load_epoch else 0, | ||
num_epoch = args.num_epochs, | ||
eval_data = val, | ||
eval_metric = eval_metrics, | ||
kvstore = kv, | ||
optimizer = args.optimizer, | ||
optimizer_params = optimizer_params, | ||
initializer = initializer, | ||
arg_params = arg_params, | ||
aux_params = aux_params, | ||
batch_end_callback = batch_end_callbacks, | ||
epoch_end_callback = checkpoint, | ||
allow_missing = True, | ||
monitor = monitor) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
""" | ||
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick Haffner. | ||
Gradient-based learning applied to document recognition. | ||
Proceedings of the IEEE (1998) | ||
""" | ||
import mxnet as mx | ||
|
||
def get_loc(data, attr={'lr_mult':'0.01'}): | ||
""" | ||
the localisation network in lenet-stn, it will increase acc about more than 1%, | ||
when num-epoch >=15 | ||
""" | ||
loc = mx.symbol.Convolution(data=data, num_filter=30, kernel=(5, 5), stride=(2,2)) | ||
loc = mx.symbol.Activation(data = loc, act_type='relu') | ||
loc = mx.symbol.Pooling(data=loc, kernel=(2, 2), stride=(2, 2), pool_type='max') | ||
loc = mx.symbol.Convolution(data=loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1)) | ||
loc = mx.symbol.Activation(data = loc, act_type='relu') | ||
loc = mx.symbol.Pooling(data=loc, global_pool=True, kernel=(2, 2), pool_type='avg') | ||
loc = mx.symbol.Flatten(data=loc) | ||
loc = mx.symbol.FullyConnected(data=loc, num_hidden=6, name="stn_loc", attr=attr) | ||
return loc | ||
|
||
|
||
def get_symbol(num_classes=10, add_stn=False, **kwargs): | ||
data = mx.symbol.Variable('data') | ||
if add_stn: | ||
data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28), | ||
transform_type="affine", sampler_type="bilinear") | ||
# first conv | ||
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20) | ||
tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh") | ||
pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max", | ||
kernel=(2,2), stride=(2,2)) | ||
# second conv | ||
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50) | ||
tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh") | ||
pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max", | ||
kernel=(2,2), stride=(2,2)) | ||
# first fullc | ||
flatten = mx.symbol.Flatten(data=pool2) | ||
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500) | ||
tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh") | ||
# second fullc | ||
fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=num_classes) | ||
# loss | ||
lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax') | ||
return lenet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
""" | ||
a simple multilayer perceptron | ||
""" | ||
import mxnet as mx | ||
|
||
def get_symbol(num_classes=10, **kwargs): | ||
data = mx.symbol.Variable('data') | ||
data = mx.sym.Flatten(data=data) | ||
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) | ||
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") | ||
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) | ||
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") | ||
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) | ||
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') | ||
return mlp |
Oops, something went wrong.