Skip to content

Commit

Permalink
Merge pull request #13 from soar-zhengjian/master
Browse files Browse the repository at this point in the history
add support for dist training
  • Loading branch information
classicsong authored Apr 19, 2018
2 parents 3301852 + 29de98a commit 6039b07
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 12 deletions.
10 changes: 9 additions & 1 deletion uaitrain/api/create_train_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class CreateUAITrainJobOp(BaseUAITrainAPIOp):
out_ufile_path string(required) the ufile path of output data
docker_cmd string(required) the cmd of run the job
max_exec_time int(required) the max exec time of job. if the job don't finish in the time, system will stop the job.
work_num int(optional) the num of server. This param should be greater than 1 for distributed train.
dist_ai_frame int(optional) the frame of distributed train
business_group string(optional) Which business group to run the job
job_memo string(optional) the memo of the job
Expand All @@ -47,7 +49,7 @@ class CreateUAITrainJobOp(BaseUAITrainAPIOp):
"""

def __init__(self, pub_key, priv_key, job_name, work_id, code_uhub_path, data_ufile_path, out_ufile_path,
docker_cmd, max_exec_time, business_group="", job_memo="", project_id="",
docker_cmd, max_exec_time, work_num=1, dist_ai_frame="", business_group="", job_memo="", project_id="",
region="", zone=""):
super(CreateUAITrainJobOp, self).__init__(self.ACTION_NAME,
pub_key,
Expand All @@ -63,13 +65,16 @@ def __init__(self, pub_key, priv_key, job_name, work_id, code_uhub_path, data_uf
self.cmd_params["DockerCmd"] = docker_cmd
self.cmd_params["PredictStartTime"] = 0
self.cmd_params["MaxExecuteTime"] = max_exec_time
self.cmd_params["TrainWorkAmount"] = work_num
self.cmd_params["DistAIFrame"] = dist_ai_frame

self.cmd_params["TrainPublicKey"] = pub_key
self.cmd_params["TrainPrivateKey"] = priv_key

self.cmd_params["TrainJobMemo"] = job_memo
self.cmd_params["BusinessGroup"] = business_group


def _check_args(self):
super(CreateUAITrainJobOp, self)._check_args()
if self.cmd_params["TrainJobName"] == "" or type(self.cmd_params["TrainJobName"]) != str:
Expand All @@ -91,4 +96,7 @@ def _check_args(self):
raise RuntimeError("docker_cmd shoud be <str> and is not nil.")

if self.cmd_params["MaxExecuteTime"] == "" or type(self.cmd_params["MaxExecuteTime"]) != int:
raise RuntimeError("max_exec_time shoud be <int> and is not nil.")

if self.cmd_params["TrainWorkAmount"] == "" or type(self.cmd_params["TrainWorkAmount"]) != int:
raise RuntimeError("max_exec_time shoud be <int> and is not nil.")
7 changes: 5 additions & 2 deletions uaitrain/api/get_train_job_running_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,20 @@ class GetUAITrainRunningLogOp(BaseUAITrainAPIOp):
RunningLog []string realtime log that train job produces
"""

def __init__(self, pub_key, priv_key, job_id, project_id="", region="", zone=""):
def __init__(self, pub_key, priv_key, job_id, log_topic_id, project_id="", region="", zone=""):
super(GetUAITrainRunningLogOp, self).__init__(self.ACTION_NAME,
pub_key,
priv_key,
project_id,
region,
zone)
self.cmd_params["TrainJobId"] = job_id
self.cmd_params["LogTopicId"] = log_topic_id

def _check_args(self):
super(GetUAITrainRunningLogOp, self)._check_args()

if type(self.cmd_params["TrainJobId"]) != str or self.cmd_params["TrainJobId"] == "":
raise RuntimeError("job_id shoud be str and is not nil.")
raise RuntimeError("job_id shoud be str and is not nil.")
if type(self.cmd_params["LogTopicId"]) != str or self.cmd_params["LogTopicId"] == "":
raise RuntimeError("log_topic_id shoud be str and is not nil.")
51 changes: 51 additions & 0 deletions uaitrain/api/get_train_log_topic_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2017 The UAI-SDK Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from uaitrain.api.base_op import BaseUAITrainAPIOp

class GetUAITrainRunningLogTopicListOp(BaseUAITrainAPIOp):
ACTION_NAME = "GetUAITrainRunningLogTopicList"
"""
GetUAITrainRunningLogTopicListOp
Compatable with UAI Train GetUAITrainRunningLogTopicList API func
Input:
pub_key string(required) Public key of the user
priv_key string(required) Private key of the user
project_id int(optional) Project ID of the job
region string(optional) Which Region to run the job
zone string(optional) Which Zone in the Region to run the job
job_id string(required) Job id of the job
Output:
RetCode int(required) Op return code: 0: success, others: error code
Action string(required) Action name
Message string(not required) Message: error description
RunningLog []string realtime log that train job produces
"""

