Skip to content

Commit

Permalink
fix axis use for sum
Browse files Browse the repository at this point in the history
  • Loading branch information
timmysilv committed Jan 20, 2025
1 parent 14ca459 commit 32f9cca
Show file tree
Hide file tree
Showing 10 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def quadrature(self, quad: Batch[Vector], phi: float = 0.0) -> ComplexTensor:
# Find where all the bras and kets are so they can be conjugated appropriately
conjugates = [i not in self.wires.ket.indices for i in range(len(self.wires.indices))]
quad_basis = math.sum(
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=[0]
[quadrature_basis(array, quad, conjugates, phi) for array in fock_arrays], axis=0
)
return quad_basis

Expand Down
6 changes: 3 additions & 3 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def visualize_2d(
shape = [max(min_shape, d) for d in self.auto_shape()]
state = self.to_fock(tuple(shape))
state = state.dm()
dm = math.sum(state.ansatz.array, axis=[0])
dm = math.sum(state.ansatz.array, axis=0)

x, prob_x = quadrature_distribution(dm)
p, prob_p = quadrature_distribution(dm, np.pi / 2)
Expand Down Expand Up @@ -502,7 +502,7 @@ def visualize_3d(
shape = [max(min_shape, d) for d in self.auto_shape()]
state = self.to_fock(tuple(shape))
state = state.dm()
dm = math.sum(state.ansatz.array, axis=[0])
dm = math.sum(state.ansatz.array, axis=0)

xvec = np.linspace(*xbounds, resolution)
pvec = np.linspace(*pbounds, resolution)
Expand Down Expand Up @@ -576,7 +576,7 @@ def visualize_dm(
raise ValueError("DM visualization not available for multi-mode states.")
state = self.to_fock(cutoff)
state = state.dm()
dm = math.sum(state.ansatz.array, axis=[0])
dm = math.sum(state.ansatz.array, axis=0)

fig = go.Figure(
data=go.Heatmap(z=abs(dm), colorscale="viridis", name="abs(ρ)", showscale=False)
Expand Down
6 changes: 3 additions & 3 deletions mrmustard/math/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,17 +1102,17 @@ def sqrtm(self, tensor: Tensor, dtype=None) -> Tensor:
The square root of ``x``"""
return self._apply("sqrtm", (tensor, dtype))

def sum(self, array: Tensor, axis: Sequence[int] = None):
def sum(self, array: Tensor, axis: int | Sequence[int] | None = None):
r"""The sum of array.
Args:
array: The array to take the sum of
axes (tuple): The axis/axes to sum over
axis (int | Sequence[int] | None): The axis/axes to sum over
Returns:
The sum of array
"""
if axis is not None:
if axis is not None and not isinstance(axis, int):
neg = [a for a in axis if a < 0]
pos = [a for a in axis if a >= 0]
axis = tuple(sorted(neg) + sorted(pos)[::-1])
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/math/backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def sort(self, array: np.ndarray, axis: int = -1) -> np.ndarray:
def sqrt(self, x: np.ndarray, dtype=None) -> np.ndarray:
return np.sqrt(self.cast(x, dtype))

def sum(self, array: np.ndarray, axis: Sequence[int] = None):
def sum(self, array: np.ndarray, axis: int | Sequence[int] | None = None):
return np.sum(array, axis=axis)

@Autocast()
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/math/backend_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def sort(self, array: tf.Tensor, axis: int = -1) -> tf.Tensor:
def sqrt(self, x: tf.Tensor, dtype=None) -> tf.Tensor:
return tf.sqrt(self.cast(x, dtype))

def sum(self, array: tf.Tensor, axis: Sequence[int] = None):
def sum(self, array: tf.Tensor, axis: int | tuple[int] | None = None):
return tf.reduce_sum(array, axis)

@Autocast()
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/ansatz/array_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def sum_batch(self) -> ArrayAnsatz:
Returns:
The collapsed ArrayAnsatz object.
"""
return ArrayAnsatz(math.expand_dims(math.sum(self.array, axis=[0]), 0), batched=True)
return ArrayAnsatz(math.expand_dims(math.sum(self.array, axis=0), 0), batched=True)

def to_dict(self) -> dict[str, ArrayLike]:
return {"array": self.data}
Expand Down
6 changes: 3 additions & 3 deletions mrmustard/physics/ansatz/polyexp_ansatz.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,12 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz:
self.A[..., :dim_alpha, :dim_alpha] * zz, axis=[-1, -2]
) # sum((b_arg,1,n,n) * (b_abc,n,n), [-1,-2]) ~ (b_arg,b_abc)
b_part = math.sum(
self.b[..., :dim_alpha] * z[..., None, :], axis=[-1]
self.b[..., :dim_alpha] * z[..., None, :], axis=-1
) # sum((b_arg,1,n) * (b_abc,n), [-1]) ~ (b_arg,b_abc)

exp_sum = math.exp(1 / 2 * A_part + b_part) # (b_arg, b_abc)
if dim_beta == 0:
val = math.sum(exp_sum * self.c, axis=[-1]) # (b_arg)
val = math.sum(exp_sum * self.c, axis=-1) # (b_arg)
else:
b_poly = math.astensor(
math.einsum(
Expand All @@ -441,7 +441,7 @@ def _call_all(self, z: Batch[Vector]) -> PolyExpAnsatz:
poly * self.c,
axis=math.arange(2, 2 + dim_beta, dtype=math.int32).tolist(),
),
axis=[-1],
axis=-1,
) # (b_arg)
return val

Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/gaussian_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def complex_gaussian_integral_1(
inv_M = math.inv(M)
c_post = c * math.reshape(
math.sqrt(math.cast((-1) ** m / det_M, "complex128"))
* math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axis=[-1])),
* math.exp(-0.5 * math.sum(bM * math.solve(M, bM), axis=-1)),
c.shape[:1] + (1,) * (len(c.shape) - 1),
)
A_post = R - math.einsum("bij,bjk,blk->bil", D, inv_M, D)
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/physics/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def fock_array(self, shape: int | Sequence[int], batched=False) -> ComplexTensor
f"Expected Fock shape of length {num_vars}, got length {len(shape)}"
) from e
arrays = self.ansatz.reduce(shape).array
array = math.sum(arrays, axis=[0])
array = math.sum(arrays, axis=0)
arrays = math.expand_dims(array, 0) if batched else array
return arrays

Expand Down
2 changes: 1 addition & 1 deletion tests/test_lab_dev/test_transformations/test_cft.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_wigner_function(self):

state = Ket.random([0]) >> Dgate([0], x=1.0, y=0.1)

dm = math.sum(state.to_fock(100).dm().ansatz.array, axis=[0])
dm = math.sum(state.to_fock(100).dm().ansatz.array, axis=0)
vec = np.linspace(-5, 5, 100)
wigner, _, _ = wigner_discretized(dm, vec, vec)

Expand Down

0 comments on commit 32f9cca

Please sign in to comment.