diff --git a/src/decomon/layers/activations/activation.py b/src/decomon/layers/activations/activation.py index 78a7262e..6390d5dc 100644 --- a/src/decomon/layers/activations/activation.py +++ b/src/decomon/layers/activations/activation.py @@ -176,7 +176,7 @@ def compute_output_shape( ) = self.inputs_outputs_spec.split_input_shape(input_shape=input_shape) return self.inputs_outputs_spec.flatten_outputs_shape( affine_bounds_propagated_shape=affine_bounds_to_propagate_shape, - constant_bounds_propagated_shape=constant_oracle_bounds_shape, + constant_bounds_propagated_shape=constant_oracle_bounds_shape, # type: ignore ) diff --git a/src/decomon/layers/crown.py b/src/decomon/layers/crown.py index 142023bb..52e5ad76 100644 --- a/src/decomon/layers/crown.py +++ b/src/decomon/layers/crown.py @@ -89,7 +89,7 @@ def call(self, inputs: list[list[BackendTensor]]) -> list[BackendTensor]: b_l = add_tensors(b_l, b_l_i, missing_batchsize=missing_batchsize) b_u = add_tensors(b_u, b_u_i, missing_batchsize=missing_batchsize) - affine_bounds = w_l, b_l, w_u, b_u + affine_bounds = [w_l, b_l, w_u, b_u] return affine_bounds def build(self, input_shape: list[list[tuple[Optional[int], ...]]]) -> None: diff --git a/src/decomon/layers/fuse.py b/src/decomon/layers/fuse.py index f9447f09..5886133c 100644 --- a/src/decomon/layers/fuse.py +++ b/src/decomon/layers/fuse.py @@ -53,7 +53,7 @@ def __init__( m1_input_shape: tuple[int, ...], m_1_output_shapes: list[tuple[int, ...]], from_linear_2: list[bool], - **kwargs, + **kwargs: Any, ): """ @@ -138,7 +138,7 @@ def _is_from_linear_m1_ith_affine_bounds(self, affine_bounds: list[Tensor], i: i return len(affine_bounds) == 0 or affine_bounds[1].shape == self.m_1_output_shapes[i] def _is_from_linear_m1_ith_affine_bounds_shape( - self, affine_bounds_shape: list[tuple[Optional[int]]], i: int + self, affine_bounds_shape: list[tuple[Optional[int], ...]], i: int ) -> bool: return len(affine_bounds_shape) == 0 or affine_bounds_shape[1] == self.m_1_output_shapes[i] @@ -226,14 +226,16 @@ def compute_output_shape( ) -> list[tuple[Optional[int], ...]]: bounds_1_shape, bounds_2_shape = input_shape - bounds_fused_shape: list[tuple[int, ...]] = [] + bounds_fused_shape: list[tuple[Optional[int], ...]] = [] for i in range(self.nb_outputs_first_model): bounds_1_i_shape = bounds_1_shape[ i * self.inputs_outputs_spec_1.nb_output_tensors : (i + 1) * self.inputs_outputs_spec_1.nb_output_tensors ] - affine_bounds_1_shape, constant_bounds_1_shape = self.inputs_outputs_spec_1.split_output_shape( + affine_bounds_1_shape: list[tuple[Optional[int], ...]] + constant_bounds_1_shape: list[tuple[Optional[int], ...]] + affine_bounds_1_shape, constant_bounds_1_shape = self.inputs_outputs_spec_1.split_output_shape( # type: ignore bounds_1_i_shape ) @@ -242,9 +244,14 @@ def compute_output_shape( * self.inputs_outputs_spec_2[0].nb_output_tensors : (i + 1) * self.inputs_outputs_spec_2[0].nb_output_tensors ] - affine_bounds_2_shape, constant_bounds_2_shape = self.inputs_outputs_spec_2[0].split_output_shape( - bounds_2_i_shape - ) + affine_bounds_2_shape: list[tuple[Optional[int], ...]] + constant_bounds_2_shape: list[tuple[Optional[int], ...]] + ( + affine_bounds_2_shape, + constant_bounds_2_shape, + ) = self.inputs_outputs_spec_2[ # type:ignore + 0 + ].split_output_shape(bounds_2_i_shape) # constant bounds if self.ibp_2: @@ -264,10 +271,11 @@ def compute_output_shape( # affine bounds if self.affine_1 and self.affine_2: _, b2_shape, _, _ = affine_bounds_2_shape + model_2_output_shape_wo_batchisze: tuple[int, ...] if self.from_linear_2[i]: - model_2_output_shape_wo_batchisze = b2_shape + model_2_output_shape_wo_batchisze = b2_shape # type: ignore else: - model_2_output_shape_wo_batchisze = b2_shape[1:] + model_2_output_shape_wo_batchisze = b2_shape[1:] # type: ignore diagonal = self.inputs_outputs_spec_1.is_diagonal_bounds_shape( affine_bounds_1_shape @@ -281,6 +289,8 @@ def compute_output_shape( self._is_from_linear_m1_ith_affine_bounds_shape(affine_bounds_shape=affine_bounds_1_shape, i=i) and self.from_linear_2[i] ) + w_fused_shape: tuple[Optional[int], ...] + b_fused_shape: tuple[Optional[int], ...] if from_linear_layer: w_fused_shape = w_fused_shape_wo_batchsize b_fused_shape = model_2_output_shape_wo_batchisze diff --git a/src/decomon/layers/input.py b/src/decomon/layers/input.py index 607f8fc2..a2e79921 100644 --- a/src/decomon/layers/input.py +++ b/src/decomon/layers/input.py @@ -95,7 +95,7 @@ def compute_output_shape( self, input_shape: tuple[Optional[int], ...], ) -> list[tuple[Optional[int], ...]]: - perturbation_domain_input_shape_wo_batchsize = input_shape[1:] + perturbation_domain_input_shape_wo_batchsize: tuple[int, ...] = input_shape[1:] # type: ignore keras_input_shape_wo_batchsize = self.perturbation_domain.get_keras_input_shape_wo_batchsize( x_shape=perturbation_domain_input_shape_wo_batchsize ) @@ -112,7 +112,7 @@ def compute_output_shape( else: constant_bounds_shape = [] return self.inputs_outputs_spec.flatten_inputs_shape( - affine_bounds_to_propagate_shape=affine_bounds_shape, + affine_bounds_to_propagate_shape=affine_bounds_shape, # type: ignore constant_oracle_bounds_shape=constant_bounds_shape, perturbation_domain_inputs_shape=[], ) @@ -165,7 +165,7 @@ def compute_output_shape( self, input_shape: tuple[Optional[int], ...], ) -> list[tuple[Optional[int], ...]]: - perturbation_domain_input_shape_wo_batchsize = input_shape[1:] + perturbation_domain_input_shape_wo_batchsize: tuple[int, ...] = input_shape[1:] # type: ignore keras_input_shape_wo_batchsize = self.perturbation_domain.get_keras_input_shape_wo_batchsize( x_shape=perturbation_domain_input_shape_wo_batchsize ) @@ -312,11 +312,13 @@ def compute_output_shape( else: w_shape_wo_batchsize = w_shape[1:] is_diag = w_shape_wo_batchsize == model_output_shape + m2_output_shape: tuple[Optional[int], ...] if is_diag: m2_output_shape = model_output_shape else: m2_output_shape = w_shape_wo_batchsize[len(model_output_shape) :] b_shape_wo_batchsize = m2_output_shape + b_shape: tuple[Optional[int], ...] if from_linear: b_shape = b_shape_wo_batchsize else: @@ -332,10 +334,6 @@ def compute_output_shape( return input_shape -def _is_keras_tensor_shape(shape): - return len(shape) > 0 and (shape[0] is None or isinstance(shape[0], int)) - - def flatten_backward_bounds( backward_bounds: Union[keras.KerasTensor, list[keras.KerasTensor], list[list[keras.KerasTensor]]] ) -> list[keras.KerasTensor]: diff --git a/src/decomon/layers/inputs_outputs_specs.py b/src/decomon/layers/inputs_outputs_specs.py index 0d215a0f..3c6cde63 100644 --- a/src/decomon/layers/inputs_outputs_specs.py +++ b/src/decomon/layers/inputs_outputs_specs.py @@ -496,7 +496,7 @@ def has_multiple_bounds_inputs(self) -> bool: return self.propagation == Propagation.FORWARD and self.is_merging_layer @overload - def extract_shapes_from_affine_bounds( + def extract_shapes_from_affine_bounds( # type:ignore self, affine_bounds: list[Tensor], i: int = -1 ) -> list[tuple[Optional[int], ...]]: ... @@ -575,14 +575,14 @@ def is_wo_batch_bounds_shape( b_shape = affine_bounds_shape[1] if self.propagation == Propagation.FORWARD: if i > -1: - return len(b_shape) == len(self.layer_input_shape[i]) + return len(b_shape) == len(self.layer_input_shape[i]) # type: ignore else: return len(b_shape) == len(self.layer_input_shape) else: return len(b_shape) == len(self.model_output_shape) @overload - def is_wo_batch_bounds_by_keras_input( + def is_wo_batch_bounds_by_keras_input( # type: ignore self, affine_bounds: list[Tensor], ) -> bool: diff --git a/src/decomon/layers/layer.py b/src/decomon/layers/layer.py index a37a504b..cbfcfc9b 100644 --- a/src/decomon/layers/layer.py +++ b/src/decomon/layers/layer.py @@ -1,5 +1,5 @@ from inspect import Parameter, signature -from typing import Any, Optional +from typing import Any, Optional, Union import keras import keras.ops as K @@ -169,7 +169,7 @@ def is_merging_layer(self) -> bool: @property def layer_input_shape(self) -> tuple[int, ...]: - return self.inputs_outputs_spec.layer_input_shape + return self.inputs_outputs_spec.layer_input_shape # type: ignore @property def model_input_shape(self) -> tuple[int, ...]: @@ -640,6 +640,7 @@ def compute_output_shape( # outputs shape depends if layer and inputs are diagonal / linear (w/o batch) b_shape_wo_batchisze = model_output_shape_wo_batchsize if self.diagonal and self.inputs_outputs_spec.is_diagonal_bounds_shape(affine_bounds_to_propagate_shape): + w_shape_wo_batchsize: Union[tuple[int, ...], list[tuple[int, ...]]] if self._is_merging_layer: w_shape_wo_batchsize = [model_output_shape_wo_batchsize] * self.inputs_outputs_spec.nb_keras_inputs else: @@ -652,15 +653,17 @@ def compute_output_shape( ] else: w_shape_wo_batchsize = self.layer.input.shape[1:] + model_output_shape_wo_batchsize + b_shape: tuple[Optional[int], ...] + w_shape: Union[tuple[Optional[int], ...], list[tuple[Optional[int], ...]]] if self.linear and self.inputs_outputs_spec.is_wo_batch_bounds_shape(affine_bounds_to_propagate_shape): b_shape = b_shape_wo_batchisze - w_shape = w_shape_wo_batchsize + w_shape = w_shape_wo_batchsize # type: ignore else: b_shape = (None,) + b_shape_wo_batchisze if self._is_merging_layer: - w_shape = [(None,) + sub_w_shape_wo_batchsize for sub_w_shape_wo_batchsize in w_shape_wo_batchsize] + w_shape = [(None,) + sub_w_shape_wo_batchsize for sub_w_shape_wo_batchsize in w_shape_wo_batchsize] # type: ignore else: - w_shape = (None,) + w_shape_wo_batchsize + w_shape = (None,) + w_shape_wo_batchsize # type: ignore if self._is_merging_layer: affine_bounds_propagated_shape = [ [ diff --git a/src/decomon/layers/merging/base_merge.py b/src/decomon/layers/merging/base_merge.py index c73d0afa..fd966e66 100644 --- a/src/decomon/layers/merging/base_merge.py +++ b/src/decomon/layers/merging/base_merge.py @@ -1,5 +1,6 @@ from typing import Any +import keras import keras.ops as K from decomon.keras_utils import add_tensors, batch_multid_dot @@ -12,7 +13,7 @@ class DecomonMerge(DecomonLayer): _is_merging_layer = True @property - def keras_layer_input(self): + def keras_layer_input(self) -> list[keras.KerasTensor]: """self.layer.input returned as a list. In the degenerate case where only 1 input is merged, self.layer.input is a single keras tensor. @@ -25,7 +26,7 @@ def keras_layer_input(self): return [self.layer.input] @property - def nb_keras_inputs(self): + def nb_keras_inputs(self) -> int: """Number of inputs merged by the underlying layer.""" return len(self.keras_layer_input) @@ -274,7 +275,7 @@ def forward_affine_propagate( from_linear_layer_new = all(from_linear_add) return w_l_new, b_l_new, w_u_new, b_u_new - def backward_affine_propagate( + def backward_affine_propagate( # type: ignore self, output_affine_bounds: list[Tensor], input_constant_bounds: list[list[Tensor]] ) -> list[tuple[Tensor, Tensor, Tensor, Tensor]]: """Propagate model affine bounds in backward direction. diff --git a/src/decomon/layers/oracle.py b/src/decomon/layers/oracle.py index 773048ca..a2f211d7 100644 --- a/src/decomon/layers/oracle.py +++ b/src/decomon/layers/oracle.py @@ -134,12 +134,13 @@ def compute_output_shape( """Compute output shape in case of symbolic call.""" if self.is_merging_layer: output_shape = [] - for layer_input_shape_i in self.layer_input_shape: + layer_input_shape_i: tuple[int, ...] + for layer_input_shape_i in self.layer_input_shape: # type: ignore layer_input_shape_w_batchsize_i = (None,) + layer_input_shape_i output_shape.append([layer_input_shape_w_batchsize_i, layer_input_shape_w_batchsize_i]) return output_shape else: - layer_input_shape_w_batchsize = (None,) + self.layer_input_shape + layer_input_shape_w_batchsize = (None,) + self.layer_input_shape # type: ignore return [layer_input_shape_w_batchsize, layer_input_shape_w_batchsize] @@ -222,7 +223,7 @@ def get_forward_oracle( x = perturbation_domain_inputs[0] if is_merging_layer: constant_bounds = [] - for affine_bounds_i, from_linear_i in zip(affine_bounds, from_linear): + for affine_bounds_i, from_linear_i in zip(affine_bounds, from_linear): # type: ignore if len(affine_bounds_i) == 0: # special case: empty affine bounds => identity bounds l_affine = perturbation_domain.get_lower_x(x) @@ -239,6 +240,8 @@ def get_forward_oracle( l_affine = perturbation_domain.get_lower_x(x) u_affine = perturbation_domain.get_upper_x(x) else: + if not isinstance(from_linear, bool): + raise ValueError("from_linear must be a boolean for a non-merging layer") w_l, b_l, w_u, b_u = affine_bounds l_affine = perturbation_domain.get_lower(x, w_l, b_l, missing_batchsize=from_linear) u_affine = perturbation_domain.get_upper(x, w_u, b_u, missing_batchsize=from_linear) diff --git a/src/decomon/layers/output.py b/src/decomon/layers/output.py index 398f5f07..74e1147d 100644 --- a/src/decomon/layers/output.py +++ b/src/decomon/layers/output.py @@ -143,11 +143,16 @@ def compute_output_shape( self, input_shape: list[tuple[Optional[int], ...]], ) -> list[tuple[Optional[int], ...]]: + affine_bounds_from_shape: list[list[tuple[Optional[int], ...]]] + constant_bounds_from_shape: list[list[tuple[Optional[int], ...]]] + perturbation_domain_inputs_shape: list[tuple[Optional[int], ...]] ( affine_bounds_from_shape, constant_bounds_from_shape, perturbation_domain_inputs_shape, - ) = self.inputs_outputs_spec.split_input_shape(input_shape) + ) = self.inputs_outputs_spec.split_input_shape( # type: ignore + input_shape + ) constant_bounds_to_shape: list[list[tuple[Optional[int], ...]]] affine_bounds_to_shape: list[list[tuple[Optional[int], ...]]] @@ -167,7 +172,8 @@ def compute_output_shape( affine_bounds_to_shape = affine_bounds_from_shape else: x_shape = perturbation_domain_inputs_shape[0] - keras_input_shape = self.perturbation_domain.get_keras_input_shape_wo_batchsize(x_shape[1:]) + x_shape_wo_batchsize: tuple[int, ...] = x_shape[1:] # type: ignore + keras_input_shape = self.perturbation_domain.get_keras_input_shape_wo_batchsize(x_shape_wo_batchsize) affine_bounds_to_shape = [] for model_output_shape in self.model_output_shapes: b_shape = (None,) + model_output_shape diff --git a/src/decomon/models/backward_cloning.py b/src/decomon/models/backward_cloning.py index 80a8a24e..1e3e81c4 100644 --- a/src/decomon/models/backward_cloning.py +++ b/src/decomon/models/backward_cloning.py @@ -378,7 +378,7 @@ def crown_model( model_output_shape = get_model_output_shape( node=node, backward_bounds=backward_bounds_node, from_linear=from_linear ) - backward_map_node = {} + backward_map_node: dict[int, DecomonLayer] = {} output_crown = crown( node=node, @@ -406,7 +406,7 @@ def convert_backward( layer_fn: Callable[..., DecomonLayer] = to_decomon, backward_bounds: Optional[list[keras.KerasTensor]] = None, from_linear_backward_bounds: Union[bool, list[bool]] = False, - slope: Union[str, Slope] = Slope.V_SLOPE, + slope: Slope = Slope.V_SLOPE, forward_output_map: Optional[dict[int, list[keras.KerasTensor]]] = None, forward_layer_map: Optional[dict[int, DecomonLayer]] = None, mapping_keras2decomon_classes: Optional[dict[type[Layer], type[DecomonLayer]]] = None, @@ -440,6 +440,7 @@ def convert_backward( """ if perturbation_domain is None: perturbation_domain = BoxDomain() + backward_bounds_for_crown_model: list[list[keras.KerasTensor]] if backward_bounds is None: backward_bounds_for_crown_model = [[]] * len(model.outputs) else: @@ -474,7 +475,7 @@ def convert_backward( return output -def get_model_output_shape(node: Node, backward_bounds: list[Tensor], from_linear: bool = False): +def get_model_output_shape(node: Node, backward_bounds: list[Tensor], from_linear: bool = False) -> tuple[int, ...]: """Get outer model output shape w/o batchsize. If any backward bounds are passed, we deduce the outer keras model output shape from it. diff --git a/src/decomon/models/convert.py b/src/decomon/models/convert.py index efd55abb..c8233be5 100644 --- a/src/decomon/models/convert.py +++ b/src/decomon/models/convert.py @@ -311,18 +311,18 @@ def clone( if perturbation_domain is None: perturbation_domain = BoxDomain() - default_final_ibp, default_final_affine = get_final_ibp_affine_from_method(method) - if final_ibp is None: - final_ibp = default_final_ibp - if final_affine is None: - final_affine = default_final_affine - if isinstance(method, str): method = ConvertMethod(method.lower()) if isinstance(slope, str): slope = Slope(slope.lower()) + default_final_ibp, default_final_affine = get_final_ibp_affine_from_method(method) + if final_ibp is None: + final_ibp = default_final_ibp + if final_affine is None: + final_affine = default_final_affine + # preprocess backward_bounds backward_bounds_flattened: Optional[list[keras.KerasTensor]] backward_bounds_for_convert: Optional[list[keras.KerasTensor]] diff --git a/src/decomon/perturbation_domain.py b/src/decomon/perturbation_domain.py index 4202251d..dbfbe1d6 100644 --- a/src/decomon/perturbation_domain.py +++ b/src/decomon/perturbation_domain.py @@ -31,7 +31,7 @@ def get_lower_x(self, x: Tensor) -> Tensor: ... @abstractmethod - def get_upper(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: + def get_upper(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any) -> Tensor: """Merge upper affine bounds with perturbation domain input to get upper constant bound. Args: @@ -47,7 +47,7 @@ def get_upper(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, ** ... @abstractmethod - def get_lower(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: + def get_lower(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any) -> Tensor: """Merge lower affine bounds with perturbation domain input to get lower constant bound. Args: @@ -111,12 +111,12 @@ def get_keras_input_shape_wo_batchsize(self, x_shape: tuple[int, ...]) -> tuple[ class BoxDomain(PerturbationDomain): - def get_upper(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: + def get_upper(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any) -> Tensor: x_min = x[:, 0] x_max = x[:, 1] return get_upper_box(x_min=x_min, x_max=x_max, w=w, b=b, missing_batchsize=missing_batchsize, **kwargs) - def get_lower(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: + def get_lower(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any) -> Tensor: x_min = x[:, 0] x_max = x[:, 1] return get_lower_box(x_min=x_min, x_max=x_max, w=w, b=b, missing_batchsize=missing_batchsize, **kwargs) @@ -162,10 +162,10 @@ def get_config(self) -> dict[str, Any]: ) return config - def get_lower(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: + def get_lower(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any) -> Tensor: return get_lower_ball(x_0=x, eps=self.eps, p=self.p, w=w, b=b, missing_batchsize=missing_batchsize, **kwargs) - def get_upper(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: + def get_upper(self, x: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any) -> Tensor: return get_upper_ball(x_0=x, eps=self.eps, p=self.p, w=w, b=b, missing_batchsize=missing_batchsize, **kwargs) def get_nb_x_components(self) -> int: @@ -178,7 +178,9 @@ def get_upper_x(self, x: Tensor) -> Tensor: return x + self.eps -def get_upper_box(x_min: Tensor, x_max: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: +def get_upper_box( + x_min: Tensor, x_max: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any +) -> Tensor: """Compute the max of an affine function within a box (hypercube) defined by its extremal corners @@ -203,16 +205,18 @@ def get_upper_box(x_min: Tensor, x_max: Tensor, w: Tensor, b: Tensor, missing_ba is_diag = w.shape == b.shape diagonal = (False, is_diag) - missing_batchsize = (False, missing_batchsize) + missing_batchsize_dot = (False, missing_batchsize) return ( - batch_multid_dot(x_max, w_pos, diagonal=diagonal, missing_batchsize=missing_batchsize) - + batch_multid_dot(x_min, w_neg, diagonal=diagonal, missing_batchsize=missing_batchsize) + batch_multid_dot(x_max, w_pos, diagonal=diagonal, missing_batchsize=missing_batchsize_dot) + + batch_multid_dot(x_min, w_neg, diagonal=diagonal, missing_batchsize=missing_batchsize_dot) + b ) -def get_lower_box(x_min: Tensor, x_max: Tensor, w: Tensor, b: Tensor, missing_batchsize=False, **kwargs: Any) -> Tensor: +def get_lower_box( + x_min: Tensor, x_max: Tensor, w: Tensor, b: Tensor, missing_batchsize: bool = False, **kwargs: Any +) -> Tensor: """ Args: x_min: lower bound of the box domain @@ -296,8 +300,8 @@ def get_upper_ball( w_q = get_lq_norm(w, p, axis=reduced_axes) diagonal = (False, is_diag) - missing_batchsize = (False, missing_batchsize) - return batch_multid_dot(x_0, w, diagonal=diagonal, missing_batchsize=missing_batchsize) + b + w_q * eps + missing_batchsize_dot = (False, missing_batchsize) + return batch_multid_dot(x_0, w, diagonal=diagonal, missing_batchsize=missing_batchsize_dot) + b + w_q * eps def get_lower_ball( @@ -343,8 +347,8 @@ def get_lower_ball( w_q = get_lq_norm(w, p, axis=reduced_axes) diagonal = (False, is_diag) - missing_batchsize = (False, missing_batchsize) - return batch_multid_dot(x_0, w, diagonal=diagonal, missing_batchsize=missing_batchsize) + b - w_q * eps + missing_batchsize_dot = (False, missing_batchsize) + return batch_multid_dot(x_0, w, diagonal=diagonal, missing_batchsize=missing_batchsize_dot) + b - w_q * eps def get_lower_ball_finetune(