def __init__(self, pub_key, priv_key, job_id, project_id="", region="", zone=""):
super(GetUAITrainRunningLogTopicListOp, self).__init__(self.ACTION_NAME,
pub_key,
priv_key,
project_id,
region,
zone)
self.cmd_params["TrainJobId"] = job_id

def _check_args(self):
super(GetUAITrainRunningLogTopicListOp, self)._check_args()

if type(self.cmd_params["TrainJobId"]) != str or self.cmd_params["TrainJobId"] == "":
raise RuntimeError("job_id shoud be str and is not nil.")
56 changes: 51 additions & 5 deletions uaitrain/operation/create_train_job/base_create_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
# limitations under the License.
# ==============================================================================

import sys
import os
import argparse

from uai.utils.utils import GATEWAY_DEFAULT
from uai.utils.logger import uai_logger
from uaitrain.operation.base_op import BaseUAITrainOp
from uaitrain.api.create_train_job import CreateUAITrainJobOp
from uaitrain.api.get_train_available_resource import GetUAITrainAvailableResourceOp
from uaitrain.api.get_env_pkg import GetUAITrainEnvPkgAPIOp

class BaseUAITrainCreateTrainJobOp(BaseUAITrainOp):
def __init__(self, parser):
Expand Down Expand Up @@ -116,6 +112,22 @@ def _add_create_ufs_args(self, create_parser):
required=False,
help='The ufs mount point for the output')

def _add_create_dist_args(self, create_parser):
dist_parser = create_parser.add_argument_group(
'Dist-train Params', '')

dist_parser.add_argument(
'--dist_ai_frame',
type=str,
required=False,
help='The AI framework for dist-train.(eg. tensorflow, mxnet). if you do not use dist-train, ignore it')
dist_parser.add_argument(
'--node_num',
type=int,
default=1,
required=False,
help='The num of node for dist-train. if you do not use dist-train, ignore it')

def _add_args(self):
parser = self.parser.add_parser('create', help='Create UAI Train Job')
self.create_parser = parser
Expand All @@ -125,6 +137,7 @@ def _add_args(self):
self._add_create_job_args(parser)
self._add_create_ufile_args(parser)
self._add_create_ufs_args(parser)
self._add_create_dist_args(parser)

def _parse_args(self, args):
super(BaseUAITrainCreateTrainJobOp, self)._parse_args(args)
Expand Down Expand Up @@ -168,6 +181,12 @@ def _parse_args(self, args):
else:
raise RuntimeError("Need either output_ufile_path or output_ufs_path")

#dist
self.dist_ai_frame = args['dist_ai_frame'] if 'dist_ai_frame' in args else ""
self.worker_num = args['node_num']
if self.dist_ai_frame != "" and self.worker_num <= 1:
raise RuntimeError("The num of node for dist-train {0} should be greater 1, please check param node_num".format(self.dist_ai_frame))

return True

def _check_res(self):
Expand All @@ -190,6 +209,29 @@ def _check_res(self):
RuntimeError('Unsupported node_type')
return -1

def _get_dist_ai_frame_id(self):
pkgtype = "DistAIFrame"
api_op = GetUAITrainEnvPkgAPIOp(self.pub_key,
self.pri_key,
pkgtype,
self.project_id,
self.region,
self.zone)
succ, result = api_op.call_api()

if succ is False:
raise RuntimeError("Error get {0} info from server".format(pkgtype))

for avpkg in result['PkgSet']:
if avpkg["PkgName"] == self.dist_ai_frame:
return avpkg["PkgId"]

ai_frame_set = [avpkg["PkgId"] for avpkg in result['PkgSet']]

print("Required Dist-frame {0} not exist", self.dist_ai_frame)
print("Now only support {0}", ai_frame_set)
raise RuntimeError("Some {0} package is not supported: {1}".format(pkgtype, self.dist_ai_frame))

def cmd_run(self, args):
if self._parse_args(args) == False:
return False
Expand All @@ -198,6 +240,8 @@ def cmd_run(self, args):
if node_id < 0:
return False

ai_frame_id = self._get_dist_ai_frame_id()

