diff --git a/visualdl/reader/graph_reader.py b/visualdl/reader/graph_reader.py index 4e8035bd..3c3af3de 100644 --- a/visualdl/reader/graph_reader.py +++ b/visualdl/reader/graph_reader.py @@ -41,7 +41,7 @@ class GraphReader(object): """Graph reader to read vdl graph files, support for frontend api in lib.py. """ - def __init__(self, logdir=''): + def __init__(self, logdir='', model_name=''): """Instance of GraphReader Args: @@ -52,6 +52,7 @@ def __init__(self, logdir=''): else: self.dir = logdir + self.model_name = model_name self.walks = {} self.displayname2runs = {} self.runs2displayname = {} @@ -102,7 +103,10 @@ def graphs(self, update=False): ] tags_temp.sort(reverse=True) if len(tags_temp) > 0: - walks_temp.update({run: tags_temp[0]}) + if self.model_name: + walks_temp.update({run: self.model_name}) + else: + walks_temp.update({run: tags_temp[0]}) self.walks = walks_temp return self.walks diff --git a/visualdl/server/api.py b/visualdl/server/api.py index eea55455..7c644025 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -77,9 +77,13 @@ def try_call(function, *args, **kwargs): class Api(object): - def __init__(self, logdir, model, cache_timeout): + def __init__(self, logdir, model, modelfile, cache_timeout): + self.model_name = '' + if not logdir and modelfile: + logdir = os.path.dirname(modelfile) + self.model_name = os.path.basename(modelfile) self._reader = LogReader(logdir) - self._graph_reader = GraphReader(logdir) + self._graph_reader = GraphReader(logdir, self.model_name) self._graph_reader.set_displayname(self._reader) if model: if 'vdlgraph' in model: @@ -415,7 +419,7 @@ def get_component_tabs(*apis, vdl_args, request_args): all_tabs = set() if vdl_args.component_tabs: return list(vdl_args.component_tabs) - if vdl_args.logdir: + if vdl_args.logdir or vdl_args.modelfile: for api in apis: all_tabs.update(api('component_tabs', request_args)) all_tabs.add('static_graph') @@ -427,8 +431,8 @@ def get_component_tabs(*apis, vdl_args, request_args): return list(all_tabs) -def create_api_call(logdir, model, cache_timeout): - api = Api(logdir, model, cache_timeout) +def create_api_call(logdir, model, modelfile, cache_timeout): + api = Api(logdir, model, modelfile, cache_timeout) routes = { 'components': (api.components, []), 'runs': (api.runs, []), diff --git a/visualdl/server/app.py b/visualdl/server/app.py index 2f7b0f3b..a1d36812 100644 --- a/visualdl/server/app.py +++ b/visualdl/server/app.py @@ -86,7 +86,7 @@ def get_locale(): ) # we add this to prevent SIGINT not work in multiprocess queue waiting babel = Babel(app, locale_selector=get_locale) # noqa:F841 # Babel api from flask_babel v3.0.0 - api_call = create_api_call(args.logdir, args.model, args.cache_timeout) + api_call = create_api_call(args.logdir, args.model, args.modelfile, args.cache_timeout) profiler_api_call = create_profiler_api_call(args.logdir) inference_api_call = create_model_convert_api_call() fastdeploy_api_call = create_fastdeploy_api_call() diff --git a/visualdl/server/args.py b/visualdl/server/args.py index 71f97afb..fdbefea2 100644 --- a/visualdl/server/args.py +++ b/visualdl/server/args.py @@ -40,6 +40,7 @@ def __init__(self, args): self.api_only = args.get('api_only', False) self.open_browser = args.get('open_browser', False) self.model = args.get('model', '') + self.modelfile = args.get('modelfile', '') self.product = args.get('product', default_product) self.telemetry = args.get('telemetry', True) self.theme = args.get('theme', None) @@ -123,6 +124,7 @@ def __init__(self, **kwargs): self.api_only = args.api_only self.open_browser = args.open_browser self.model = args.model + self.modelfile = args.modelfile self.product = args.product self.telemetry = args.telemetry self.theme = args.theme @@ -141,6 +143,13 @@ def parse_args(): epilog="For more information: https://github.com/PaddlePaddle/VisualDL" ) + parser.add_argument( + "--modelfile", + type=str, + action="store", + default="", + help="json model file path") + parser.add_argument( "--logdir", action="store", nargs="+", help="log file directory")