Skip to content

Commit

Permalink
Update end2end_tfkeras.py (#2361)
Browse files Browse the repository at this point in the history
Tensorflow Keras save model function change to export for `pb` export extension

Signed-off-by: Masoud Kaviani <[email protected]>
Co-authored-by: Jay Zhang <[email protected]>
  • Loading branch information
MasoudKaviani and fatcat-z authored Dec 24, 2024
1 parent 6298b26 commit 4b97bcf
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions examples/end2end_tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@

########################################
# Saves the model.
if not os.path.exists("simple_rnn"):
os.mkdir("simple_rnn")
tf.keras.models.save_model(model, "simple_rnn")
model.export("simple_rnn")

########################################
# Run the command line.
Expand All @@ -57,7 +55,7 @@
########################################
# Runs onnxruntime.
session = InferenceSession("simple_rnn.onnx")
got = session.run(None, {'input_1': input})
got = session.run(None, {'keras_tensor': input})
print(got[0])

########################################
Expand All @@ -68,5 +66,5 @@
# Measures processing time.
print('tf:', timeit.timeit('model.predict(input)',
number=100, globals=globals()))
print('ort:', timeit.timeit("session.run(None, {'input_1': input})",
print('ort:', timeit.timeit("session.run(None, {'keras_tensor': input})",
number=100, globals=globals()))

0 comments on commit 4b97bcf

Please sign in to comment.