From 2dcf9ec5a6d6873cee0461d4c3f3e6990916e020 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 7 Feb 2024 13:40:21 +0300 Subject: [PATCH] [Keras] Enable Dense operator for any input dims (#16526) Our dense op expects 2D, but there are no limitation in Keras on the shape of the input tensor. Reshaping of all "batch" axes into one was added in this commit. After that, it is possible to import Dense layer with ND input tensor from Keras to TVM. --- python/tvm/relay/frontend/keras.py | 14 ++++++++------ tests/python/frontend/keras/test_forward.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 21862089944e..d53647cc684c 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -266,11 +266,12 @@ def _convert_dense( # In case of RNN dense, input shape will be (1, 1, n) if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) - if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise tvm.error.OpAttributeInvalid( - f"Input shape {input_shape} is not valid for operator Dense." - ) - inexpr = _op.squeeze(inexpr, axis=[0]) + # Keras has no limitations on the shape of the input tensor. But our + # dense op expects 2D input. All inputs with number of dimensions > 2 + # are reshaped all "batch" axes into one. + # For example: (N, d1, d2, d3) -> (N * d1 * d2, d3) + new_batch_size = np.prod(input_shape[:-1]) + inexpr = _op.reshape(inexpr, newshape=(new_batch_size, input_shape[-1])) out = _op.nn.dense(data=inexpr, **params) if keras_layer.use_bias: bias = etab.new_const(weightList[1]) @@ -283,7 +284,8 @@ def _convert_dense( if act_type != "linear": out = _convert_activation(out, act_type, etab, data_layout) if input_dim > 2: - out = _op.expand_dims(out, axis=0) + out_shape = (*input_shape[:-1], units) + out = _op.reshape(out, newshape=out_shape) return out diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index aef137e634a7..0d05e34a155b 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -285,6 +285,16 @@ def test_forward_dense(self, keras_mod): keras_model = keras_mod.models.Model(data, x) verify_keras_frontend(keras_model, need_transpose=False) + data = keras_mod.layers.Input(shape=(120, 2560), name="image_set") + x = keras_mod.layers.Dense(1, activation="linear", name="e")(data) + keras_model = keras_mod.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + + data = keras_mod.layers.Input(shape=(10, 12, 2560), name="image_set") + x = keras_mod.layers.Dense(32, activation="linear", name="e")(data) + keras_model = keras_mod.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + def test_forward_permute(self, keras_mod): data = keras_mod.layers.Input(shape=(2, 3, 4)) x = keras_mod.layers.Permute([2, 3, 1])(data)