From 641f5425b524a0ff53e456d59cbd9a19e22421b0 Mon Sep 17 00:00:00 2001 From: Masoud Kaviani Date: Sat, 19 Oct 2024 16:43:48 +0330 Subject: [PATCH] Update end2end_tfkeras.py Tensorflow Keras save model function change to export for `pb` export extension Signed-off-by: Masoud Kaviani --- examples/end2end_tfkeras.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/end2end_tfkeras.py b/examples/end2end_tfkeras.py index 19da4d3b3..21ffd2d76 100644 --- a/examples/end2end_tfkeras.py +++ b/examples/end2end_tfkeras.py @@ -41,9 +41,7 @@ ######################################## # Saves the model. -if not os.path.exists("simple_rnn"): - os.mkdir("simple_rnn") -tf.keras.models.save_model(model, "simple_rnn") +model.export("simple_rnn") ######################################## # Run the command line. @@ -57,7 +55,7 @@ ######################################## # Runs onnxruntime. session = InferenceSession("simple_rnn.onnx") -got = session.run(None, {'input_1': input}) +got = session.run(None, {'keras_tensor': input}) print(got[0]) ######################################## @@ -68,5 +66,5 @@ # Measures processing time. print('tf:', timeit.timeit('model.predict(input)', number=100, globals=globals())) -print('ort:', timeit.timeit("session.run(None, {'input_1': input})", +print('ort:', timeit.timeit("session.run(None, {'keras_tensor': input})", number=100, globals=globals()))