diff --git a/tests/unit/tf/core/test_encoder.py b/tests/unit/tf/core/test_encoder.py index 36c309ed64..2b89bd6493 100644 --- a/tests/unit/tf/core/test_encoder.py +++ b/tests/unit/tf/core/test_encoder.py @@ -122,6 +122,15 @@ def test_topk_encoder(music_streaming_data: Dataset): loaded_topk_encoder = tf.keras.models.load_model(tmpdir) batch_output = loaded_topk_encoder(batch[0]) + output_signature = loaded_topk_encoder.signatures["serving_default"].structured_output + assert len(output_signature) == 2 + assert output_signature["scores"] == tf.TensorSpec( + shape=(None, TOP_K), dtype=tf.float32, name="scores" + ) + assert output_signature["identifiers"] == tf.TensorSpec( + shape=(None, TOP_K), dtype=tf.int32, name="identifiers" + ) + assert list(batch_output.scores.shape) == [BATCH_SIZE, TOP_K] tf.debugging.assert_equal( topk_encoder.topk_layer._candidates,