Skip to content

Commit

Permalink
Fix get_tensor_decomposition_images_box for data_format=channels_firs…
Browse files Browse the repository at this point in the history
…t + dc_decomp=True

- h and g were having the wrong shape
- we als rfactr get_standard_values_images_box by putting things
  necessary fr data_format=channels_last in the corresponding if
  • Loading branch information
nhuet committed Dec 18, 2023
1 parent ce6b566 commit 30295d3
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,18 +861,16 @@ def build_image_from_2D_box(odd=0, m0=0, m1=1, dc_decomp=True):

@staticmethod
def get_standard_values_images_box(data_format="channels_last", odd=0, m0=0, m1=1, dc_decomp=True):
output = Helpers.build_image_from_2D_box(odd, m0, m1, dc_decomp)
if dc_decomp:
x_0, y_0, z_0, u_c_0, w_u_0, b_u_0, l_c_0, w_l_0, b_l_0, h_0, g_0 = output
else:
x_0, y_0, z_0, u_c_0, w_u_0, b_u_0, l_c_0, w_l_0, b_l_0 = output
if data_format == "channels_last":
output = Helpers.build_image_from_2D_box(odd, m0, m1, dc_decomp)
if dc_decomp:
x_0, y_0, z_0, u_c_0, w_u_0, b_u_0, l_c_0, w_l_0, b_l_0, h_0, g_0 = output
else:
x_0, y_0, z_0, u_c_0, w_u_0, b_u_0, l_c_0, w_l_0, b_l_0 = output

x_ = x_0
z_ = z_0
z_min_ = z_0[:, 0]
z_max_ = z_0[:, 1]
x_ = x_0
z_ = z_0

if data_format == "channels_last":
y_0 = y_0[:, :, :, None]
b_u_0 = b_u_0[:, :, :, None]
b_l_0 = b_l_0[:, :, :, None]
Expand Down Expand Up @@ -951,7 +949,7 @@ def get_tensor_decomposition_images_box(data_format, odd, dc_decomp=True):
Input((2, n, n), dtype=keras_config.floatx()),
]
if dc_decomp:
output += [Input((n, n, 2), dtype=keras_config.floatx()), Input((n, n, 2), dtype=keras_config.floatx())]
output += [Input((2, n, n), dtype=keras_config.floatx()), Input((2, n, n), dtype=keras_config.floatx())]

return output

Expand Down

0 comments on commit 30295d3

Please sign in to comment.