Skip to content

Commit

Permalink
Add MXNet support for UAI Train
Browse files Browse the repository at this point in the history
  • Loading branch information
System Administrator authored and System Administrator committed Sep 27, 2017
1 parent 69aa210 commit bacb424
Show file tree
Hide file tree
Showing 14 changed files with 542 additions and 0 deletions.
Empty file.
24 changes: 24 additions & 0 deletions examples/mxnet/train/mnist/code/common/find_mxnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os, sys
try:
import mxnet as mx
except ImportError:
curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../../../python"))
import mxnet as mx
220 changes: 220 additions & 0 deletions examples/mxnet/train/mnist/code/common/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
import logging
import os
import time

def _get_lr_scheduler(args, kv):
if 'lr_factor' not in args or args.lr_factor >= 1:
return (args.lr, None)
epoch_size = args.num_examples / args.batch_size
if 'dist' in args.kv_store:
epoch_size /= kv.num_workers
begin_epoch = args.load_epoch if args.load_epoch else 0
step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
lr = args.lr
for s in step_epochs:
if begin_epoch >= s:
lr *= args.lr_factor
if lr != args.lr:
logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))

steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))

def _load_model(args, rank=0):
if 'load_epoch' not in args or args.load_epoch is None:
return (None, None, None)
assert args.model_prefix is not None

#Add UAI output_dir path
model_prefix = os.path.join(args.output_dir, args.model_prefix)
if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
model_prefix += "-%d" % (rank)
sym, arg_params, aux_params = mx.model.load_checkpoint(
model_prefix, args.load_epoch)
logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
return (sym, arg_params, aux_params)

def _save_model(args, rank=0):
if args.model_prefix is None:
return None

#Add UAI output_dir path
model_prefix = os.path.join(args.output_dir, args.model_prefix)
dst_dir = os.path.dirname(model_prefix)
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
return mx.callback.do_checkpoint(model_prefix if rank == 0 else "%s-%d" % (
model_prefix, rank))

def add_fit_args(parser):
"""
parser : argparse.ArgumentParser
return a parser added with args required by fit
"""
train = parser.add_argument_group('Training', 'model training')
train.add_argument('--network', type=str,
help='the neural network to use')
train.add_argument('--num-layers', type=int,
help='number of layers in the neural network, required by some networks such as resnet')
train.add_argument('--gpus', type=str,
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
train.add_argument('--kv-store', type=str, default='device',
help='key-value store type')
train.add_argument('--num-epochs', type=int, default=100,
help='max num of epochs')
train.add_argument('--lr', type=float, default=0.1,
help='initial learning rate')
train.add_argument('--lr-factor', type=float, default=0.1,
help='the ratio to reduce lr on each step')
train.add_argument('--lr-step-epochs', type=str,
help='the epochs to reduce the lr, e.g. 30,60')
train.add_argument('--optimizer', type=str, default='sgd',
help='the optimizer type')
train.add_argument('--mom', type=float, default=0.9,
help='momentum for sgd')
train.add_argument('--wd', type=float, default=0.0001,
help='weight decay for sgd')
train.add_argument('--batch-size', type=int, default=128,
help='the batch size')
train.add_argument('--disp-batches', type=int, default=20,
help='show progress for every n batches')
train.add_argument('--model-prefix', type=str,
help='model prefix')
parser.add_argument('--monitor', dest='monitor', type=int, default=0,
help='log network parameters every N iters if larger than 0')
train.add_argument('--load-epoch', type=int,
help='load the model on an epoch using the model-load-prefix')
train.add_argument('--top-k', type=int, default=0,
help='report the top-k accuracy. 0 means no report.')
train.add_argument('--test-io', type=int, default=0,
help='1 means test reading speed without training')
train.add_argument('--dtype', type=str, default='float32',
help='precision: float32 or float16')
return train

def fit(args, network, data_loader, **kwargs):
"""
train a model
args : argparse returns
network : the symbol definition of the nerual network
data_loader : function that returns the train and val data iterators
"""
# kvstore
kv = mx.kvstore.create(args.kv_store)

# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
logging.info('start with arguments %s', args)

# data iterators
(train, val) = data_loader(args, kv)
if args.test_io:
tic = time.time()
for i, batch in enumerate(train):
for j in batch.data:
j.wait_to_read()
if (i+1) % args.disp_batches == 0:
logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
i, args.disp_batches*args.batch_size/(time.time()-tic)))
tic = time.time()

