diff --git a/demo/components/cond_inside_cond_test.py b/demo/components/cond_inside_cond_test.py new file mode 100755 index 000000000..557156b6a --- /dev/null +++ b/demo/components/cond_inside_cond_test.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import paddle +from visualdl import LogWriter +""" + pseudocode: + for i in range(1, 10): + a = 2 * i + if i < 5: + if i >= 3: + return a + a + else: + return a - a + else: + if i < 8: + return a * a + else: + return a / a +""" +paddle.enable_static() + + +def less_than_branch(i, a): + return paddle.static.nn.cond( + i >= 3.0, + lambda: paddle.add(a, a), + lambda: paddle.subtract(a, a), + ) + + +def greater_equal_branch(i, a): + return paddle.static.nn.cond( + i < 8.0, + lambda: paddle.multiply(a, a), + lambda: paddle.divide(a, a), + ) + + +main_program = paddle.static.Program() +startup_program = paddle.static.Program() +with paddle.static.program_guard(main_program, startup_program): + i = paddle.static.data(name="i", shape=[1], dtype='float32') + i.stop_gradient = False + a = 2.0 * i + out = paddle.static.nn.cond( + i < 5.0, + lambda: less_than_branch(i, a), + lambda: greater_equal_branch(i, a), + ) + mean = paddle.mean(out) + +with LogWriter(logdir="./log/cond_inside_cond_test/") as writer: + writer.add_graph( + model=main_program, + input_spec=[paddle.static.InputSpec([1], dtype='float32')], + verbose=True, + is_pir=True) diff --git a/demo/components/cond_test.py b/demo/components/cond_test.py new file mode 100755 index 000000000..3c0fe8d9e --- /dev/null +++ b/demo/components/cond_test.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import paddle +from visualdl import LogWriter + +paddle.enable_static() +""" + pseudocode: + for i in range(1, 10): + a = 2 * i + if i < 5: + return a + a + else: + return a - a +""" + + +class ConditionalLayer(paddle.nn.Layer): + def __init__(self): + super(ConditionalLayer, self).__init__() + + def forward(self, i): + a = 2.0 * i + out = paddle.static.nn.cond( + i < 5.0, + lambda: paddle.add(a, a), + lambda: paddle.subtract(a, a), + ) + return out + + +main_program = paddle.static.Program() +startup_program = paddle.static.Program() +with paddle.static.program_guard(main_program, startup_program): + i = paddle.static.data(name="i", shape=[1], dtype='float32') + i.stop_gradient = False + a = 2.0 * i + out = paddle.static.nn.cond( + i < 5.0, + lambda: paddle.add(a, a), + lambda: paddle.subtract(a, a), + ) + mean = paddle.mean(out) + +with LogWriter(logdir="./log/cond_test/") as writer: + writer.add_graph( + model=main_program, + input_spec=[paddle.static.InputSpec([1], 'float32')], + verbose=True, + is_pir=True) diff --git a/demo/components/pir_graph_test.py b/demo/components/pir_graph_test.py new file mode 100755 index 000000000..e870b8dcc --- /dev/null +++ b/demo/components/pir_graph_test.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import paddle +import paddle.nn.functional as F +from paddle import nn +from visualdl import LogWriter + + +class MyNet(nn.Layer): + def __init__(self): + super(MyNet, self).__init__() + self.conv1 = nn.Conv2D( + in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2) + self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2) + self.conv2 = nn.Conv2D( + in_channels=20, + out_channels=20, + kernel_size=5, + stride=1, + padding=2) + self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2) + self.fc = nn.Linear(in_features=980, out_features=10) + + def forward(self, inputs): + x = self.conv1(inputs) + x = F.relu(x) + x = self.max_pool1(x) + x = self.conv2(x) + x = F.relu(x) + x = self.max_pool2(x) + x = paddle.reshape(x, [x.shape[0], -1]) + x = self.fc(x) + return x + + +net = MyNet() +with LogWriter(logdir="./log/pir_graph_test/") as writer: + writer.add_graph( + model=net, + input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')], + verbose=True, + is_pir=True) diff --git a/demo/components/pir_translate.py b/demo/components/pir_program_test.py similarity index 53% rename from demo/components/pir_translate.py rename to demo/components/pir_program_test.py index 6b7868116..ee8b67993 100644 --- a/demo/components/pir_translate.py +++ b/demo/components/pir_program_test.py @@ -1,6 +1,18 @@ +# Copyright (c) 2024 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= import paddle -from paddle import ir - from visualdl import LogWriter paddle.enable_static() @@ -18,11 +30,9 @@ batch_norm = paddle.nn.BatchNorm(32, act='relu', data_layout='NHWC') out = batch_norm(conv2d(tanh_out)) -newir_program = ir.translate_to_new_ir(main_program.desc) - with LogWriter(logdir="./log/program_test/") as writer: writer.add_graph( - model=newir_program, + model=main_program, input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')], verbose=True, is_pir=True) diff --git a/demo/components/while_test.py b/demo/components/while_test.py new file mode 100755 index 000000000..a5e0838a3 --- /dev/null +++ b/demo/components/while_test.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import paddle +from visualdl import LogWriter + +paddle.enable_static() +main_program = paddle.static.Program() +startup_program = paddle.static.Program() +with paddle.static.program_guard(main_program, startup_program): + linear = paddle.nn.Linear(16, 10) + + def cond(i, loop_len, x, result): + return i < loop_len + + def body(i, loop_len, x, result): + result = linear(x) + paddle.increment(i) + return [i, loop_len, x, result] + + x = paddle.static.data(name='x', shape=[32, 16], dtype='float32') + i = paddle.zeros(shape=[1], dtype='int64') + loop_len = paddle.ones(shape=[1], dtype='int64') + result = paddle.zeros( + shape=x.shape[:-1] + linear.weight.shape[-1:], dtype="float32" + ) + result.stop_gradient = False + _, _, _, results = paddle.static.nn.while_loop( + cond, body, [i, loop_len, x, result] + ) + loss = paddle.mean(results) + +with LogWriter(logdir="./log/while_test/") as writer: + writer.add_graph( + model=main_program, + input_spec=[paddle.static.InputSpec([1], 'float32')], + verbose=True, + is_pir=True) diff --git a/docs/Paddle PIR Visualization.md b/docs/Paddle PIR Visualization.md new file mode 100755 index 000000000..1e7e1a36f --- /dev/null +++ b/docs/Paddle PIR Visualization.md @@ -0,0 +1,241 @@ +## 项目信息 +### 项目名称 +飞桨PaddlePaddle-PIR适配VisualDL模型可视化 + +### 方案描述 +目前的VisualDL的计算图可视化思路为:根据静态图获取全部变量和算子信息,构建特定格式的vdlgraph.log文件,后利用自定义的Model类读取数据并支持前端进行可视化。其中需要重点关注的是计算图算子和变量之间的输入输出关系,即算子中的`input_vars`和`output_vars`信息和变量中的`from_node`和`to_nodes`信息,这些输入输出信息决定了可视化的计算图结构。 + +现有PIR计算图可视化初步实现为:从PIR的program中获取可视化所需的变量和算子信息,构建结构相同的vdlgraph.log文件,目前实现的不足主要有以下四点: + +1. 现有方法没有考虑PIR的新特性,只提取顶层block的变量和算子,无法可视化含有多层block的控制流结构的计算图。 +2. 现有方法不支持layer的展开收缩,并且没有在vdlgraph.log文件中存储计算图边信息 +3. 现有方法只支持静态图(paddle.base.libpaddle.pir.Program)输入,不支持动态图输入 +4. 现有方法不支持可视化PIR json格式存储的模型 + +针对以上四点不足,分别设计对应解决方案: + +1. 重写PIR的program分析部分,按照深度优先搜索策略,从顶层block逐层获取每层的算子和变量,重点关注跨block的变量输入输出关系。 +2. 将控制流结构构建为layer结构(可收缩、展开),收缩时隐藏内部子block,重点体现整体模型结构和数据流向,展开时虚线框标记控制流内部block。重点体现内部算子关系和数据流向;仿照paddle2.x的program分析部分,增加模型边信息分析功能。 +3. 基于paddle3.x的动转静、save_load功能,修改生成vdlgraph文件的功能的输入接口,支持静态图和动态图的输入。 +4. 修改前端动态图更换模型接口,支持*.json文件输入,修改后端所有相关接口和功能代码,分别处理*.pdmodel和*.json两类文件。 + +## 项目总结 +### 核心功能描述 +#### VisualDL支持PIR 控制流算子可视化 +在PIR中,控制流算子都拥有子block,在子block中存放分支或者循环体包含的算子信息,下面为一个简单的包含ifop的模型IR + +```plain +{ + (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"i",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[1],stop_gradient:[false]} : () -> builtin.tensor<1xf32> + (%1) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Double)2} : () -> builtin.tensor<1xf32> + (%2) = "pd_op.scale" (%0, %1) {bias:(Float)0,bias_after_scale:true,stop_gradient:[false]} : (builtin.tensor<1xf32>, builtin.tensor<1xf32>) -> builtin.tensor<1xf32> + (%3) = "pd_op.full" () {dtype:(pd_op.DataType)float32,place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[],stop_gradient:[true],value:(Double)5} : () -> builtin.tensor + (%4) = "pd_op.less_than" (%0, %3) {stop_gradient:[true]} : (builtin.tensor<1xf32>, builtin.tensor) -> builtin.tensor<1xb> + (%5) = "pd_op.if" (%4) {stop_gradient:[false]} -> builtin.tensor<1xf32> { + (%6) = "pd_op.add" (%2, %2) {stop_gradient:[false]} : (builtin.tensor<1xf32>, builtin.tensor<1xf32>) -> builtin.tensor<1xf32> + () = "cf.yield" (%6) {} : (builtin.tensor<1xf32>) -> + } else { + (%7) = "pd_op.subtract" (%2, %2) {stop_gradient:[false]} : (builtin.tensor<1xf32>, builtin.tensor<1xf32>) -> builtin.tensor<1xf32> + () = "cf.yield" (%7) {} : (builtin.tensor<1xf32>) -> + } + (%8) = "pd_op.mean" (%5) {axis:(pd_op.IntArray)[],keepdim:false,stop_gradient:[false]} : (builtin.tensor<1xf32>) -> builtin.tensor +} +``` + +其中line8-12中的四个算子在子block中,由于现有方法只遍历顶层block的算子,忽略了子block,所以不支持控制流算子可视化。为了解决这个问题,我们首先从子block中获取算子和变量信息,增加`get_sub_var`和`get_sub_ops`函数用于提取子block内的算子和变量,二者均为递归函数进而处理多层block嵌套情况。在遍历顶层block时遇到控制流算子会调用这两个函数,在这个需求中要重点关注跨block的变量输入输出关系,相关判断代码为: + +```plain +def is_same_block_op(from_node, to_node, all_ops): + if all_ops[to_node]["parent_node"] == '/': + return False + from_ancestors = set() + while all_ops[from_node]["parent_node"] != '/': + from_ancestors.add(all_ops[from_node]["parent_node"]) + from_node = all_ops[from_node]["parent_node"] + if all_ops[to_node]["parent_node"] in from_ancestors: + return False + else: + return True +``` + +为了计算图可视化的直观和美观,我们在visualdl中将控制流算子表示成一个可收缩可展开的layer,效果如下: + +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725995876633-33d068ea-96ac-4c81-b637-0a242e22838a.png) + +展开后效果如下: + +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725995789430-d99ae3bd-b4b8-45a8-9dd1-42d13cefd871.png) + +其中展开后由layer上方的名称标识为这是一个控制流算子,这里需要注意的是,在此途中pd_op.less_than_0算子计算了ifop的条件值,应该有一条线连接pd_op.less_than_0和pd_op.if_0展开后的layer,但由于前端时基于netron的,目前不支持算子连线到一个layer。 + +#### VisualDL支持获取PIR program边信息 +实现逻辑和旧IR中一致,核心代码如下: + +```plain +# edge info +for var_name in all_vars.keys(): + construct_edges(var_name, all_ops, all_vars, all_edges) + +for src_node, to_node in all_edges.keys(): + all_ops[src_node]['edge_output_nodes'].append(to_node) + all_ops[to_node]['edge_input_nodes'].append(src_node) + all_edges[(src_node, + to_node)]['vars'] = list(all_edges[(src_node, + to_node)]['vars']) + if len(all_edges[(src_node, to_node)]['vars']) > 1: + all_edges[(src_node, to_node)]['label'] = str( + len(all_edges[(src_node, to_node)]['vars'])) + ' tensors' + elif len(all_edges[(src_node, to_node)]['vars']) == 1: + all_edges[(src_node, to_node)]['label'] = str( + all_vars[all_edges[(src_node, to_node)]['vars'][0]]['shape']) +``` + +#### 支持输入静态计算图和动态计算图 +目前的PIR分析是针对于静态计算图的,对于动态计算图将进行动转静和save load得到静态图进行分析,核心代码如下: + +```plain +if isinstance(model, paddle.base.libpaddle.pir.Program): + result = analyse_pir(model) +else: + model = paddle.jit.to_static(model, input_spec) + paddle.jit.save(model, os.path.join(tmp, 'temp')) + model_data = paddle.jit.load(os.path.join(tmp, 'temp')) + result = analyse_pir(model_data.program()) +``` + +#### 支持输入json格式模型 +PIR下模型动转静后存储为json文件,再前端需要能导入json文件进行可视化,目前只能实现在动态图界面导入json文件,核心后端代码为: + +```plain +def set_input_graph(self, content, file_type='pdmodel'): + if isinstance(content, str): + if not is_VDLGraph_file(content): + return + if 'pdmodel' in content: + file_type = 'pdmodel' + elif 'json' in content: + file_type = 'json' + else: + file_type = 'vdlgraph' + content = bfile.BFile(content, 'rb').read() + + if file_type == 'pdmodel': + data = analyse_model(content) + self.graph_buffer['manual_input_model'] = Model(data) + + elif file_type == 'json': + json_object = json.loads(content) + with tempfile.TemporaryDirectory() as tmp: + with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: + json.dump(json_object, json_file, indent=4) + model_data = load(os.path.join(tmp, 'temp')) + data = analyse_pir(model_data.program()) + self.graph_buffer['manual_input_model'] = Model(data) + + elif file_type == 'vdlgraph': + self.graph_buffer['manual_input_model'] = Model( + json.loads(content.decode())) + + else: + return +``` + +### 遇到的问题及解决方案 +1. 为了使计算图数据流可视化效果更加直观,我们在控制流子block中添加了一个output算子,这是因为在PIR中,控制流子block以辅助算子cf.yield作为结束,这使得展开后控制流的layer没有参数传出,因此我们添加了一个output算子,输入为控制流所有cf.yield,输出为未展开前控制流算子的输出,具体实现代码如下: + +```plain +def create_control_output_node(all_ops, all_vars, control_node_name): + op_name = control_node_name + '/' + "output" + all_ops[op_name] = {} + all_ops[op_name]['name'] = op_name + all_ops[op_name]['show_name'] = op_name + + all_ops[op_name]['type'] = "control_op.output" + all_ops[op_name]['dtype'] = all_ops[control_node_name]['dtype'] + all_ops[op_name]['input_vars'] = {} + all_ops[op_name]['output_vars'] = all_ops[control_node_name]['output_vars'] + + all_ops[op_name]['is_leaf_node'] = True + for var in all_vars: + if all_vars[var]['from_node'] == control_node_name: + all_ops[op_name]['output_vars'][var] = [var] + all_vars[var]['from_node'] = op_name + + all_ops[op_name]['attrs'] = all_ops[control_node_name]['attrs'] + all_ops[op_name]['attr_types'] = all_ops[control_node_name]['attr_types'] + all_ops[op_name]['children_node'] = [] + all_ops[op_name]['input_nodes'] = [] + all_ops[op_name]['output_nodes'] = [] + all_ops[op_name]['edge_input_nodes'] = [] + all_ops[op_name]['edge_output_nodes'] = [] + all_ops[op_name]['parent_node'] = control_node_name + all_ops[control_node_name]['children_node'].append(op_name) + return all_ops, all_vars +``` + +2. 对于循环控制流算子(whileop),根据中间表示可视化会体现不出循环的那条线 ,不利于数据流的展示,因此我们对于pd_op.increment算子我们增加一条边指向上游算子,代表循环结构的数据流向,效果如下:![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725975662978-9479a84a-6a1d-418c-b97e-d6c35d92677d.png?x-oss-process=image%2Fformat%2Cwebp%2Fresize%2Cw_937%2Climit_0)代码如下: + +```plain +if op.name() == "pd_op.increment_": + all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node']) + all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var] +``` + +3. 对于builtin.parameter算子,默认的persistable属性为True,这会导致前端可视化后这个算子没有输出,因此我们将此类算子persistable属性均设置为False,代码如下: + +```plain +if op.name() == "builtin.parameter": + all_vars[var_name]['persistable'] = False +``` + +4. visualdl的前端基于netron实现,其对于一些算子类型类如conv2d等会有不同颜色标识,前端根据vdlgraph.log中算子的type字段判断类型,在旧IR中,可以通过op.type()获取正确的算子类型,但是PIR的算子没有type()接口,因此我们从算子名中获取算子类型,后续Paddle更新后可以考虑修改,目前实现代码为:`all_ops[op_name]['type'] = op.name().replace("pd_op.", "")` + +5. 目前前端静态页面的上传模型和终端visualdl --model命令不支持可视化json格式的模型,这是因为静态模型可视化是基于netron实现的,在visualdl中,静态模型解析是在前端进行的,利用netron的模型解析实现,目前netron不支持json格式paddle模型的解析,自行实现也十分复杂,这个需求列入TODO。 + +### 测试用例 +目前实现了五个测试用例,分别为pir_program_test,pir_graph_test,cond_test,while_test,cond_inside_cond_test,分别测试静态图输入,动态图输入,分支结构模型,循环结构模型,分支嵌套结构模型 + +#### 测试脚本 +1. cd VisualDL +2. export FLAGS_enable_pir_api=1 +3. python demo/components/pir_program_test.py (pir_graph_test,cond_test,while_test,cond_inside_cond_test)输出文件将在VisualDL/log/program_test路径下 +4. visualdl --logdir ./log/program_test/ --host 0.0.0.0 注意此时在VisualDL目录下 + +#### 测试效果 +运行测例后在[http://0.0.0.0:8040/](http://0.0.0.0:8040/)查看可视化计算图 + +##### pir_program_test(PIR静态计算图可视化) +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996090463-249b4ade-ad21-47fc-8017-dee9ad30f979.png) + +##### pir_graph_test(输入PIR动态图可视化) +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996125143-5c21f177-6fa3-4203-84ea-d1d276323163.png) + +##### cond_test(ifop可视化) +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996043233-fbf7b462-7436-41b4-8d8e-55febb4d10fe.png) + +全展开后 + +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996056427-ca87bc8e-78aa-4bbd-8abb-d9c6669780e9.png) + +##### while_test(whileop可视化) +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996156062-275278ea-900a-44d9-a477-55dac6494456.png) + +全展开后 + +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996166186-f5f350af-90fc-4192-835c-4d7847a9fa43.png) + +##### cond_inside_cond_test(双层ifop嵌套结构可视化) +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996206268-4586af9c-7f68-4b4f-9790-f26db4f1dd58.png) + +展开一层ifop + +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996221339-32a802b7-bad3-4369-9505-9b5b5e08e863.png) + +全展开后 + +![](https://cdn.nlark.com/yuque/0/2024/png/32921027/1725996236240-5b6252b6-acb3-4f4b-9bfa-7aab74ce4b35.png) + +### 后续工作安排 ++ 探索如何从json格式模型中获取layer数据 ++ 优化完善代码 + diff --git a/frontend/packages/core/src/pages/graphDynamic.tsx b/frontend/packages/core/src/pages/graphDynamic.tsx index e2f8adb38..792627386 100644 --- a/frontend/packages/core/src/pages/graphDynamic.tsx +++ b/frontend/packages/core/src/pages/graphDynamic.tsx @@ -101,7 +101,7 @@ const Graph: FunctionComponent = () => { const onChangeFile = (e: React.ChangeEvent) => { const target = e.target as EventTarget & HTMLInputElement; const file: FileList | null = target.files as FileList; - if (file[0].name.split('.')[1] !== 'pdmodel') { + if (file[0].name.split('.')[1] !== 'pdmodel' && file[0].name.split('.')[1] !== 'json') { alert('该页面只能解析paddle的模型,如需解析请跳转网络结构静态图页面'); return; } diff --git a/frontend/packages/netron/src/view.js b/frontend/packages/netron/src/view.js index 4fb672696..e7cf17a2a 100644 --- a/frontend/packages/netron/src/view.js +++ b/frontend/packages/netron/src/view.js @@ -1394,7 +1394,7 @@ view.ModelFactoryService = class { this.register('./uff', ['.uff', '.pb', '.trt', '.pbtxt', '.uff.txt']); this.register('./sklearn', ['.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5']); this.register('./cntk', ['.model', '.cntk', '.cmf', '.dnn']); - this.register('./paddle', ['.paddle', '.pdmodel', '__model__']); + this.register('./paddle', ['.paddle', '.pdmodel', '.json', '__model__']); this.register('./armnn', ['.armnn']); this.register('./bigdl', ['.model', '.bigdl']); this.register('./darknet', ['.cfg', '.model']); diff --git a/frontend/packages/netron/src/view2.js b/frontend/packages/netron/src/view2.js index 55c9705ef..bda5f4113 100644 --- a/frontend/packages/netron/src/view2.js +++ b/frontend/packages/netron/src/view2.js @@ -1075,7 +1075,7 @@ view.ModelFactoryService = class { this.register('./uff', ['.uff', '.pb', '.trt', '.pbtxt', '.uff.txt']); this.register('./sklearn', ['.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5']); this.register('./cntk', ['.model', '.cntk', '.cmf', '.dnn']); - this.register('./paddle', ['.paddle', '.pdmodel', '__model__']); + this.register('./paddle', ['.paddle', '.pdmodel', '.json', '__model__']); this.register('./armnn', ['.armnn']); this.register('./bigdl', ['.model', '.bigdl']); this.register('./darknet', ['.cfg', '.model']); diff --git a/frontend/packages/netron2/src/view.js b/frontend/packages/netron2/src/view.js index ab45507da..acfc6bf37 100644 --- a/frontend/packages/netron2/src/view.js +++ b/frontend/packages/netron2/src/view.js @@ -1398,7 +1398,7 @@ view.ModelFactoryService = class { this.register('./uff', ['.uff', '.pb', '.trt', '.pbtxt', '.uff.txt']); this.register('./sklearn', ['.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5']); this.register('./cntk', ['.model', '.cntk', '.cmf', '.dnn']); - this.register('./paddle', ['.paddle', '.pdmodel', '__model__']); + this.register('./paddle', ['.paddle', '.pdmodel', '.json', '__model__']); this.register('./armnn', ['.armnn']); this.register('./bigdl', ['.model', '.bigdl']); this.register('./darknet', ['.cfg', '.model']); diff --git a/frontend/packages/netron2/src/view2.js b/frontend/packages/netron2/src/view2.js index 7ad3c509d..daebb297d 100644 --- a/frontend/packages/netron2/src/view2.js +++ b/frontend/packages/netron2/src/view2.js @@ -1076,7 +1076,7 @@ view.ModelFactoryService = class { this.register('./uff', ['.uff', '.pb', '.trt', '.pbtxt', '.uff.txt']); this.register('./sklearn', ['.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5']); this.register('./cntk', ['.model', '.cntk', '.cmf', '.dnn']); - this.register('./paddle', ['.paddle', '.pdmodel', '__model__']); + this.register('./paddle', ['.paddle', '.pdmodel', '.json', '__model__']); this.register('./armnn', ['.armnn']); this.register('./bigdl', ['.model', '.bigdl']); this.register('./darknet', ['.cfg', '.model']); diff --git a/visualdl/component/graph/__init__.py b/visualdl/component/graph/__init__.py index bb89f7dc2..13e3036dc 100644 --- a/visualdl/component/graph/__init__.py +++ b/visualdl/component/graph/__init__.py @@ -14,6 +14,7 @@ # ======================================================================= from .exporter import translate_graph from .graph_component import analyse_model +from .graph_component import analyse_pir from .netron_graph import Model -__all__ = ['translate_graph', 'analyse_model', 'Model'] +__all__ = ['translate_graph', 'analyse_model', 'analyse_pir', 'Model'] diff --git a/visualdl/component/graph/exporter.py b/visualdl/component/graph/exporter.py index 3a9c545de..c04fdadb0 100644 --- a/visualdl/component/graph/exporter.py +++ b/visualdl/component/graph/exporter.py @@ -23,6 +23,13 @@ def translate_graph(model, input_spec, verbose=True, **kwargs): + try: + import paddle + except Exception: + print("Paddlepaddle is required to use add_graph interface.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") is_pir = kwargs.get('is_pir', False) with tempfile.TemporaryDirectory() as tmp: if (not is_pir): @@ -34,7 +41,13 @@ def translate_graph(model, input_spec, verbose=True, **kwargs): model_data = open(os.path.join(tmp, 'temp.pdmodel'), 'rb').read() result = analyse_model(model_data) else: - result = analyse_pir(model) + if isinstance(model, paddle.base.libpaddle.pir.Program): + result = analyse_pir(model) + else: + model = paddle.jit.to_static(model, input_spec) + paddle.jit.save(model, os.path.join(tmp, 'temp')) + model_data = paddle.jit.load(os.path.join(tmp, 'temp')) + result = analyse_pir(model_data.program()) if verbose: print_model(result) result = json.dumps(result, indent=2) diff --git a/visualdl/component/graph/graph_component.py b/visualdl/component/graph/graph_component.py index 5e91869b3..775da25c5 100644 --- a/visualdl/component/graph/graph_component.py +++ b/visualdl/component/graph/graph_component.py @@ -365,44 +365,267 @@ def analyse_model(model_pb): # noqa: C901 return final_data +def is_control_flow(op): + return op.name() == "pd_op.if" or op.name() == "pd_op.while" + + +def is_same_block_op(from_node, to_node, all_ops): + if all_ops[to_node]["parent_node"] == '/': + return False + from_ancestors = set() + while all_ops[from_node]["parent_node"] != '/': + from_ancestors.add(all_ops[from_node]["parent_node"]) + from_node = all_ops[from_node]["parent_node"] + if all_ops[to_node]["parent_node"] in from_ancestors: + return False + else: + return True + + +def create_control_output_node(all_ops, all_vars, control_node_name): + op_name = control_node_name + '/' + "output" + all_ops[op_name] = {} + all_ops[op_name]['name'] = op_name + all_ops[op_name]['show_name'] = op_name + + all_ops[op_name]['type'] = "control_op.output" + all_ops[op_name]['dtype'] = all_ops[control_node_name]['dtype'] + all_ops[op_name]['input_vars'] = {} + all_ops[op_name]['output_vars'] = all_ops[control_node_name]['output_vars'] + + all_ops[op_name]['is_leaf_node'] = True + for var in all_vars: + if all_vars[var]['from_node'] == control_node_name: + all_ops[op_name]['output_vars'][var] = [var] + all_vars[var]['from_node'] = op_name + + all_ops[op_name]['attrs'] = all_ops[control_node_name]['attrs'] + all_ops[op_name]['attr_types'] = all_ops[control_node_name]['attr_types'] + all_ops[op_name]['children_node'] = [] + all_ops[op_name]['input_nodes'] = [] + all_ops[op_name]['output_nodes'] = [] + all_ops[op_name]['edge_input_nodes'] = [] + all_ops[op_name]['edge_output_nodes'] = [] + all_ops[op_name]['parent_node'] = control_node_name + all_ops[control_node_name]['children_node'].append(op_name) + return all_ops, all_vars + + +def safe_get_shape(op): + try: + return op.result(0).shape + except Exception: + return [] + + +def safe_get_type(op): + try: + return op.result(0).dtype.name + except Exception: + return '' + + +def safe_get_dtype(op): + try: + return op.result(0).dtype.name + except Exception: + return '' + + +def safe_get_persistable(op): + try: + if op.name() == "builtin.parameter": + return False + else: + return op.result(0).persistable + except Exception: + return False + + +def get_sub_ops(op, op_name, all_ops, all_vars): + try: + from paddle.utils.unique_name import generate + except Exception: + print("Paddlepaddle is required to use add_graph interface.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") + for sub_block in op.blocks(): + for sub_op in sub_block.ops: + sub_op_name0 = generate(sub_op.name()) + sub_op_name = op_name + '/' + sub_op_name0 + all_ops[sub_op_name] = {} + all_ops[sub_op_name]['name'] = sub_op_name + all_ops[sub_op_name]['show_name'] = sub_op_name + all_ops[sub_op_name]['type'] = sub_op.name().replace("pd_op.", "") + all_ops[sub_op_name]['dtype'] = safe_get_dtype(sub_op) + all_ops[sub_op_name]['input_vars'] = {} + all_ops[sub_op_name]['output_vars'] = {} + all_ops[sub_op_name]['is_leaf_node'] = True + now_var = utils.gen_var_name(sub_op.results()) + for source in sub_op.operands_source(): + input_name = utils.gen_var_name(source) + if sub_op.name() == "pd_op.increment_": + all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node']) + all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var] + all_ops[sub_op_name]['input_vars'][input_name] = [input_name] + all_vars[input_name]['to_nodes'].append(sub_op_name) + all_vars[now_var]['from_node'] = sub_op_name + all_ops[sub_op_name]['output_vars'][now_var] = [now_var] + + try: + attrs = op.results()[0].get_defining_op().attrs() + if 'place' in attrs: + attrs['place'] = str(attrs['place']) + attrs['dtype'] = safe_get_dtype(op) + except Exception: + # attrs = {} + pass + + all_ops[sub_op_name]['attrs'] = attrs + all_ops[sub_op_name]['attr_types'] = attrs + all_ops[sub_op_name]['children_node'] = [] + all_ops[sub_op_name]['input_nodes'] = [] + all_ops[sub_op_name]['output_nodes'] = [] + all_ops[sub_op_name]['edge_input_nodes'] = [] + all_ops[sub_op_name]['edge_output_nodes'] = [] + all_ops[sub_op_name]["parent_node"] = op_name + all_ops[op_name]['children_node'].append(sub_op_name) + + # yield + if sub_op.name() == 'cf.yield': + var_name = "tmp_var_" + sub_op_name0 + all_vars[var_name] = {} + all_vars[var_name]['name'] = var_name + all_vars[var_name]['dtype'] = '' + all_vars[var_name]['shape'] = [] + all_vars[var_name]['value'] = [] + all_vars[var_name]['persistable'] = False + all_vars[var_name]['attrs'] = {} + all_vars[var_name]['from_node'] = sub_op_name + all_ops[sub_op_name]['output_vars'][var_name] = [var_name] + control_output = all_ops[sub_op_name]["parent_node"] + '/' + "output" + all_vars[var_name]['to_nodes'] = [control_output] + all_ops[control_output]['input_vars'][var_name] = [var_name] + if is_control_flow(sub_op): + all_ops[sub_op_name]['is_leaf_node'] = False + all_ops, all_vars = create_control_output_node(all_ops, all_vars, sub_op_name) + all_ops, all_vars = get_sub_ops(sub_op, sub_op_name, all_ops, all_vars) + + return all_ops, all_vars + + +def get_sub_var(op, all_vars): + for sub_block in op.blocks(): + for sub_op in sub_block.ops: + var_name = utils.gen_var_name(sub_op.results()) + all_vars[var_name] = {} + all_vars[var_name]['name'] = var_name + try: + attrs = op.results()[0].get_defining_op().attrs() + if 'place' in attrs: + attrs['place'] = str(attrs['place']) + attrs['dtype'] = safe_get_dtype(op) + except Exception: + attrs = {} + + all_vars[var_name]['shape'] = safe_get_shape(sub_op) + all_vars[var_name]['type'] = safe_get_type(sub_op) + all_vars[var_name]['dtype'] = safe_get_dtype(sub_op) + all_vars[var_name]['value'] = [] + all_vars[var_name]['persistable'] = safe_get_persistable(sub_op) + all_vars[var_name]['attrs'] = attrs + all_vars[var_name]['from_node'] = '' + all_vars[var_name]['to_nodes'] = [] + if is_control_flow(sub_op): + all_vars = get_sub_var(sub_op, all_vars) + return all_vars + + +def update_node_connections(all_vars, all_ops): + for variable_name in all_vars: + if all_vars[variable_name]['from_node'] == '': + continue + from_node = all_vars[variable_name]['from_node'] + for to_node in all_vars[variable_name]['to_nodes']: + if is_same_block_op(from_node, to_node, all_ops): + all_vars[variable_name]['to_nodes'].append(all_ops[to_node]["parent_node"]) + all_ops[all_ops[to_node]["parent_node"]]['input_vars'][variable_name] = [variable_name] + from_node_name = all_vars[variable_name]['from_node'] + for to_node_name in all_vars[variable_name]['to_nodes']: + if to_node_name != from_node_name: + all_ops[from_node_name]['output_nodes'].append(to_node_name) + all_ops[to_node_name]['input_nodes'].append(from_node_name) + all_vars[variable_name]['to_nodes'] = list(set(all_vars[variable_name]['to_nodes'])) + + for node in all_ops: + if node != '/': + all_ops[node]['input_nodes'] = list(set(all_ops[node]['input_nodes'])) + all_ops[node]['output_nodes'] = list(set(all_ops[node]['output_nodes'])) + + return all_vars, all_ops + + def analyse_pir(program): - from paddle.utils.unique_name import generate + try: + from paddle.utils.unique_name import generate + except Exception: + print("Paddlepaddle is required to use add_graph interface.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") all_ops = {} all_vars = {} all_edges = {} + + # create '/' op + all_ops['/'] = {} + all_ops['/']['name'] = '/' + all_ops['/']['show_name'] = '/' + all_ops['/']['type'] = '' + all_ops['/']['attrs'] = {} + all_ops['/']['input_vars'] = {} + all_ops['/']['output_vars'] = {} + all_ops['/']['is_leaf_node'] = False + all_ops['/']['children_node'] = [] + # vars info - for op in (program.global_block().ops): + for op in program.global_block().ops: var_name = utils.gen_var_name(op.results()) all_vars[var_name] = {} all_vars[var_name]['name'] = var_name - attrs = op.results()[0].get_defining_op().attrs() - - if 'place' in attrs: - attrs['place'] = str(attrs['place']) - attrs['dtype'] = op.result(0).dtype.name - - all_vars[var_name]['shape'] = op.result(0).shape - all_vars[var_name]['type'] = op.result(0).dtype.name - all_vars[var_name]['dtype'] = op.result(0).dtype.name - + try: + attrs = op.results()[0].get_defining_op().attrs() + if 'place' in attrs: + attrs['place'] = str(attrs['place']) + attrs['dtype'] = safe_get_dtype(op) + except Exception: + pass + + all_vars[var_name]['shape'] = safe_get_shape(op) + all_vars[var_name]['type'] = safe_get_type(op) + all_vars[var_name]['dtype'] = safe_get_dtype(op) all_vars[var_name]['value'] = [] - all_vars[var_name]['persistable'] = op.result(0).is_persistable + all_vars[var_name]['persistable'] = safe_get_persistable(op) all_vars[var_name]['attrs'] = attrs all_vars[var_name]['from_node'] = '' all_vars[var_name]['to_nodes'] = [] + if is_control_flow(op): + all_vars = get_sub_var(op, all_vars) # ops info - for op in (program.global_block().ops): + for op in program.global_block().ops: op_name = generate(op.name()) + op_name = '/' + op_name - if op.num_operands() > 0: + if op.num_operands() >= 0: all_ops[op_name] = {} all_ops[op_name]['name'] = op_name all_ops[op_name]['show_name'] = op_name - all_ops[op_name]['type'] = op.result(0).dtype.name - all_ops[op_name]['dtype'] = op.result(0).dtype.name + all_ops[op_name]['type'] = op.name().replace("pd_op.", "") + all_ops[op_name]['dtype'] = safe_get_dtype(op) all_ops[op_name]['input_vars'] = {} all_ops[op_name]['output_vars'] = {} @@ -410,6 +633,9 @@ def analyse_pir(program): now_var = utils.gen_var_name(op.results()) for source in op.operands_source(): input_name = utils.gen_var_name(source) + if op.name() == "pd_op.increment_": + all_vars[now_var]['to_nodes'].append(all_vars[input_name]['from_node']) + all_ops[all_vars[input_name]['from_node']]['input_vars'][now_var] = [now_var] all_ops[op_name]['input_vars'][input_name] = [input_name] all_vars[input_name]['to_nodes'].append(op_name) all_vars[now_var]['from_node'] = op_name @@ -422,32 +648,33 @@ def analyse_pir(program): all_ops[op_name]['output_nodes'] = [] all_ops[op_name]['edge_input_nodes'] = [] all_ops[op_name]['edge_output_nodes'] = [] + all_ops[op_name]['parent_node'] = '/' + all_ops['/']['children_node'].append(op_name) - # create '/' op - all_ops['/'] = {} - all_ops['/']['name'] = '/' - all_ops['/']['show_name'] = '/' - all_ops['/']['type'] = '' - all_ops['/']['attrs'] = {} - all_ops['/']['input_vars'] = {} - all_ops['/']['output_vars'] = {} - all_ops['/']['is_leaf_node'] = False - all_ops['/']['children_node'] = [] - for node in all_ops: - if node != '/': - all_ops['/']['children_node'].append(node) + if is_control_flow(op): + all_ops[op_name]['is_leaf_node'] = False + all_ops, all_vars = create_control_output_node(all_ops, all_vars, op_name) + all_ops, all_vars = get_sub_ops(op, op_name, all_ops, all_vars) - for variable_name in all_vars: - if all_vars[variable_name]['from_node'] == '': - continue - from_node_name = all_vars[variable_name]['from_node'] - for to_node_name in all_vars[variable_name]['to_nodes']: - if to_node_name != from_node_name: - all_ops[from_node_name]['output_nodes'].append(to_node_name) - all_ops[to_node_name]['input_nodes'].append(from_node_name) + # update node connections + all_vars, all_ops = update_node_connections(all_vars, all_ops) # edge info - # TODO(Difers):add edge info in future + for var_name in all_vars.keys(): + construct_edges(var_name, all_ops, all_vars, all_edges) + + for src_node, to_node in all_edges.keys(): + all_ops[src_node]['edge_output_nodes'].append(to_node) + all_ops[to_node]['edge_input_nodes'].append(src_node) + all_edges[(src_node, + to_node)]['vars'] = list(all_edges[(src_node, + to_node)]['vars']) + if len(all_edges[(src_node, to_node)]['vars']) > 1: + all_edges[(src_node, to_node)]['label'] = str( + len(all_edges[(src_node, to_node)]['vars'])) + ' tensors' + elif len(all_edges[(src_node, to_node)]['vars']) == 1: + all_edges[(src_node, to_node)]['label'] = str( + all_vars[all_edges[(src_node, to_node)]['vars'][0]]['shape']) final_data = { 'version': _graph_version, diff --git a/visualdl/reader/graph_reader.py b/visualdl/reader/graph_reader.py index 1acc99ed9..4e8035bd5 100644 --- a/visualdl/reader/graph_reader.py +++ b/visualdl/reader/graph_reader.py @@ -14,8 +14,10 @@ # ======================================================================= import json import os +import tempfile from visualdl.component.graph import analyse_model +from visualdl.component.graph import analyse_pir from visualdl.component.graph import Model from visualdl.io import bfile @@ -30,7 +32,7 @@ def is_VDLGraph_file(path): Returns: True if the file is a VDL graph file, otherwise false. """ - if "vdlgraph" not in path and 'pdmodel' not in path: + if "vdlgraph" not in path and 'pdmodel' not in path and 'json' not in path: return False return True @@ -136,6 +138,20 @@ def get_graph(self, data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() if 'pdmodel' in self.walks[run]: graph_model = Model(analyse_model(data)) + elif 'json' in self.walks[run]: + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") + json_object = json.loads(data) + with tempfile.TemporaryDirectory() as tmp: + with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: + json.dump(json_object, json_file, indent=4) + model_data = load(os.path.join(tmp, 'temp')) + graph_model = Model(analyse_pir(model_data.program())) else: graph_model = Model(json.loads(data.decode())) self.graph_buffer[run] = graph_model @@ -163,6 +179,20 @@ def search_graph_node(self, run, nodeid, keep_state=False, is_node=True): data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() if 'pdmodel' in self.walks[run]: graph_model = Model(analyse_model(data)) + elif 'json' in self.walks[run]: + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") + json_object = json.loads(data) + with tempfile.TemporaryDirectory() as tmp: + with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: + json.dump(json_object, json_file, indent=4) + model_data = load(os.path.join(tmp, 'temp')) + graph_model = Model(analyse_pir(model_data.program())) else: graph_model = Model(json.loads(data.decode())) self.graph_buffer[run] = graph_model @@ -184,6 +214,20 @@ def get_all_nodes(self, run): data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() if 'pdmodel' in self.walks[run]: graph_model = Model(analyse_model(data)) + elif 'json' in self.walks[run]: + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") + json_object = json.loads(data) + with tempfile.TemporaryDirectory() as tmp: + with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: + json.dump(json_object, json_file, indent=4) + model_data = load(os.path.join(tmp, 'temp')) + graph_model = Model(analyse_pir(model_data.program())) else: graph_model = Model(json.loads(data.decode())) self.graph_buffer[run] = graph_model @@ -206,6 +250,8 @@ def set_input_graph(self, content, file_type='pdmodel'): return if 'pdmodel' in content: file_type = 'pdmodel' + elif 'json' in content: + file_type = 'json' else: file_type = 'vdlgraph' content = bfile.BFile(content, 'rb').read() @@ -214,6 +260,22 @@ def set_input_graph(self, content, file_type='pdmodel'): data = analyse_model(content) self.graph_buffer['manual_input_model'] = Model(data) + elif file_type == 'json': + try: + from paddle.jit import load + except Exception: + print("Paddlepaddle is required to load json file.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") + json_object = json.loads(content) + with tempfile.TemporaryDirectory() as tmp: + with open(os.path.join(tmp, 'temp.json'), 'w') as json_file: + json.dump(json_object, json_file, indent=4) + model_data = load(os.path.join(tmp, 'temp')) + data = analyse_pir(model_data.program()) + self.graph_buffer['manual_input_model'] = Model(data) + elif file_type == 'vdlgraph': self.graph_buffer['manual_input_model'] = Model( json.loads(content.decode())) diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 0ef7b6dc1..eea55455d 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -351,6 +351,9 @@ def graph_upload(self): if 'pdmodel' in file_handle.filename: graph_reader.set_input_graph(file_handle.stream.read(), 'pdmodel') + elif 'json' in file_handle.filename: + graph_reader.set_input_graph(file_handle.stream.read(), + 'json') elif 'vdlgraph' in file_handle.filename: graph_reader.set_input_graph(file_handle.stream.read(), 'vdlgraph')