create_op = CreateUAITrainJobOp(
pub_key=self.pub_key,
priv_key=self.pri_key,
Expand All @@ -208,6 +252,8 @@ def cmd_run(self, args):
out_ufile_path=self.output_path,
docker_cmd=self.docker_cmd,
max_exec_time=self.max_exec_time,
work_num=self.worker_num,
dist_ai_frame=ai_frame_id,
business_group=self.business_group,
job_memo=self.job_memo,
project_id=self.project_id,
Expand Down
Empty file.
89 changes: 89 additions & 0 deletions uaitrain/operation/get_log_topic/get_log_topic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2017 The UAI-SDK Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import time
from uai.utils.logger import uai_logger
from uai.utils.logger import printConsoleOnlyError
from uaitrain.operation.base_op import BaseUAITrainOp
from uaitrain.api.get_train_log_topic_list import GetUAITrainRunningLogTopicListOp
from uaitrain.api.get_train_job_list import GetUAITrainJobListOp

class BaseUAITrainGetLogTopicOp(BaseUAITrainOp):
def __init__(self, parser):
super(BaseUAITrainGetLogTopicOp, self).__init__(parser)
printConsoleOnlyError()

def _add_job_info_args(self, job_parser):
info_parser = job_parser.add_argument_group(
'Job Info Params', 'Job Infos')
info_parser.add_argument(
'--job_id',
type=str,
required=True,
help='The <job_id> to query')

def _add_args(self):
parser = self.parser.add_parser('topic', help='Get realtime log topic of UAI Train Job')
self.job_parser = parser
self._add_account_args(parser)
self._add_job_info_args(parser)

def _parse_args(self, args):
super(BaseUAITrainGetLogTopicOp, self)._parse_args(args)

self.job_id = args['job_id']
return True

def _check_job_running(self):
job_op = GetUAITrainJobListOp(
pub_key=self.pub_key,
priv_key=self.pri_key,
job_id=self.job_id,
project_id=self.project_id,
region=self.region,
zone=self.zone)

succ, resp = job_op.call_api()
if succ is False:
print("Error get job status info. job {0} ".format(self.job_id))
return False

if resp['DataSet'][0]['Status'] in ['Done', 'Stopped', 'Deleted', 'Error']:
return False
return True

def cmd_run(self, args):
if self._parse_args(args) == False:
return False

topic_op = GetUAITrainRunningLogTopicListOp(
pub_key=self.pub_key,
priv_key=self.pri_key,
job_id=self.job_id,
project_id=self.project_id,
region=self.region,
zone=self.zone)

succ, resp = topic_op.call_api()
if succ is False:
uai_logger.warn("Error get realtime topic info. job {0}, check your job_id, it may be not running.".format(self.job_id))
return False

result = resp['DataSet'] if resp['DataSet'] is not None else []
print ("The Topic list is:")
for topic in result:
print (topic['TopicId'])

return True
7 changes: 7 additions & 0 deletions uaitrain/operation/get_realtime_log/base_log_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def _add_job_info_args(self, job_parser):
type=str,
required=True,
help='The <job_id> to query')
info_parser.add_argument(
'--log_topic_id',
type=str,
required=True,
help='The <log_topic_id> to query')

def _add_args(self):
parser = self.parser.add_parser('log', help='Get realtime log of UAI Train Job')
Expand All @@ -44,6 +49,7 @@ def _parse_args(self, args):
super(BaseUAITrainGetRealtimeLogOp, self)._parse_args(args)

self.job_id = args['job_id']
self.log_topic_id = args['log_topic_id']
return True

def _check_job_running(self):
Expand Down Expand Up @@ -73,6 +79,7 @@ def cmd_run(self, args):
pub_key=self.pub_key,
priv_key=self.pri_key,
job_id=self.job_id,
log_topic_id=self.log_topic_id,
project_id=self.project_id,
region=self.region,
zone=self.zone)
Expand Down
5 changes: 4 additions & 1 deletion uaitrain_tool/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from uaitrain.operation.info_train_job.info_train_op import BaseUAITrainRunningJobInfoOp
from uaitrain.operation.rename_train_job.base_rename_op import BaseUAITrainRenameTrainJobOp
from uaitrain.operation.get_train_job_conf.base_conf_op import BaseUAITrainTrainJobConfOp
from uaitrain.operation.get_tensorboard_url.get_tensorboard_url import BaseUAITrainGetTensorBoardUrlOp
from uaitrain.operation.get_log_topic.get_log_topic import BaseUAITrainGetLogTopicOp

if __name__ == '__main__':
parser = argparse.ArgumentParser(
Expand All @@ -40,6 +40,7 @@
info_op = BaseUAITrainRunningJobInfoOp(subparsers)
rename_op = BaseUAITrainRenameTrainJobOp(subparsers)
conf_op = BaseUAITrainTrainJobConfOp(subparsers)
topic_op = BaseUAITrainGetLogTopicOp(subparsers)
cmd_args = vars(parser.parse_args())

if cmd_args['commands'] == 'pack':
Expand All @@ -58,6 +59,8 @@
conf_op.cmd_run(cmd_args)
elif cmd_args['commands'] == 'rename':
rename_op.cmd_run(cmd_args)
elif cmd_args['commands'] == 'topic':
topic_op.cmd_run(cmd_args)
else:
print("UAI Train Base Tool Only Support General operations, please use python base_tool.py -h to check")

Loading

0 comments on commit 6039b07

Please sign in to comment.