diff --git a/mrmustard/lab_dev/sampler.py b/mrmustard/lab_dev/sampler.py index d122348e1..b4e0828cf 100644 --- a/mrmustard/lab_dev/sampler.py +++ b/mrmustard/lab_dev/sampler.py @@ -88,7 +88,7 @@ def probabilities(self, state: State | None = None) -> list[float] | None: state.probability if isinstance(state, State) else math.real(state) for state in states ] - return probs + return probs / sum(probs) return self.prob_dist def sample(self, state: State | None = None, n_samples: int = 1000) -> list[any]: @@ -134,6 +134,10 @@ class PNRSampler(Sampler): def __init__(self, modes: Sequence[int], cutoff: int) -> None: super().__init__(list(range(cutoff)), [Number(modes, n) for n in range(cutoff)]) + def probabilities(self, state: State | None = None) -> list[float] | None: + fock_state = state.to_fock() if state else state + return super().probabilities(fock_state) + class HomodyneSampler(Sampler): r""" diff --git a/tests/test_lab_dev/test_sampler.py b/tests/test_lab_dev/test_sampler.py index 288aa3474..5734631a3 100644 --- a/tests/test_lab_dev/test_sampler.py +++ b/tests/test_lab_dev/test_sampler.py @@ -55,7 +55,7 @@ def test_probabilities(self): assert sampler2.probabilities() is None state = Vacuum([0]) - assert sampler2.probabilities(state) == [1, 0, 0] + assert all(sampler2.probabilities(state) == [1, 0, 0]) with pytest.raises(ValueError, match="incompatible"): sampler_two_mode = Sampler( @@ -88,8 +88,8 @@ def test_probabilities(self): sampler = PNRSampler([0, 1], cutoff=10) vac_prob = [1.0] + [0.0] * 9 assert sampler.probabilities() is None - assert sampler.probabilities(Vacuum([0, 1])) == vac_prob - assert sampler.probabilities(Vacuum([0, 1, 2])) == vac_prob + assert all(sampler.probabilities(Vacuum([0, 1])) == vac_prob) + assert all(sampler.probabilities(Vacuum([0, 1, 2])) == vac_prob) class TestHomodyneSampler: