Skip to content

Commit

Permalink
Fix minor issues in backend ops
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 8, 2024
1 parent 240b671 commit 10c27c0
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
6 changes: 4 additions & 2 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,11 @@ def conv(
f"kernel in_channels {kernel_in_channels}. "
)
feature_group_count = channels // kernel_in_channels
kernel = convert_to_tensor(kernel)
inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
return jax.lax.conv_general_dilated(
convert_to_tensor(inputs),
convert_to_tensor(kernel),
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,8 +1251,8 @@ def split(x, indices_or_sections, axis=0):
dim=axis,
)
if dim == 0 and isinstance(indices_or_sections, int):
out = tuple(out[0].clone() for _ in range(indices_or_sections))
return out
out = [out[0].clone() for _ in range(indices_or_sections)]
return list(out)


def stack(x, axis=0):
Expand Down
21 changes: 16 additions & 5 deletions keras/src/ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,15 +1076,26 @@ def _crop_images(
f"Received: target_width={target_width}"
)

if isinstance(top_cropping, int) and isinstance(left_cropping, int):
start_indices = [0, top_cropping, left_cropping, 0]
else:
start_indices = backend.numpy.stack([0, top_cropping, left_cropping, 0])
if (
isinstance(batch, int)
and isinstance(target_height, int)
and isinstance(target_width, int)
and isinstance(depth, int)
):
shape = [batch, target_height, target_width, depth]
else:
shape = backend.numpy.stack([batch, target_height, target_width, depth])
cropped = ops.slice(
images,
backend.numpy.stack([0, top_cropping, left_cropping, 0]),
backend.numpy.stack([batch, target_height, target_width, depth]),
start_indices,
shape,
)

cropped_shape = [batch, target_height, target_width, depth]
cropped = backend.numpy.reshape(cropped, cropped_shape)

cropped = backend.numpy.reshape(cropped, shape)
if not is_batch:
cropped = backend.numpy.squeeze(cropped, axis=[0])
return cropped
1 change: 1 addition & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,6 +3974,7 @@ def test_sort(self):

def test_split(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertIsInstance(knp.split(x, 2), list)
self.assertAllClose(knp.split(x, 2), np.split(x, 2))
self.assertAllClose(knp.Split(2)(x), np.split(x, 2))
self.assertAllClose(
Expand Down

0 comments on commit 10c27c0

Please sign in to comment.