diff --git a/src/decomon/keras_utils.py b/src/decomon/keras_utils.py index 0209c0a8..681f359c 100644 --- a/src/decomon/keras_utils.py +++ b/src/decomon/keras_utils.py @@ -13,33 +13,6 @@ BACKEND_JAX = "jax" -class LinalgSolve(keras.Operation): - """Keras operation mimicking tensorflow.linalg.solve().""" - - def compute_output_spec(self, matrix: keras.KerasTensor, rhs: keras.KerasTensor) -> keras.KerasTensor: - rhs_shape = rhs.shape - rhs_dtype = getattr(rhs, "dtype", type(rhs)) - rhs_sparse = getattr(rhs, "sparse", False) - return keras.KerasTensor( - shape=rhs_shape, - dtype=rhs_dtype, - sparse=rhs_sparse, - ) - - def call(self, matrix: BackendTensor, rhs: BackendTensor) -> BackendTensor: - backend = keras.config.backend() - if backend == BACKEND_TENSORFLOW: - import tensorflow as tf - - return tf.linalg.solve(matrix=matrix, rhs=rhs) - elif backend == BACKEND_PYTORCH: - import torch - - return torch.linalg.solve(A=matrix, B=rhs) - else: - raise NotImplementedError(f"linalg_solve() not yet implemented for backend {backend}.") - - class BatchedIdentityLike(keras.Operation): """Keras Operation creating an identity tensor with shape (including batch_size) based on input. diff --git a/src/decomon/layers/utils_pooling.py b/src/decomon/layers/utils_pooling.py index 4b7ecefd..541d5eb6 100644 --- a/src/decomon/layers/utils_pooling.py +++ b/src/decomon/layers/utils_pooling.py @@ -10,7 +10,6 @@ PerturbationDomain, get_affine, ) -from decomon.keras_utils import LinalgSolve from decomon.types import BackendTensor, Tensor # step 1: compute (x_i, y_i) such that x_i[j]=l_j if j==i else u_j @@ -112,7 +111,7 @@ def get_upper_linear_hull_max( if dtype != dtype32: corners_collapse = K.cast(corners_collapse, dtype32) corners_pred = K.cast(corners_pred, dtype32) - w_hull = LinalgSolve()(matrix=corners_collapse, rhs=K.expand_dims(corners_pred, -1)) # (None, shape_, n_dim+1, 1) + w_hull = K.solve(corners_collapse, K.expand_dims(corners_pred, -1)) # (None, shape_, n_dim+1, 1) if dtype != dtype32: w_hull = K.cast(w_hull, dtype=dtype) diff --git a/tests/test_keras_utils.py b/tests/test_keras_utils.py index c65ad19a..6fa5aaf3 100644 --- a/tests/test_keras_utils.py +++ b/tests/test_keras_utils.py @@ -5,13 +5,7 @@ from keras.layers import Dense, Input from numpy.testing import assert_almost_equal -from decomon.keras_utils import ( - BACKEND_PYTORCH, - BACKEND_TENSORFLOW, - LinalgSolve, - get_weight_index_from_name, - share_layer_all_weights, -) +from decomon.keras_utils import get_weight_index_from_name, share_layer_all_weights def test_get_weight_index_from_name_nok_attribute(): @@ -34,33 +28,6 @@ def test_get_weight_index_from_name_ok(): assert get_weight_index_from_name(layer=layer, weight_name="bias") in [0, 1] -def test_linalgsolve(floatx, decimal): - if keras.config.backend() in (BACKEND_TENSORFLOW, BACKEND_PYTORCH) and floatx == 16: - pytest.skip("LinalgSolve not implemented for float16 on torch and tensorflow") - - dtype = f"float{floatx}" - - matrix = np.array([[1, 0, 0], [2, 1, 0], [3, 2, 1]]) - matrix = np.repeat(matrix[None, None], 2, axis=0) - matrix_symbolic_tensor = keras.KerasTensor(shape=matrix.shape, dtype=dtype) - matrix_tensor = keras.ops.convert_to_tensor(matrix, dtype=dtype) - - rhs = np.array([[1, 0], [0, 0], [0, 1]]) - rhs = np.repeat(rhs[None, None], 2, axis=0) - rhs_symbolic_tensor = keras.KerasTensor(shape=rhs.shape, dtype=dtype) - rhs_tensor = keras.ops.convert_to_tensor(rhs, dtype=dtype) - - expected_sol = np.array([[1, 0], [-2, 0], [1, 1]]) - expected_sol = np.repeat(expected_sol[None, None], 2, axis=0) - - sol_symbolic_tensor = LinalgSolve()(matrix_symbolic_tensor, rhs_symbolic_tensor) - assert tuple(sol_symbolic_tensor.shape) == tuple(expected_sol.shape) - - sol_tensor = LinalgSolve()(matrix_tensor, rhs_tensor) - assert keras.backend.standardize_dtype(sol_tensor.dtype) == dtype - assert_almost_equal(expected_sol, keras.ops.convert_to_numpy(sol_tensor), decimal=decimal) - - def test_share_layer_all_weights_nok_original_layer_unbuilt(): original_layer = Dense(3) new_layer = original_layer.__class__.from_config(original_layer.get_config())