Skip to content

Commit

Permalink
Fix reshape/transpose issues seen by pytorch
Browse files Browse the repository at this point in the history
Pytorch does not accept numpy.ndarray shapes, we must use tuples (or
lists) of integers.
  • Loading branch information
nhuet committed Dec 15, 2023
1 parent 0ba5862 commit f670071
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/decomon/backward_layers/utils_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_toeplitz_channels_last(conv_layer: Conv2D, flatten: bool = True) -> Back

diag_patches_ = K.reshape(diag_patches, (w_in, h_in, c_in, w_out, h_out, filter_size**2, c_in))

shape = np.arange(len(diag_patches_.shape))
shape = list(range(len(diag_patches_.shape)))
shape[-1] -= 1
shape[-2] += 1
diag_patches_ = K.transpose(diag_patches_, shape)
Expand Down Expand Up @@ -135,7 +135,7 @@ def get_toeplitz_channels_first(conv_layer: Conv2D, flatten: bool = True) -> Bac

diag_patches_ = K.reshape(diag_patches, (w_in, h_in, c_in, w_out, h_out, filter_size**2, c_in))

shape = np.arange(len(diag_patches_.shape))
shape = list(range(len(diag_patches_.shape)))
shape[-1] -= 1
shape[-2] += 1
diag_patches_ = K.transpose(diag_patches_, shape)
Expand Down
4 changes: 2 additions & 2 deletions src/decomon/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,9 @@ def permute_dimensions(
# not enough dim to permute
return inputs
index = np.arange(len(input_shape))
index = np.insert(np.delete(index, axis), axis_perm, axis)
index = tuple(np.insert(np.delete(index, axis), axis_perm, axis))
index_w = np.arange(len(input_shape) + 1)
index_w = np.insert(np.delete(index_w, axis), axis_perm + 1, axis)
index_w = tuple(np.insert(np.delete(index_w, axis), axis_perm + 1, axis))

if ibp:
u_c_out = K.transpose(u_c, index)
Expand Down
8 changes: 4 additions & 4 deletions src/decomon/layers/utils_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_upper_linear_hull_max(

# expand dim/broadcast
mask = K.identity(n_dim, dtype=dtype) # (n_dim, n_dim)
mask_shape = np.ones(len(u_c.shape) + 1)
mask_shape = [1] * (len(u_c.shape) + 1)
mask_shape[-1] = n_dim
if axis != -1:
mask_shape[axis] = n_dim
Expand Down Expand Up @@ -101,11 +101,11 @@ def get_upper_linear_hull_max(
bias_corner = o_value + K.sum(z_value * corners, -2, keepdims=True)
corners_collapse = K.concatenate([corners_collapse, bias_corner], axis=-2)

dimensions = np.arange(len(corners.shape))
dimensions = list(range(len(corners.shape)))
if axis != -1:
dim_permutation = np.concatenate([dimensions[:axis], dimensions[axis + 1 :], [dimensions[axis]]])
dim_permutation = dimensions[:axis] + dimensions[axis + 1 :] + [dimensions[axis]]
else:
dim_permutation = np.concatenate([dimensions[:-2], dimensions[-1:], [dimensions[-2]]])
dim_permutation = dimensions[:-2] + dimensions[-1:] + [dimensions[-2]]

corners_collapse = K.transpose(corners_collapse, dim_permutation)
# tf.linalg.solve works only for float32
Expand Down

0 comments on commit f670071

Please sign in to comment.