Skip to content

Commit

Permalink
Remove decomon.keras_utils.LinalgSolve as it exists now as keras.ops.…
Browse files Browse the repository at this point in the history
…solve()
  • Loading branch information
nhuet authored and ducoffeM committed Mar 19, 2024
1 parent 439d7ed commit bec8cd8
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 63 deletions.
27 changes: 0 additions & 27 deletions src/decomon/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,6 @@
BACKEND_JAX = "jax"


class LinalgSolve(keras.Operation):
"""Keras operation mimicking tensorflow.linalg.solve()."""

def compute_output_spec(self, matrix: keras.KerasTensor, rhs: keras.KerasTensor) -> keras.KerasTensor:
rhs_shape = rhs.shape
rhs_dtype = getattr(rhs, "dtype", type(rhs))
rhs_sparse = getattr(rhs, "sparse", False)
return keras.KerasTensor(
shape=rhs_shape,
dtype=rhs_dtype,
sparse=rhs_sparse,
)

def call(self, matrix: BackendTensor, rhs: BackendTensor) -> BackendTensor:
backend = keras.config.backend()
if backend == BACKEND_TENSORFLOW:
import tensorflow as tf

return tf.linalg.solve(matrix=matrix, rhs=rhs)
elif backend == BACKEND_PYTORCH:
import torch

return torch.linalg.solve(A=matrix, B=rhs)
else:
raise NotImplementedError(f"linalg_solve() not yet implemented for backend {backend}.")


class BatchedIdentityLike(keras.Operation):
"""Keras Operation creating an identity tensor with shape (including batch_size) based on input.
Expand Down
3 changes: 1 addition & 2 deletions src/decomon/layers/utils_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
PerturbationDomain,
get_affine,
)
from decomon.keras_utils import LinalgSolve
from decomon.types import BackendTensor, Tensor

# step 1: compute (x_i, y_i) such that x_i[j]=l_j if j==i else u_j
Expand Down Expand Up @@ -112,7 +111,7 @@ def get_upper_linear_hull_max(
if dtype != dtype32:
corners_collapse = K.cast(corners_collapse, dtype32)
corners_pred = K.cast(corners_pred, dtype32)
w_hull = LinalgSolve()(matrix=corners_collapse, rhs=K.expand_dims(corners_pred, -1)) # (None, shape_, n_dim+1, 1)
w_hull = K.solve(corners_collapse, K.expand_dims(corners_pred, -1)) # (None, shape_, n_dim+1, 1)

if dtype != dtype32:
w_hull = K.cast(w_hull, dtype=dtype)
Expand Down
35 changes: 1 addition & 34 deletions tests/test_keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
from keras.layers import Dense, Input
from numpy.testing import assert_almost_equal

from decomon.keras_utils import (
BACKEND_PYTORCH,
BACKEND_TENSORFLOW,
LinalgSolve,
get_weight_index_from_name,
share_layer_all_weights,
)
from decomon.keras_utils import get_weight_index_from_name, share_layer_all_weights


def test_get_weight_index_from_name_nok_attribute():
Expand All @@ -34,33 +28,6 @@ def test_get_weight_index_from_name_ok():
assert get_weight_index_from_name(layer=layer, weight_name="bias") in [0, 1]


def test_linalgsolve(floatx, decimal):
if keras.config.backend() in (BACKEND_TENSORFLOW, BACKEND_PYTORCH) and floatx == 16:
pytest.skip("LinalgSolve not implemented for float16 on torch and tensorflow")

dtype = f"float{floatx}"

matrix = np.array([[1, 0, 0], [2, 1, 0], [3, 2, 1]])
matrix = np.repeat(matrix[None, None], 2, axis=0)
matrix_symbolic_tensor = keras.KerasTensor(shape=matrix.shape, dtype=dtype)
matrix_tensor = keras.ops.convert_to_tensor(matrix, dtype=dtype)

rhs = np.array([[1, 0], [0, 0], [0, 1]])
rhs = np.repeat(rhs[None, None], 2, axis=0)
rhs_symbolic_tensor = keras.KerasTensor(shape=rhs.shape, dtype=dtype)
rhs_tensor = keras.ops.convert_to_tensor(rhs, dtype=dtype)

expected_sol = np.array([[1, 0], [-2, 0], [1, 1]])
expected_sol = np.repeat(expected_sol[None, None], 2, axis=0)

sol_symbolic_tensor = LinalgSolve()(matrix_symbolic_tensor, rhs_symbolic_tensor)
assert tuple(sol_symbolic_tensor.shape) == tuple(expected_sol.shape)

sol_tensor = LinalgSolve()(matrix_tensor, rhs_tensor)
assert keras.backend.standardize_dtype(sol_tensor.dtype) == dtype
assert_almost_equal(expected_sol, keras.ops.convert_to_numpy(sol_tensor), decimal=decimal)


def test_share_layer_all_weights_nok_original_layer_unbuilt():
original_layer = Dense(3)
new_layer = original_layer.__class__.from_config(original_layer.get_config())
Expand Down

0 comments on commit bec8cd8

Please sign in to comment.