Skip to content

Commit

Permalink
add parameters --modelfile
Browse files Browse the repository at this point in the history
  • Loading branch information
cse0001 committed Dec 5, 2024
1 parent b815bb5 commit a46773f
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
8 changes: 6 additions & 2 deletions visualdl/reader/graph_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GraphReader(object):
"""Graph reader to read vdl graph files, support for frontend api in lib.py.
"""

def __init__(self, logdir=''):
def __init__(self, logdir='', model_name=''):
"""Instance of GraphReader
Args:
Expand All @@ -52,6 +52,7 @@ def __init__(self, logdir=''):
else:
self.dir = logdir

self.model_name = model_name
self.walks = {}
self.displayname2runs = {}
self.runs2displayname = {}
Expand Down Expand Up @@ -102,7 +103,10 @@ def graphs(self, update=False):
]
tags_temp.sort(reverse=True)
if len(tags_temp) > 0:
walks_temp.update({run: tags_temp[0]})
if self.model_name:
walks_temp.update({run: self.model_name})
else:
walks_temp.update({run: tags_temp[0]})
self.walks = walks_temp
return self.walks

Expand Down
14 changes: 9 additions & 5 deletions visualdl/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ def try_call(function, *args, **kwargs):


class Api(object):
def __init__(self, logdir, model, cache_timeout):
def __init__(self, logdir, model, modelfile, cache_timeout):
self.model_name = ''
if not logdir and modelfile:
logdir = os.path.dirname(modelfile)
self.model_name = os.path.basename(modelfile)
self._reader = LogReader(logdir)
self._graph_reader = GraphReader(logdir)
self._graph_reader = GraphReader(logdir, self.model_name)
self._graph_reader.set_displayname(self._reader)
if model:
if 'vdlgraph' in model:
Expand Down Expand Up @@ -415,7 +419,7 @@ def get_component_tabs(*apis, vdl_args, request_args):
all_tabs = set()
if vdl_args.component_tabs:
return list(vdl_args.component_tabs)
if vdl_args.logdir:
if vdl_args.logdir or vdl_args.modelfile:
for api in apis:
all_tabs.update(api('component_tabs', request_args))
all_tabs.add('static_graph')
Expand All @@ -427,8 +431,8 @@ def get_component_tabs(*apis, vdl_args, request_args):
return list(all_tabs)


def create_api_call(logdir, model, cache_timeout):
api = Api(logdir, model, cache_timeout)
def create_api_call(logdir, model, modelfile, cache_timeout):
api = Api(logdir, model, modelfile, cache_timeout)
routes = {
'components': (api.components, []),
'runs': (api.runs, []),
Expand Down
2 changes: 1 addition & 1 deletion visualdl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_locale():
) # we add this to prevent SIGINT not work in multiprocess queue waiting
babel = Babel(app, locale_selector=get_locale) # noqa:F841
# Babel api from flask_babel v3.0.0
api_call = create_api_call(args.logdir, args.model, args.cache_timeout)
api_call = create_api_call(args.logdir, args.model, args.modelfile, args.cache_timeout)
profiler_api_call = create_profiler_api_call(args.logdir)
inference_api_call = create_model_convert_api_call()
fastdeploy_api_call = create_fastdeploy_api_call()
Expand Down
9 changes: 9 additions & 0 deletions visualdl/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, args):
self.api_only = args.get('api_only', False)
self.open_browser = args.get('open_browser', False)
self.model = args.get('model', '')
self.modelfile = args.get('modelfile', '')
self.product = args.get('product', default_product)
self.telemetry = args.get('telemetry', True)
self.theme = args.get('theme', None)
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(self, **kwargs):
self.api_only = args.api_only
self.open_browser = args.open_browser
self.model = args.model
self.modelfile = args.modelfile
self.product = args.product
self.telemetry = args.telemetry
self.theme = args.theme
Expand All @@ -141,6 +143,13 @@ def parse_args():
epilog="For more information: https://github.com/PaddlePaddle/VisualDL"
)

parser.add_argument(
"--modelfile",
type=str,
action="store",
default="",
help="json model file path")

parser.add_argument(
"--logdir", action="store", nargs="+", help="log file directory")

Expand Down

0 comments on commit a46773f

Please sign in to comment.