Skip to content

Commit

Permalink
[FRONTEND][KERAS] Fix bug concat convert for NCHW (#16159)
Browse files Browse the repository at this point in the history
* [FRONTEND][KERAS] Fix bug concat convert for NCHW

Fixed the bug in keras frontend for inception_v3 keras
in concat convertion for NCHW.

* fix the lint error

* Removed weight download
  • Loading branch information
krishnaraj36 authored Nov 28, 2023
1 parent 604b263 commit 3136ff4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def _convert_concat(
if axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims else 1
axis = axis + 1 if axis < (dims - 1) else 1
return _op.concatenate(_as_list(inexpr), axis=axis)


Expand Down
24 changes: 24 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ def test_forward_concatenate(self, keras_mod):
keras_model = keras_mod.models.Model([data1, data2], out)
verify_keras_frontend(keras_model, layout="NHWC")
verify_keras_frontend(keras_model, layout="NCHW")
# test axis at last dimension
data1 = keras_mod.layers.Input(shape=(1, 2, 2))
data2 = keras_mod.layers.Input(shape=(1, 2, 3))
merge_func = keras_mod.layers.Concatenate(axis=3)
out = merge_func([data1, data2])
keras_model = keras_mod.models.Model([data1, data2], out)
verify_keras_frontend(keras_model, layout="NHWC")
verify_keras_frontend(keras_model, layout="NCHW")

def test_forward_merge_dot(self, keras_mod):
"""test_forward_merge_dot"""
Expand Down Expand Up @@ -643,6 +651,20 @@ def test_forward_resnet50(self, keras_mod, layout="NCHW"):
)
verify_keras_frontend(keras_model, layout=layout)

def test_forward_inception_v3(self, keras_mod, layout="NCHW"):
"""test_forward_inception_v3"""
if hasattr(keras_mod.applications, "InceptionV3"):
# Keras 2.4.x and older
inception_v3_mod = keras_mod.applications.InceptionV3
else:
# Keras 2.6.x and newer
inception_v3_mod = keras_mod.applications.inception_v3.InceptionV3

keras_model = inception_v3_mod(
include_top=True, weights=None, input_shape=(299, 299, 3), classes=1000
)
verify_keras_frontend(keras_model, layout=layout)

def test_forward_mobilenet(self, keras_mod, layout="NCHW"):
mobilenet_mod = get_mobilenet(keras_mod)

Expand Down Expand Up @@ -877,6 +899,8 @@ def test_simplernn_with_infertype(self, keras_mod):
sut.test_forward_xception(keras_mod=k)
sut.test_forward_resnet50(keras_mod=k)
sut.test_forward_resnet50(keras_mod=k, layout="NHWC")
sut.test_forward_inception_v3(keras_mod=k)
sut.test_forward_inception_v3(keras_mod=k, layout="NHWC")
sut.test_forward_mobilenet(keras_mod=k)
sut.test_forward_mobilenet(keras_mod=k, layout="NHWC")
sut.test_forward_conv3d(keras_mod=k)
Expand Down

0 comments on commit 3136ff4

Please sign in to comment.