Skip to content

Commit

Permalink
Replace tfnp with native tf ops wherever possible (#18998)
Browse files Browse the repository at this point in the history
* Add dtype test for `floor_divide`

* Use native `tf.*` if possible in `backend.numpy`

* Apply `result_type` to np's `floor_divide`

* Increase test coverage

* Fix failed test
  • Loading branch information
james77777778 authored Dec 28, 2023
1 parent 8e897fb commit 86bd12c
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 109 deletions.
3 changes: 3 additions & 0 deletions keras/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def argmin(x, axis=None):


def argsort(x, axis=-1):
x = convert_to_tensor(x)
if x.ndim == 0:
return jnp.argsort(x, axis=None)
return jnp.argsort(x, axis=axis)


Expand Down
9 changes: 9 additions & 0 deletions keras/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,15 @@ def eye(N, M=None, k=0, dtype=None):


def floor_divide(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
if not isinstance(x2, (int, float)):
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2))
)
x1 = convert_to_tensor(x1, dtype)
x2 = convert_to_tensor(x2, dtype)
return np.floor_divide(x1, x2)


Expand Down
Loading

0 comments on commit 86bd12c

Please sign in to comment.