Skip to content

Commit

Permalink
fix pir to kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
Difers committed Sep 20, 2023
1 parent bdaff69 commit 05067af
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
model=newir_program,
input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')],
verbose=True,
is_newir=True)
is_pir=True)
4 changes: 2 additions & 2 deletions visualdl/component/graph/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from .utils import print_model


def translate_graph(model, input_spec, verbose=True, is_pir=False):
import paddle
def translate_graph(model, input_spec, verbose=True, **kwargs):
is_pir = kwargs.get('is_pir', False)
with tempfile.TemporaryDirectory() as tmp:
if (not is_pir):
model._full_name = '{}[{}]'.format(model.__class__.__name__,
Expand Down
4 changes: 2 additions & 2 deletions visualdl/component/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def gen_var_name(ops):
if var in var_name:
return var_name[var]
else:
if (op.is_persistable):
try:
name = op.name
else:
except ValueError:
name = "tmp_var_" + str(var_idx[0])
var_idx[0] += 1
var_name[var] = name
Expand Down
8 changes: 3 additions & 5 deletions visualdl/writer/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import time

import numpy as np
from paddle.framework import in_pir_mode

from visualdl.component.base_component import audio
from visualdl.component.base_component import embedding
Expand Down Expand Up @@ -614,7 +613,7 @@ def add_roc_curve(self,
num_thresholds=num_thresholds,
weights=weights))

def add_graph(self, model, input_spec, verbose=False):
def add_graph(self, model, input_spec, verbose=False, **kwargs):
"""
Add a model graph to vdl graph file.
Args:
Expand Down Expand Up @@ -655,9 +654,8 @@ def forward(self, inputs):
verbose=True)
"""
try:
if in_pir_mode():
is_pir = True
result = translate_graph(model, input_spec, verbose, is_pir=False)
is_pir = kwargs.get('is_pir', False)
result = translate_graph(model, input_spec, verbose, is_pir=is_pir)
except Exception as e:
print("Failed to save model graph, error: {}".format(e))
raise e
Expand Down

0 comments on commit 05067af

Please sign in to comment.