diff --git a/examples/end2end_tfkeras.py b/examples/end2end_tfkeras.py index 19da4d3b3..21ffd2d76 100644 --- a/examples/end2end_tfkeras.py +++ b/examples/end2end_tfkeras.py @@ -41,9 +41,7 @@ ######################################## # Saves the model. -if not os.path.exists("simple_rnn"): - os.mkdir("simple_rnn") -tf.keras.models.save_model(model, "simple_rnn") +model.export("simple_rnn") ######################################## # Run the command line. @@ -57,7 +55,7 @@ ######################################## # Runs onnxruntime. session = InferenceSession("simple_rnn.onnx") -got = session.run(None, {'input_1': input}) +got = session.run(None, {'keras_tensor': input}) print(got[0]) ######################################## @@ -68,5 +66,5 @@ # Measures processing time. print('tf:', timeit.timeit('model.predict(input)', number=100, globals=globals())) -print('ort:', timeit.timeit("session.run(None, {'input_1': input})", +print('ort:', timeit.timeit("session.run(None, {'keras_tensor': input})", number=100, globals=globals()))