Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ragged Tensor as an output from Tensorflow serving #2222

Open
bajaj6 opened this issue May 23, 2024 · 5 comments
Open

Ragged Tensor as an output from Tensorflow serving #2222

bajaj6 opened this issue May 23, 2024 · 5 comments

Comments

@bajaj6
Copy link

bajaj6 commented May 23, 2024

Bug Report

System information

  • OS Platform and Distribution: macOS
  • TensorFlow Serving installed from: binary
  • TensorFlow version: 2.14.1

Describe the problem

We use tensorflow serving to serve models in production. We have a use case where the output of the model is a ragged tensor.

To see if the tensorflow serving supports ragged tensor as output, we created this toy example.

import tensorflow as tf

# Initiate the model
inputs = tf.keras.Input(shape=(1,), ragged=True, name='input_1')
output = tf.keras.layers.Lambda(lambda x: x + 1, dtype=tf.float32)(inputs)

model = tf.keras.Model(inputs=[inputs], outputs=output)

model.compile()

# Serialise/Deserialise to the format the inferoo expects
model.save("./my_model/1", save_format="tf")
model = tf.keras.models.load_model("./my_model/1")


# Make predictions with ragged tensors
x = tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.float32)
out = model.predict([x])
print(out)
tf.debugging.assert_equal(out, tf.ragged.constant([[2, 3, 4], [5, 6]], dtype=tf.float32))
print("All good!")

We save the model to a local disk and then load the model via tensorflow serving. I used [saved_model_cli][1] to inspect model signatures.

The model output has datatype DT_INVALID, I guess the tensorflow serving will fail to load this model.

user@MR26DF61QG ~ % saved_model_cli show --dir /Users/user/ragged_tensor/my_model/1 --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['args_0'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: serving_default_args_0:0
    inputs['args_0_1'] tensor_info:
        dtype: DT_INT64
        shape: (-1)
        name: serving_default_args_0_1:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['lambda'] tensor_info:
        dtype: DT_INVALID
        shape: ()
        name:
  Method name is: tensorflow/serving/predict
The MetaGraph with tag set ['serve'] contains the following ops: {'StringJoin', 'VarHandleOp', 'ReadVariableOp', 'ShardedFilename', 'Placeholder', 'Select', 'AssignVariableOp', 'Const', 'Pack', 'MergeV2Checkpoints', 'RestoreV2', 'StatefulPartitionedCall', 'NoOp', 'DisableCopyOnRead', 'Identity', 'StaticRegexFullMatch', 'PartitionedCall', 'AddV2', 'SaveV2'}

Concrete Functions:
  Function Name: '__call__'
    Option #1
      Callable with:
        Argument #1
          DType: RaggedTensorSpec
          Value: RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int64)
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None
    Option #2
      Callable with:
        Argument #1
          DType: RaggedTensorSpec
          Value: RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int64)
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None

  Function Name: '_default_save_signature'
    Option #1
      Callable with:
        Argument #1
          DType: RaggedTensorSpec
          Value: RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int64)

  Function Name: 'call_and_return_all_conditional_losses'
    Option #1
      Callable with:
        Argument #1
          DType: RaggedTensorSpec
          Value: RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int64)
        Argument #2
          DType: bool
          Value: False
        Argument #3
          DType: NoneType
          Value: None
    Option #2
      Callable with:
        Argument #1
          DType: RaggedTensorSpec
          Value: RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int64)
        Argument #2
          DType: bool
          Value: True
        Argument #3
          DType: NoneType
          Value: None

Exact Steps to Reproduce

1- Run the above the python code to save the model to local disk
2- Run the saved_model_cli command to print model signatures
saved_model_cli show --dir /Users/user/ragged_tensor/my_model/1 --all

I did post this in Stackoverflow here but haven't received any response.

Thanks

@singhniraj08
Copy link

@bajaj6, Apologies for late reply. Can you try serving the same model on TF Serving and let us know if you face any issues. Please share the whole error stack trace to debug the issue on our end. Thank you!

@bajaj6
Copy link
Author

bajaj6 commented Jun 4, 2024

@singhniraj08

http://localhost:8501/v1/models/my_model:predict
{
      "inputs": {
          "args_0":[1.0, 3.0],
          "args_0_1":[1]
      }
}

Response:

{
    "error": "Tensor :0, specified in either feed_devices or fetch_devices was not found in the Graph"
}

Searching for this error:

@singhniraj08
Copy link

@bajaj6, Let us keep this issue as a feature request for supporting Ragged Tensor in Tensorflow serving. Thank you for reporting this issue.

@bajaj6
Copy link
Author

bajaj6 commented Jun 5, 2024

@singhniraj08 Sure, thanks.
Can you confirm that Ragged Tensor is not supported in Tensorflow serving?

@singhniraj08
Copy link

@bajaj6, Currently ragged tensor is not supported in Tensorflow Serving. One way to fix this is to remove the ragged tensor from the model. I will keep this issue open as a feature request for ragged tensor support. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants