diff --git a/mrmustard/lab_dev/circuit_components.py b/mrmustard/lab_dev/circuit_components.py index 324a93da6..de6dc902c 100644 --- a/mrmustard/lab_dev/circuit_components.py +++ b/mrmustard/lab_dev/circuit_components.py @@ -337,7 +337,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor: # Find where all the bras and kets are so they can be conjugated appropriately conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))] quad_basis = math.sum( - [quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=[0] + [quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=0 ) return quad_basis diff --git a/mrmustard/lab_dev/states/base.py b/mrmustard/lab_dev/states/base.py index 9b9330d9f..dcedad33f 100644 --- a/mrmustard/lab_dev/states/base.py +++ b/mrmustard/lab_dev/states/base.py @@ -386,7 +386,7 @@ def visualize_2d( shape = [max(min_shape, d) for d in self.auto_shape()] state = self.to_fock(tuple(shape)) state = state.dm() - dm = math.sum(state.ansatz.array, axis=[0]) + dm = math.sum(state.ansatz.array, axis=0) x, prob_x = quadrature_distribution(dm) p, prob_p = quadrature_distribution(dm, np.pi / 2) @@ -502,7 +502,7 @@ def visualize_3d( shape = [max(min_shape, d) for d in self.auto_shape()] state = self.to_fock(tuple(shape)) state = state.dm() - dm = math.sum(state.ansatz.array, axis=[0]) + dm = math.sum(state.ansatz.array, axis=0) xvec = np.linspace(*xbounds, resolution) pvec = np.linspace(*pbounds, resolution) @@ -576,7 +576,7 @@ def visualize_dm( raise ValueError("DM visualization not available for multi-mode states.") state = self.to_fock(cutoff) state = state.dm() - dm = math.sum(state.ansatz.array, axis=[0]) + dm = math.sum(state.ansatz.array, axis=0) fig = go.Figure( data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False) diff --git a/mrmustard/math/backend_manager.py b/mrmustard/math/backend_manager.py index 42754890b..6dd6dc7bf 100644 --- a/mrmustard/math/backend_manager.py +++ b/mrmustard/math/backend_manager.py @@ -1102,17 +1102,17 @@ def sqrtm(self, tensor: Tensor, dtype=None) -> Tensor: The square root of ``x``""" return self._apply("sqrtm", (tensor, dtype)) - def sum(self, array: Tensor, axis: Sequence[int] = None): + def sum(self, array: Tensor, axis: int | Sequence[int] | None = None): r"""The sum of array. Args: array: The array to take the sum of - axes (tuple): The axis/axes to sum over + axis (int | Sequence[int] | None): The axis/axes to sum over Returns: The sum of array """ - if axis is not None: + if axis is not None and not isinstance(axis, int): neg = [a for a in axis if a < 0] pos = [a for a in axis if a >= 0] axis = tuple(sorted(neg) + sorted(pos)[::-1]) diff --git a/mrmustard/math/backend_numpy.py b/mrmustard/math/backend_numpy.py index d04eaeaba..643bc3086 100644 --- a/mrmustard/math/backend_numpy.py +++ b/mrmustard/math/backend_numpy.py @@ -390,7 +390,7 @@ def sort(self, array: np.ndarray, axis: int = -1) -> np.ndarray: def sqrt(self, x: np.ndarray, dtype=None) -> np.ndarray: return np.sqrt(self.cast(x, dtype)) - def sum(self, array: np.ndarray, axis: Sequence[int] = None): + def sum(self, array: np.ndarray, axis: int | Sequence[int] | None = None): return np.sum(array, axis=axis) @Autocast() diff --git a/mrmustard/math/backend_tensorflow.py b/mrmustard/math/backend_tensorflow.py index 474bd9760..009f638ed 100644 --- a/mrmustard/math/backend_tensorflow.py +++ b/mrmustard/math/backend_tensorflow.py @@ -336,7 +336,7 @@ def sort(self, array: tf.Tensor, axis: int = -1) -> tf.Tensor: def sqrt(self, x: tf.Tensor, dtype=None) -> tf.Tensor: return tf.sqrt(self.cast(x, dtype)) - def sum(self, array: tf.Tensor, axis: Sequence[int] = None): + def sum(self, array: tf.Tensor, axis: int | tuple[int] | None = None): return tf.reduce_sum(array, axis) @Autocast() diff --git a/mrmustard/physics/ansatz/array_ansatz.py b/mrmustard/physics/ansatz/array_ansatz.py index bb3dd5e25..9c1d8bd47 100644 --- a/mrmustard/physics/ansatz/array_ansatz.py +++ b/mrmustard/physics/ansatz/array_ansatz.py @@ -223,7 +223,7 @@ def sum_batch(self) -> ArrayAnsatz: Returns: The collapsed ArrayAnsatz object. """ - return ArrayAnsatz(math.expand_dims(math.sum(self.array, axis=[0]), 0), batched=True) + return ArrayAnsatz(math.expand_dims(math.sum(self.array, axis=0), 0), batched=True) def to_dict(self) -> dict[str, ArrayLike]: return {"array": self.data} diff --git a/mrmustard/physics/ansatz/polyexp_ansatz.py b/mrmustard/physics/ansatz/polyexp_ansatz.py index 42d9ffed2..f42360327 100644 --- a/mrmustard/physics/ansatz/polyexp_ansatz.py +++ b/mrmustard/physics/ansatz/polyexp_ansatz.py @@ -411,12 +411,12 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: self.A[..., :dim_alpha, :dim_alpha] * zz, axis=[-1, -2] ) # sum((b_arg,1,n,n) * (b_abc,n,n), [-1,-2]) ~ (b_arg,b_abc) b_part = math.sum( - self.b[..., :dim_alpha] * z[..., None, :], axis=[-1] + self.b[..., :dim_alpha] * z[..., None, :], axis=-1 ) # sum((b_arg,1,n) * (b_abc,n), [-1]) ~ (b_arg,b_abc) exp_sum = math.exp(1 / 2 * A_part + b_part) # (b_arg, b_abc) if dim_beta == 0: - val = math.sum(exp_sum * self.c, axis=[-1]) # (b_arg) + val = math.sum(exp_sum * self.c, axis=-1) # (b_arg) else: b_poly = math.astensor( math.einsum( @@ -441,7 +441,7 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz: poly * self.c, axis=math.arange(2, 2 + dim_beta, dtype=math.int32).tolist(), ), - axis=[-1], + axis=-1, ) # (b_arg) return val diff --git a/mrmustard/physics/gaussian_integrals.py b/mrmustard/physics/gaussian_integrals.py index c36eddc4e..4dce69552 100644 --- a/mrmustard/physics/gaussian_integrals.py +++ b/mrmustard/physics/gaussian_integrals.py @@ -382,7 +382,7 @@ def complex_gaussian_integral_1( inv_M = math.inv(M) c_post = c * math.reshape( math.sqrt(math.cast((-1) ** m / det_M, "complex128")) - * math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axis=[-1])), + * math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axis=-1)), c.shape[:1] + (1,) * (len(c.shape) - 1), ) A_post = R - math.einsum("bij,bjk,blk->bil", D, inv_M, D) diff --git a/mrmustard/physics/representations.py b/mrmustard/physics/representations.py index db40a5c74..891c7c99b 100644 --- a/mrmustard/physics/representations.py +++ b/mrmustard/physics/representations.py @@ -155,7 +155,7 @@ def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor f"Expected Fock shape of length {num_vars}, got length {len(shape)}" ) from e arrays = self.ansatz.reduce(shape).array - array = math.sum(arrays, axis=[0]) + array = math.sum(arrays, axis=0) arrays = math.expand_dims(array, 0) if batched else array return arrays diff --git a/tests/test_lab_dev/test_transformations/test_cft.py b/tests/test_lab_dev/test_transformations/test_cft.py index f93812ca7..aa350aca5 100644 --- a/tests/test_lab_dev/test_transformations/test_cft.py +++ b/tests/test_lab_dev/test_transformations/test_cft.py @@ -41,7 +41,7 @@ def test_wigner_function(self): state = Ket.random([0]) >> Dgate([0], x=1.0, y=0.1) - dm = math.sum(state.to_fock(100).dm().ansatz.array, axis=[0]) + dm = math.sum(state.to_fock(100).dm().ansatz.array, axis=0) vec = np.linspace(-5, 5, 100) wigner, _, _ = wigner_discretized(dm, vec, vec)