diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index e89c0a3c03a9..21862089944e 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -967,7 +967,7 @@ def _convert_concat( if axis == -1: axis = 1 else: - axis = axis + 1 if axis < dims else 1 + axis = axis + 1 if axis < (dims - 1) else 1 return _op.concatenate(_as_list(inexpr), axis=axis) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 8c5b578060a4..aef137e634a7 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -177,6 +177,14 @@ def test_forward_concatenate(self, keras_mod): keras_model = keras_mod.models.Model([data1, data2], out) verify_keras_frontend(keras_model, layout="NHWC") verify_keras_frontend(keras_model, layout="NCHW") + # test axis at last dimension + data1 = keras_mod.layers.Input(shape=(1, 2, 2)) + data2 = keras_mod.layers.Input(shape=(1, 2, 3)) + merge_func = keras_mod.layers.Concatenate(axis=3) + out = merge_func([data1, data2]) + keras_model = keras_mod.models.Model([data1, data2], out) + verify_keras_frontend(keras_model, layout="NHWC") + verify_keras_frontend(keras_model, layout="NCHW") def test_forward_merge_dot(self, keras_mod): """test_forward_merge_dot""" @@ -643,6 +651,20 @@ def test_forward_resnet50(self, keras_mod, layout="NCHW"): ) verify_keras_frontend(keras_model, layout=layout) + def test_forward_inception_v3(self, keras_mod, layout="NCHW"): + """test_forward_inception_v3""" + if hasattr(keras_mod.applications, "InceptionV3"): + # Keras 2.4.x and older + inception_v3_mod = keras_mod.applications.InceptionV3 + else: + # Keras 2.6.x and newer + inception_v3_mod = keras_mod.applications.inception_v3.InceptionV3 + + keras_model = inception_v3_mod( + include_top=True, weights=None, input_shape=(299, 299, 3), classes=1000 + ) + verify_keras_frontend(keras_model, layout=layout) + def test_forward_mobilenet(self, keras_mod, layout="NCHW"): mobilenet_mod = get_mobilenet(keras_mod) @@ -877,6 +899,8 @@ def test_simplernn_with_infertype(self, keras_mod): sut.test_forward_xception(keras_mod=k) sut.test_forward_resnet50(keras_mod=k) sut.test_forward_resnet50(keras_mod=k, layout="NHWC") + sut.test_forward_inception_v3(keras_mod=k) + sut.test_forward_inception_v3(keras_mod=k, layout="NHWC") sut.test_forward_mobilenet(keras_mod=k) sut.test_forward_mobilenet(keras_mod=k, layout="NHWC") sut.test_forward_conv3d(keras_mod=k)