return


# load model
if 'arg_params' in kwargs and 'aux_params' in kwargs:
arg_params = kwargs['arg_params']
aux_params = kwargs['aux_params']
else:
sym, arg_params, aux_params = _load_model(args, kv.rank)
if sym is not None:
assert sym.tojson() == network.tojson()

# save model
checkpoint = _save_model(args, kv.rank)

# devices for training
# Add UAI multi-gpu dev support
devs = mx.cpu() if args.num_gpus is None or args.num_gpus is 0 else [
mx.gpu(i) for i in range(args.num_gpus)]

# learning rate
lr, lr_scheduler = _get_lr_scheduler(args, kv)

# create model
model = mx.mod.Module(
context = devs,
symbol = network
)

lr_scheduler = lr_scheduler
optimizer_params = {
'learning_rate': lr,
'wd' : args.wd,
'lr_scheduler': lr_scheduler}

# Add 'multi_precision' parameter only for SGD optimizer
if args.optimizer == 'sgd':
optimizer_params['multi_precision'] = True

# Only a limited number of optimizers have 'momentum' property
has_momentum = {'sgd', 'dcasgd', 'nag'}
if args.optimizer in has_momentum:
optimizer_params['momentum'] = args.mom

monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None

if args.network == 'alexnet':
# AlexNet will not converge using Xavier
initializer = mx.init.Normal()
else:
initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),

# evaluation metrices
eval_metrics = ['accuracy']
if args.top_k > 0:
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k))

# callbacks that run after each batch
batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]
if 'batch_end_callback' in kwargs:
cbs = kwargs['batch_end_callback']
batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]

# run
model.fit(train,
begin_epoch = args.load_epoch if args.load_epoch else 0,
num_epoch = args.num_epochs,
eval_data = val,
eval_metric = eval_metrics,
kvstore = kv,
optimizer = args.optimizer,
optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
batch_end_callback = batch_end_callbacks,
epoch_end_callback = checkpoint,
allow_missing = True,
monitor = monitor)
Empty file.
64 changes: 64 additions & 0 deletions examples/mxnet/train/mnist/code/symbols/lenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick Haffner.
Gradient-based learning applied to document recognition.
Proceedings of the IEEE (1998)
"""
import mxnet as mx

def get_loc(data, attr={'lr_mult':'0.01'}):
"""
the localisation network in lenet-stn, it will increase acc about more than 1%,
when num-epoch >=15
"""
loc = mx.symbol.Convolution(data=data, num_filter=30, kernel=(5, 5), stride=(2,2))
loc = mx.symbol.Activation(data = loc, act_type='relu')
loc = mx.symbol.Pooling(data=loc, kernel=(2, 2), stride=(2, 2), pool_type='max')
loc = mx.symbol.Convolution(data=loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1))
loc = mx.symbol.Activation(data = loc, act_type='relu')
loc = mx.symbol.Pooling(data=loc, global_pool=True, kernel=(2, 2), pool_type='avg')
loc = mx.symbol.Flatten(data=loc)
loc = mx.symbol.FullyConnected(data=loc, num_hidden=6, name="stn_loc", attr=attr)
return loc


def get_symbol(num_classes=10, add_stn=False, **kwargs):
data = mx.symbol.Variable('data')
if add_stn:
data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),
transform_type="affine", sampler_type="bilinear")
# first conv
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
kernel=(2,2), stride=(2,2))
# second conv
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
kernel=(2,2), stride=(2,2))
# first fullc
flatten = mx.symbol.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=num_classes)
# loss
lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
return lenet
32 changes: 32 additions & 0 deletions examples/mxnet/train/mnist/code/symbols/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
a simple multilayer perceptron
"""
import mxnet as mx

def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.sym.Flatten(data=data)
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
return mlp
Loading

0 comments on commit bacb424

Please sign in to comment.