From 5784b113c3fbec5a68d78c05723f8aefe19de8b7 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Wed, 15 Nov 2023 14:17:44 -0800 Subject: [PATCH] Add `stim.PauliString.iter_all` (#654) - Add `sinter plot --ymax` - Add `simd_bits::countr_zero` - Add `simd_bits::operator-=` - Add `stim.PauliString.iter_all` - Add `stim.PauliStringIterator` - Autoformat the code Fixes https://github.com/quantumlib/Stim/issues/397 --- dev/util_gen_stub_file.py | 3 +- doc/python_api_reference_vDev.md | 119 +++++++ doc/stim.pyi | 87 +++++ file_lists/benchmark_files | 1 + file_lists/python_api_files | 1 + file_lists/test_files | 1 + glue/python/src/stim/__init__.pyi | 87 +++++ glue/sample/src/sinter/__init__.py | 3 + glue/sample/src/sinter/_main_plot.py | 12 +- src/stim.h | 1 + src/stim/circuit/stabilizer_flow.inl | 75 ++--- src/stim/io/measure_record_batch.inl | 13 +- src/stim/io/measure_record_reader.inl | 43 ++- src/stim/io/measure_record_reader.perf.cc | 32 +- src/stim/main_namespaced.perf.cc | 36 ++- src/stim/mem/simd_bit_table.inl | 22 +- src/stim/mem/simd_bits.h | 3 + src/stim/mem/simd_bits.inl | 11 + src/stim/mem/simd_bits.test.cc | 13 + src/stim/mem/simd_bits_range_ref.h | 3 + src/stim/mem/simd_bits_range_ref.inl | 24 ++ src/stim/mem/simd_bits_range_ref.test.cc | 14 + src/stim/py/stim.pybind.cc | 3 + .../count_determined_measurements.inl | 3 +- src/stim/simulators/dem_sampler.inl | 13 +- src/stim/simulators/frame_simulator.inl | 215 +++++++------ src/stim/simulators/frame_simulator_util.inl | 23 +- .../measurements_to_detection_events.inl | 12 +- src/stim/simulators/tableau_simulator.inl | 40 +-- src/stim/stabilizers/conversions.inl | 9 +- src/stim/stabilizers/pauli_string.inl | 3 +- src/stim/stabilizers/pauli_string.pybind.cc | 94 ++++++ src/stim/stabilizers/pauli_string_iter.h | 145 +++++++++ src/stim/stabilizers/pauli_string_iter.inl | 75 +++++ .../stabilizers/pauli_string_iter.perf.cc | 55 ++++ .../stabilizers/pauli_string_iter.pybind.cc | 78 +++++ .../stabilizers/pauli_string_iter.pybind.h | 28 ++ .../stabilizers/pauli_string_iter.test.cc | 297 ++++++++++++++++++ .../stabilizers/pauli_string_pybind_test.py | 43 +++ src/stim/stabilizers/pauli_string_ref.inl | 42 +-- src/stim/stabilizers/tableau.inl | 10 +- src/stim/stabilizers/tableau_iter.inl | 7 +- src/stim/stabilizers/tableau_iter.perf.cc | 4 +- .../stabilizers/tableau_transposed_raii.inl | 27 +- 44 files changed, 1510 insertions(+), 320 deletions(-) create mode 100644 src/stim/stabilizers/pauli_string_iter.h create mode 100644 src/stim/stabilizers/pauli_string_iter.inl create mode 100644 src/stim/stabilizers/pauli_string_iter.perf.cc create mode 100644 src/stim/stabilizers/pauli_string_iter.pybind.cc create mode 100644 src/stim/stabilizers/pauli_string_iter.pybind.h create mode 100644 src/stim/stabilizers/pauli_string_iter.test.cc diff --git a/dev/util_gen_stub_file.py b/dev/util_gen_stub_file.py index ec0f723c2..3bfabd0fe 100644 --- a/dev/util_gen_stub_file.py +++ b/dev/util_gen_stub_file.py @@ -252,6 +252,8 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Opt elif isinstance(obj, (int, str)): text = f"{term_name}: {type(obj).__name__} = {obj!r}" doc = '' + elif term_name == term_name.upper(): + return None # Skip constants because they lack a doc string. else: text = f"class {term_name}" if inspect.isabstract(obj): @@ -296,7 +298,6 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Opt def generate_documentation(*, obj: object, level: int, full_name: str) -> Iterator[DescribedObject]: - if full_name.endswith("__"): return if not inspect.ismodule(obj) and not inspect.isclass(obj): diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index 17bae3a03..b729716b3 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -247,12 +247,16 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`stim.PauliString.copy`](#stim.PauliString.copy) - [`stim.PauliString.from_numpy`](#stim.PauliString.from_numpy) - [`stim.PauliString.from_unitary_matrix`](#stim.PauliString.from_unitary_matrix) + - [`stim.PauliString.iter_all`](#stim.PauliString.iter_all) - [`stim.PauliString.random`](#stim.PauliString.random) - [`stim.PauliString.sign`](#stim.PauliString.sign) - [`stim.PauliString.to_numpy`](#stim.PauliString.to_numpy) - [`stim.PauliString.to_tableau`](#stim.PauliString.to_tableau) - [`stim.PauliString.to_unitary_matrix`](#stim.PauliString.to_unitary_matrix) - [`stim.PauliString.weight`](#stim.PauliString.weight) +- [`stim.PauliStringIterator`](#stim.PauliStringIterator) + - [`stim.PauliStringIterator.__iter__`](#stim.PauliStringIterator.__iter__) + - [`stim.PauliStringIterator.__next__`](#stim.PauliStringIterator.__next__) - [`stim.Tableau`](#stim.Tableau) - [`stim.Tableau.__add__`](#stim.Tableau.__add__) - [`stim.Tableau.__call__`](#stim.Tableau.__call__) @@ -8346,6 +8350,68 @@ def from_unitary_matrix( """ ``` + +```python +# stim.PauliString.iter_all + +# (in class stim.PauliString) +@staticmethod +def iter_all( + num_qubits: int, + *, + min_weight: int = 0, + max_weight: object = None, + allowed_paulis: str = 'XYZ', +) -> stim.PauliStringIterator: + """Returns an iterator that iterates over all matching pauli strings. + + Args: + num_qubits: The desired number of qubits in the pauli strings. + min_weight: Defaults to 0. The minimum number of non-identity terms that + must be present in each yielded pauli string. + max_weight: Defaults to None (unused). The maximum number of non-identity + terms that must be present in each yielded pauli string. + allowed_paulis: Defaults to "XYZ". Set this to a string containing the + non-identity paulis that are allowed to appear in each yielded pauli + string. This argument must be a string made up of only "X", "Y", and + "Z" characters. A non-identity Pauli is allowed if it appears in the + string, and not allowed if it doesn't. Identity Paulis are always + allowed. + + Returns: + An Iterable[stim.PauliString] that yields the requested pauli strings. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... num_qubits=3, + ... min_weight=1, + ... max_weight=2, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X__ + +Z__ + +_X_ + +_Z_ + +__X + +__Z + +XX_ + +XZ_ + +ZX_ + +ZZ_ + +X_X + +X_Z + +Z_X + +Z_Z + +_XX + +_XZ + +_ZX + +_ZZ + """ +``` + ```python # stim.PauliString.random @@ -8577,6 +8643,59 @@ def weight( """ ``` + +```python +# stim.PauliStringIterator + +# (at top-level in the stim module) +class PauliStringIterator: + """Iterates over all pauli strings matching specified patterns. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... 2, + ... min_weight=1, + ... max_weight=1, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X_ + +Z_ + +_X + +_Z + """ +``` + + +```python +# stim.PauliStringIterator.__iter__ + +# (in class stim.PauliStringIterator) +def __iter__( + self, +) -> stim.PauliStringIterator: + """Returns an independent copy of the pauli string iterator. + + Since for-loops and loop-comprehensions call `iter` on things they + iterate, this effectively allows the iterator to be iterated + multiple times. + """ +``` + + +```python +# stim.PauliStringIterator.__next__ + +# (in class stim.PauliStringIterator) +def __next__( + self, +) -> stim.PauliString: + """Returns the next iterated pauli string. + """ +``` + ```python # stim.Tableau diff --git a/doc/stim.pyi b/doc/stim.pyi index 3b2def10e..bb3b2f8cd 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -6403,6 +6403,61 @@ class PauliString: stim.PauliString("+XZ") """ @staticmethod + def iter_all( + num_qubits: int, + *, + min_weight: int = 0, + max_weight: object = None, + allowed_paulis: str = 'XYZ', + ) -> stim.PauliStringIterator: + """Returns an iterator that iterates over all matching pauli strings. + + Args: + num_qubits: The desired number of qubits in the pauli strings. + min_weight: Defaults to 0. The minimum number of non-identity terms that + must be present in each yielded pauli string. + max_weight: Defaults to None (unused). The maximum number of non-identity + terms that must be present in each yielded pauli string. + allowed_paulis: Defaults to "XYZ". Set this to a string containing the + non-identity paulis that are allowed to appear in each yielded pauli + string. This argument must be a string made up of only "X", "Y", and + "Z" characters. A non-identity Pauli is allowed if it appears in the + string, and not allowed if it doesn't. Identity Paulis are always + allowed. + + Returns: + An Iterable[stim.PauliString] that yields the requested pauli strings. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... num_qubits=3, + ... min_weight=1, + ... max_weight=2, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X__ + +Z__ + +_X_ + +_Z_ + +__X + +__Z + +XX_ + +XZ_ + +ZX_ + +ZZ_ + +X_X + +X_Z + +Z_X + +Z_Z + +_XX + +_XZ + +_ZX + +_ZZ + """ + @staticmethod def random( num_qubits: int, *, @@ -6591,6 +6646,38 @@ class PauliString: >>> stim.PauliString("-XXX___XXYZ").weight 7 """ +class PauliStringIterator: + """Iterates over all pauli strings matching specified patterns. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... 2, + ... min_weight=1, + ... max_weight=1, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X_ + +Z_ + +_X + +_Z + """ + def __iter__( + self, + ) -> stim.PauliStringIterator: + """Returns an independent copy of the pauli string iterator. + + Since for-loops and loop-comprehensions call `iter` on things they + iterate, this effectively allows the iterator to be iterated + multiple times. + """ + def __next__( + self, + ) -> stim.PauliString: + """Returns the next iterated pauli string. + """ class Tableau: """A stabilizer tableau. diff --git a/file_lists/benchmark_files b/file_lists/benchmark_files index d91532e93..a2b5f1e28 100644 --- a/file_lists/benchmark_files +++ b/file_lists/benchmark_files @@ -16,5 +16,6 @@ src/stim/simulators/frame_simulator.perf.cc src/stim/simulators/tableau_simulator.perf.cc src/stim/stabilizers/conversions.perf.cc src/stim/stabilizers/pauli_string.perf.cc +src/stim/stabilizers/pauli_string_iter.perf.cc src/stim/stabilizers/tableau.perf.cc src/stim/stabilizers/tableau_iter.perf.cc diff --git a/file_lists/python_api_files b/file_lists/python_api_files index c8e36a645..50bc6ecac 100644 --- a/file_lists/python_api_files +++ b/file_lists/python_api_files @@ -21,5 +21,6 @@ src/stim/simulators/matched_error.pybind.cc src/stim/simulators/measurements_to_detection_events.pybind.cc src/stim/simulators/tableau_simulator.pybind.cc src/stim/stabilizers/pauli_string.pybind.cc +src/stim/stabilizers/pauli_string_iter.pybind.cc src/stim/stabilizers/tableau.pybind.cc src/stim/stabilizers/tableau_iter.pybind.cc diff --git a/file_lists/test_files b/file_lists/test_files index 593a91d0f..3fd76c0d8 100644 --- a/file_lists/test_files +++ b/file_lists/test_files @@ -70,6 +70,7 @@ src/stim/simulators/transform_without_feedback.test.cc src/stim/simulators/vector_simulator.test.cc src/stim/stabilizers/conversions.test.cc src/stim/stabilizers/pauli_string.test.cc +src/stim/stabilizers/pauli_string_iter.test.cc src/stim/stabilizers/tableau.test.cc src/stim/stabilizers/tableau_iter.test.cc src/stim/str_util.test.cc diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index 3b2def10e..bb3b2f8cd 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -6403,6 +6403,61 @@ class PauliString: stim.PauliString("+XZ") """ @staticmethod + def iter_all( + num_qubits: int, + *, + min_weight: int = 0, + max_weight: object = None, + allowed_paulis: str = 'XYZ', + ) -> stim.PauliStringIterator: + """Returns an iterator that iterates over all matching pauli strings. + + Args: + num_qubits: The desired number of qubits in the pauli strings. + min_weight: Defaults to 0. The minimum number of non-identity terms that + must be present in each yielded pauli string. + max_weight: Defaults to None (unused). The maximum number of non-identity + terms that must be present in each yielded pauli string. + allowed_paulis: Defaults to "XYZ". Set this to a string containing the + non-identity paulis that are allowed to appear in each yielded pauli + string. This argument must be a string made up of only "X", "Y", and + "Z" characters. A non-identity Pauli is allowed if it appears in the + string, and not allowed if it doesn't. Identity Paulis are always + allowed. + + Returns: + An Iterable[stim.PauliString] that yields the requested pauli strings. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... num_qubits=3, + ... min_weight=1, + ... max_weight=2, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X__ + +Z__ + +_X_ + +_Z_ + +__X + +__Z + +XX_ + +XZ_ + +ZX_ + +ZZ_ + +X_X + +X_Z + +Z_X + +Z_Z + +_XX + +_XZ + +_ZX + +_ZZ + """ + @staticmethod def random( num_qubits: int, *, @@ -6591,6 +6646,38 @@ class PauliString: >>> stim.PauliString("-XXX___XXYZ").weight 7 """ +class PauliStringIterator: + """Iterates over all pauli strings matching specified patterns. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... 2, + ... min_weight=1, + ... max_weight=1, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X_ + +Z_ + +_X + +_Z + """ + def __iter__( + self, + ) -> stim.PauliStringIterator: + """Returns an independent copy of the pauli string iterator. + + Since for-loops and loop-comprehensions call `iter` on things they + iterate, this effectively allows the iterator to be iterated + multiple times. + """ + def __next__( + self, + ) -> stim.PauliString: + """Returns the next iterated pauli string. + """ class Tableau: """A stabilizer tableau. diff --git a/glue/sample/src/sinter/__init__.py b/glue/sample/src/sinter/__init__.py index 17fd45c86..da5836199 100644 --- a/glue/sample/src/sinter/__init__.py +++ b/glue/sample/src/sinter/__init__.py @@ -15,6 +15,9 @@ from sinter._csv_out import ( CSV_HEADER, ) +from sinter._decoding_all_built_in_decoders import ( + BUILT_IN_DECODERS, +) from sinter._existing_data import ( read_stats_from_csv_files, stats_from_csv_files, diff --git a/glue/sample/src/sinter/_main_plot.py b/glue/sample/src/sinter/_main_plot.py index 86f91713c..c6defeb90 100644 --- a/glue/sample/src/sinter/_main_plot.py +++ b/glue/sample/src/sinter/_main_plot.py @@ -220,7 +220,11 @@ def parse_args(args: List[str]) -> Any: parser.add_argument('--ymin', default=None, type=float, - help='Sets the minimum value of the y axis (max always 1).') + help='Forces the minimum value of the y axis.') + parser.add_argument('--ymax', + default=None, + type=float, + help='Forces the maximum value of the y axis.') parser.add_argument('--title', default=None, type=str, @@ -474,6 +478,7 @@ def _plot_helper( xaxis: str, yaxis: Optional[str], min_y: Optional[float], + max_y: Optional[float], max_x: Optional[float], min_x: Optional[float], title: Optional[str], @@ -557,7 +562,7 @@ def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: y_not_x=True, axis_label=f"Logical Error Rate (per {failure_unit})" if yaxis is None else yaxis, default_scale='log', - forced_max_v=1 if min_y is None or 1 > min_y else None, + forced_max_v=max_y if max_y is not None else 1 if min_y is None or 1 > min_y else None, default_min_v=1e-4, default_max_v=1, forced_min_v=min_y, @@ -612,6 +617,8 @@ def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: default_max_v=1, plotted_stats=plotted_stats, v_func=y_func, + forced_min_v=min_y, + forced_max_v=max_y, ) plot_custom( ax=ax_cus, @@ -754,6 +761,7 @@ def main_plot(*, command_line_args: List[str]): yaxis=args.yaxis, fig_size=args.fig_size, min_y=args.ymin, + max_y=args.ymax, max_x=args.xmax, min_x=args.xmin, highlight_max_likelihood_factor=args.highlight_max_likelihood_factor, diff --git a/src/stim.h b/src/stim.h index a70c5b15b..6c5a2a040 100644 --- a/src/stim.h +++ b/src/stim.h @@ -98,6 +98,7 @@ #include "stim/simulators/vector_simulator.h" #include "stim/stabilizers/conversions.h" #include "stim/stabilizers/pauli_string.h" +#include "stim/stabilizers/pauli_string_iter.h" #include "stim/stabilizers/pauli_string_ref.h" #include "stim/stabilizers/tableau.h" #include "stim/stabilizers/tableau_iter.h" diff --git a/src/stim/circuit/stabilizer_flow.inl b/src/stim/circuit/stabilizer_flow.inl index 3566d1c11..1766b6308 100644 --- a/src/stim/circuit/stabilizer_flow.inl +++ b/src/stim/circuit/stabilizer_flow.inl @@ -1,7 +1,6 @@ -#include "stim/circuit/stabilizer_flow.h" - #include "stim/arg_parse.h" #include "stim/circuit/circuit.h" +#include "stim/circuit/stabilizer_flow.h" #include "stim/simulators/frame_simulator_util.h" #include "stim/simulators/tableau_simulator.h" @@ -26,11 +25,7 @@ void _pauli_string_controlled_not(PauliStringRef control, uint32_t target, Ci template bool _check_if_circuit_has_stabilizer_flow( - size_t num_samples, - std::mt19937_64 &rng, - const Circuit &circuit, - const StabilizerFlow &flow) { - + size_t num_samples, std::mt19937_64 &rng, const Circuit &circuit, const StabilizerFlow &flow) { uint32_t n = (uint32_t)circuit.count_qubits(); n = std::max(n, (uint32_t)flow.input.num_qubits); n = std::max(n, (uint32_t)flow.output.num_qubits); @@ -56,11 +51,7 @@ bool _check_if_circuit_has_stabilizer_flow( augmented_circuit.safe_append_u("M", {n}, {}); auto out = sample_batch_measurements( - augmented_circuit, - TableauSimulator::reference_sample_circuit(augmented_circuit), - num_samples, - rng, - false); + augmented_circuit, TableauSimulator::reference_sample_circuit(augmented_circuit), num_samples, rng, false); size_t m = augmented_circuit.count_measurements() - 1; return !out[m].not_zero(); @@ -68,14 +59,10 @@ bool _check_if_circuit_has_stabilizer_flow( template std::vector check_if_circuit_has_stabilizer_flows( - size_t num_samples, - std::mt19937_64 &rng, - const Circuit &circuit, - const std::vector> flows) { + size_t num_samples, std::mt19937_64 &rng, const Circuit &circuit, const std::vector> flows) { std::vector result; for (const auto &flow : flows) { - result.push_back(_check_if_circuit_has_stabilizer_flow( - num_samples, rng, circuit, flow)); + result.push_back(_check_if_circuit_has_stabilizer_flow(num_samples, rng, circuit, flow)); } return result; } @@ -85,42 +72,44 @@ StabilizerFlow StabilizerFlow::from_str(const char *text) { try { auto parts = split('>', text); if (parts.size() != 2 || parts[0].empty() || parts[0].back() != '-') { - throw std::invalid_argument(""); + throw std::invalid_argument(""); } parts[0].pop_back(); while (!parts[0].empty() && parts[0].back() == ' ') { - parts[0].pop_back(); + parts[0].pop_back(); } - PauliString input = parts[0] == "1" ? PauliString(0) : parts[0] == "-1" ? PauliString::from_str("-") : PauliString::from_str(parts[0].c_str()); + PauliString input = parts[0] == "1" ? PauliString(0) + : parts[0] == "-1" ? PauliString::from_str("-") + : PauliString::from_str(parts[0].c_str()); parts = split(' ', parts[1]); size_t k = 0; while (k < parts.size() && parts[k].empty()) { - k += 1; + k += 1; } PauliString output(0); std::vector measurements; if (!parts[k].empty() && parts[k][0] != 'r') { - output = PauliString::from_str(parts[k].c_str()); + output = PauliString::from_str(parts[k].c_str()); } else { - auto t = stim::GateTarget::from_target_str(parts[k].c_str()); - if (!t.is_measurement_record_target()) { - throw std::invalid_argument(""); - } - measurements.push_back(t); + auto t = stim::GateTarget::from_target_str(parts[k].c_str()); + if (!t.is_measurement_record_target()) { + throw std::invalid_argument(""); + } + measurements.push_back(t); } k++; while (k < parts.size()) { - if (parts[k] != "xor" || k + 1 == parts.size()) { - throw std::invalid_argument(""); - } - auto t = stim::GateTarget::from_target_str(parts[k + 1].c_str()); - if (!t.is_measurement_record_target()) { - throw std::invalid_argument(""); - } - measurements.push_back(t); - k += 2; + if (parts[k] != "xor" || k + 1 == parts.size()) { + throw std::invalid_argument(""); + } + auto t = stim::GateTarget::from_target_str(parts[k + 1].c_str()); + if (!t.is_measurement_record_target()) { + throw std::invalid_argument(""); + } + measurements.push_back(t); + k += 2; } return StabilizerFlow{input, output, measurements}; } catch (const std::invalid_argument &ex) { @@ -130,12 +119,12 @@ StabilizerFlow StabilizerFlow::from_str(const char *text) { template bool StabilizerFlow::operator==(const StabilizerFlow &other) const { - return input == other.input && output == other.output && measurement_outputs == other.measurement_outputs; + return input == other.input && output == other.output && measurement_outputs == other.measurement_outputs; } template bool StabilizerFlow::operator!=(const StabilizerFlow &other) const { - return !(*this == other); + return !(*this == other); } template @@ -149,7 +138,7 @@ template std::ostream &operator<<(std::ostream &out, const StabilizerFlow &flow) { if (flow.input.num_qubits == 0) { if (flow.input.sign) { - out << "-"; + out << "-"; } out << "1"; } else { @@ -159,9 +148,9 @@ std::ostream &operator<<(std::ostream &out, const StabilizerFlow &flow) { bool skip_xor = false; if (flow.output.num_qubits == 0) { if (flow.output.sign) { - out << "-1"; + out << "-1"; } else if (flow.measurement_outputs.empty()) { - out << "+1"; + out << "+1"; } skip_xor = true; } else { @@ -169,7 +158,7 @@ std::ostream &operator<<(std::ostream &out, const StabilizerFlow &flow) { } for (const auto &t : flow.measurement_outputs) { if (!skip_xor) { - out << " xor "; + out << " xor "; } skip_xor = false; t.write_succinct(out); diff --git a/src/stim/io/measure_record_batch.inl b/src/stim/io/measure_record_batch.inl index 03ebec9f4..dee0405ae 100644 --- a/src/stim/io/measure_record_batch.inl +++ b/src/stim/io/measure_record_batch.inl @@ -14,10 +14,9 @@ * limitations under the License. */ -#include "stim/io/measure_record_batch.h" - #include +#include "stim/io/measure_record_batch.h" #include "stim/io/measure_record_batch_writer.h" #include "stim/probability_util.h" @@ -25,7 +24,13 @@ namespace stim { template MeasureRecordBatch::MeasureRecordBatch(size_t num_shots, size_t max_lookback) - : num_shots(num_shots), max_lookback(max_lookback), unwritten(0), stored(0), written(0), shot_mask(num_shots), storage(1, num_shots) { + : num_shots(num_shots), + max_lookback(max_lookback), + unwritten(0), + stored(0), + written(0), + shot_mask(num_shots), + storage(1, num_shots) { for (size_t k = 0; k < num_shots; k++) { shot_mask[k] = true; } @@ -163,4 +168,4 @@ void MeasureRecordBatch::destructive_resize(size_t new_num_shots, size_t new_ } } -} +} // namespace stim diff --git a/src/stim/io/measure_record_reader.inl b/src/stim/io/measure_record_reader.inl index fa202d25a..40b9384a6 100644 --- a/src/stim/io/measure_record_reader.inl +++ b/src/stim/io/measure_record_reader.inl @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "stim/io/measure_record_reader.h" - #include +#include "stim/io/measure_record_reader.h" + namespace stim { template @@ -48,17 +48,23 @@ std::unique_ptr> MeasureRecordReader::make( FILE *in, SampleFormat input_format, size_t num_measurements, size_t num_detectors, size_t num_observables) { switch (input_format) { case SampleFormat::SAMPLE_FORMAT_01: - return std::make_unique>(in, num_measurements, num_detectors, num_observables); + return std::make_unique>( + in, num_measurements, num_detectors, num_observables); case SampleFormat::SAMPLE_FORMAT_B8: - return std::make_unique>(in, num_measurements, num_detectors, num_observables); + return std::make_unique>( + in, num_measurements, num_detectors, num_observables); case SampleFormat::SAMPLE_FORMAT_DETS: - return std::make_unique>(in, num_measurements, num_detectors, num_observables); + return std::make_unique>( + in, num_measurements, num_detectors, num_observables); case SampleFormat::SAMPLE_FORMAT_HITS: - return std::make_unique>(in, num_measurements, num_detectors, num_observables); + return std::make_unique>( + in, num_measurements, num_detectors, num_observables); case SampleFormat::SAMPLE_FORMAT_PTB64: - return std::make_unique>(in, num_measurements, num_detectors, num_observables); + return std::make_unique>( + in, num_measurements, num_detectors, num_observables); case SampleFormat::SAMPLE_FORMAT_R8: - return std::make_unique>(in, num_measurements, num_detectors, num_observables); + return std::make_unique>( + in, num_measurements, num_detectors, num_observables); default: throw std::invalid_argument("Sample format not recognized by MeasurementRecordReader"); } @@ -92,8 +98,7 @@ void MeasureRecordReader::move_obs_in_shots_to_mask_assuming_sorted(SparseSho } template -size_t MeasureRecordReader::read_into_table_with_major_shot_index( - simd_bit_table &out_table, size_t max_shots) { +size_t MeasureRecordReader::read_into_table_with_major_shot_index(simd_bit_table &out_table, size_t max_shots) { size_t read_shots = 0; while (read_shots < max_shots && start_and_read_entire_record(out_table[read_shots])) { read_shots++; @@ -110,8 +115,7 @@ MeasureRecordReaderFormat01::MeasureRecordReaderFormat01( } template -bool MeasureRecordReaderFormat01::start_and_read_entire_record( - simd_bits_range_ref dirty_out_buffer) { +bool MeasureRecordReaderFormat01::start_and_read_entire_record(simd_bits_range_ref dirty_out_buffer) { return start_and_read_entire_record_helper( [&](size_t k) { dirty_out_buffer[k] = false; @@ -212,8 +216,7 @@ MeasureRecordReaderFormatB8::MeasureRecordReaderFormatB8( } template -bool MeasureRecordReaderFormatB8::start_and_read_entire_record( - simd_bits_range_ref dirty_out_buffer) { +bool MeasureRecordReaderFormatB8::start_and_read_entire_record(simd_bits_range_ref dirty_out_buffer) { size_t n = this->bits_per_record(); size_t nb = (n + 7) >> 3; size_t nr = fread(dirty_out_buffer.u8, 1, nb, in); @@ -302,8 +305,7 @@ MeasureRecordReaderFormatHits::MeasureRecordReaderFormatHits( } template -bool MeasureRecordReaderFormatHits::start_and_read_entire_record( - simd_bits_range_ref dirty_out_buffer) { +bool MeasureRecordReaderFormatHits::start_and_read_entire_record(simd_bits_range_ref dirty_out_buffer) { size_t m = this->bits_per_record(); dirty_out_buffer.prefix_ref(m).clear(); return start_and_read_entire_record_helper([&](size_t bit_index) { @@ -399,8 +401,7 @@ MeasureRecordReaderFormatR8::MeasureRecordReaderFormatR8( } template -bool MeasureRecordReaderFormatR8::start_and_read_entire_record( - simd_bits_range_ref dirty_out_buffer) { +bool MeasureRecordReaderFormatR8::start_and_read_entire_record(simd_bits_range_ref dirty_out_buffer) { dirty_out_buffer.prefix_ref(this->bits_per_record()).clear(); return start_and_read_entire_record_helper([&](size_t bit_index) { dirty_out_buffer[bit_index] = 1; @@ -477,8 +478,7 @@ bool MeasureRecordReaderFormatR8::start_and_read_entire_record_helper(HANDLE_ /// DETS format template -bool MeasureRecordReaderFormatDets::start_and_read_entire_record( - simd_bits_range_ref dirty_out_buffer) { +bool MeasureRecordReaderFormatDets::start_and_read_entire_record(simd_bits_range_ref dirty_out_buffer) { dirty_out_buffer.prefix_ref(this->bits_per_record()).clear(); return start_and_read_entire_record_helper([&](size_t bit_index) { dirty_out_buffer[bit_index] = true; @@ -634,8 +634,7 @@ bool MeasureRecordReaderFormatPTB64::load_cache() { } template -bool MeasureRecordReaderFormatPTB64::start_and_read_entire_record( - simd_bits_range_ref dirty_out_buffer) { +bool MeasureRecordReaderFormatPTB64::start_and_read_entire_record(simd_bits_range_ref dirty_out_buffer) { if (num_unread_shots_in_buf == 0) { load_cache(); } diff --git a/src/stim/io/measure_record_reader.perf.cc b/src/stim/io/measure_record_reader.perf.cc index 1617add4c..f5f00dbcf 100644 --- a/src/stim/io/measure_record_reader.perf.cc +++ b/src/stim/io/measure_record_reader.perf.cc @@ -77,54 +77,54 @@ void sparse_reader_benchmark(double goal_micros) { } BENCHMARK(read_01_dense_per10) { - dense_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_01>(60); + dense_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_01>(60); } BENCHMARK(read_01_sparse_per10) { - sparse_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_01>(45); + sparse_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_01>(45); } BENCHMARK(read_b8_dense_per10) { - dense_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_B8>(0.65); + dense_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_B8>(0.65); } BENCHMARK(read_b8_sparse_per10) { - sparse_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_B8>(6); + sparse_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_B8>(6); } BENCHMARK(read_hits_dense_per10) { - dense_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_HITS>(16); + dense_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_HITS>(16); } BENCHMARK(read_hits_dense_per100) { - dense_reader_benchmark<10000, 100, SampleFormat:: SAMPLE_FORMAT_HITS>(2.1); + dense_reader_benchmark<10000, 100, SampleFormat::SAMPLE_FORMAT_HITS>(2.1); } BENCHMARK(read_hits_sparse_per10) { - sparse_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_HITS>(15); + sparse_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_HITS>(15); } BENCHMARK(read_hits_sparse_per100) { - sparse_reader_benchmark<10000, 100, SampleFormat:: SAMPLE_FORMAT_HITS>(2.2); + sparse_reader_benchmark<10000, 100, SampleFormat::SAMPLE_FORMAT_HITS>(2.2); } BENCHMARK(read_dets_dense_per10) { - dense_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_DETS>(23); + dense_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_DETS>(23); } BENCHMARK(read_dets_dense_per100) { - dense_reader_benchmark<10000, 100, SampleFormat:: SAMPLE_FORMAT_DETS>(3.0); + dense_reader_benchmark<10000, 100, SampleFormat::SAMPLE_FORMAT_DETS>(3.0); } BENCHMARK(read_dets_sparse_per10) { - sparse_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_DETS>(23); + sparse_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_DETS>(23); } BENCHMARK(read_dets_sparse_per100) { - sparse_reader_benchmark<10000, 100, SampleFormat:: SAMPLE_FORMAT_DETS>(3.0); + sparse_reader_benchmark<10000, 100, SampleFormat::SAMPLE_FORMAT_DETS>(3.0); } BENCHMARK(read_r8_dense_per10) { - dense_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_R8>(5); + dense_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_R8>(5); } BENCHMARK(read_r8_dense_per100) { - dense_reader_benchmark<10000, 100, SampleFormat:: SAMPLE_FORMAT_R8>(1.3); + dense_reader_benchmark<10000, 100, SampleFormat::SAMPLE_FORMAT_R8>(1.3); } BENCHMARK(read_r8_sparse_per10) { - sparse_reader_benchmark<10000, 10, SampleFormat:: SAMPLE_FORMAT_R8>(3.5); + sparse_reader_benchmark<10000, 10, SampleFormat::SAMPLE_FORMAT_R8>(3.5); } BENCHMARK(read_r8_sparse_per100) { - sparse_reader_benchmark<10000, 100, SampleFormat:: SAMPLE_FORMAT_R8>(1.0); + sparse_reader_benchmark<10000, 100, SampleFormat::SAMPLE_FORMAT_R8>(1.0); } diff --git a/src/stim/main_namespaced.perf.cc b/src/stim/main_namespaced.perf.cc index 41c0e64eb..4e09fc6e9 100644 --- a/src/stim/main_namespaced.perf.cc +++ b/src/stim/main_namespaced.perf.cc @@ -87,7 +87,7 @@ BENCHMARK(main_sample1_pauliframe_b8_rep_d1000_r100) { simd_bits ref(0); benchmark_go([&]() { rewind(out); - sample_batch_measurements_writing_results_to_disk(circuit, ref, 1, out, SampleFormat:: SAMPLE_FORMAT_B8, rng); + sample_batch_measurements_writing_results_to_disk(circuit, ref, 1, out, SampleFormat::SAMPLE_FORMAT_B8, rng); }) .goal_millis(9) .show_rate("Samples", circuit.count_measurements()); @@ -104,7 +104,15 @@ BENCHMARK(main_sample1_detectors_b8_rep_d1000_r100) { benchmark_go([&]() { rewind(out); sample_batch_detection_events_writing_results_to_disk( - circuit, 1, false, false, out, SampleFormat:: SAMPLE_FORMAT_B8, rng, obs_out, SampleFormat:: SAMPLE_FORMAT_B8); + circuit, + 1, + false, + false, + out, + SampleFormat::SAMPLE_FORMAT_B8, + rng, + obs_out, + SampleFormat::SAMPLE_FORMAT_B8); }) .goal_millis(11) .show_rate("Samples", circuit.count_measurements()); @@ -119,7 +127,7 @@ BENCHMARK(main_sample256_pauliframe_b8_rep_d1000_r100) { simd_bits ref(0); benchmark_go([&]() { rewind(out); - sample_batch_measurements_writing_results_to_disk(circuit, ref, 256, out, SampleFormat:: SAMPLE_FORMAT_B8, rng); + sample_batch_measurements_writing_results_to_disk(circuit, ref, 256, out, SampleFormat::SAMPLE_FORMAT_B8, rng); }) .goal_millis(13) .show_rate("Samples", circuit.count_measurements()); @@ -135,7 +143,7 @@ BENCHMARK(main_sample256_pauliframe_b8_rep_d1000_r1000_stream) { simd_bits ref(0); benchmark_go([&]() { rewind(out); - sample_batch_measurements_writing_results_to_disk(circuit, ref, 256, out, SampleFormat:: SAMPLE_FORMAT_B8, rng); + sample_batch_measurements_writing_results_to_disk(circuit, ref, 256, out, SampleFormat::SAMPLE_FORMAT_B8, rng); }) .goal_millis(360) .show_rate("Samples", circuit.count_measurements()); @@ -152,7 +160,15 @@ BENCHMARK(main_sample256_detectors_b8_rep_d1000_r100) { benchmark_go([&]() { rewind(out); sample_batch_detection_events_writing_results_to_disk( - circuit, 256, false, false, out, SampleFormat:: SAMPLE_FORMAT_B8, rng, obs_out, SampleFormat:: SAMPLE_FORMAT_B8); + circuit, + 256, + false, + false, + out, + SampleFormat::SAMPLE_FORMAT_B8, + rng, + obs_out, + SampleFormat::SAMPLE_FORMAT_B8); }) .goal_millis(15) .show_rate("Samples", circuit.count_measurements()); @@ -170,7 +186,15 @@ BENCHMARK(main_sample256_detectors_b8_rep_d1000_r1000_stream) { benchmark_go([&]() { rewind(out); sample_batch_detection_events_writing_results_to_disk( - circuit, 256, false, false, out, SampleFormat:: SAMPLE_FORMAT_B8, rng, obs_out, SampleFormat:: SAMPLE_FORMAT_B8); + circuit, + 256, + false, + false, + out, + SampleFormat::SAMPLE_FORMAT_B8, + rng, + obs_out, + SampleFormat::SAMPLE_FORMAT_B8); }) .goal_millis(360) .show_rate("Samples", circuit.count_measurements()); diff --git a/src/stim/mem/simd_bit_table.inl b/src/stim/mem/simd_bit_table.inl index 79c74cf81..0fe88ff12 100644 --- a/src/stim/mem/simd_bit_table.inl +++ b/src/stim/mem/simd_bit_table.inl @@ -153,8 +153,7 @@ void simd_bit_table::do_square_transpose() { for (size_t maj_low = 0; maj_low < W; maj_low++) { std::swap( data.ptr_simd[get_index_of_bitword(maj_high, maj_low, min_high)], - data.ptr_simd[get_index_of_bitword(min_high, maj_low, maj_high)] - ); + data.ptr_simd[get_index_of_bitword(min_high, maj_low, maj_high)]); } } } @@ -169,7 +168,8 @@ simd_bit_table simd_bit_table::transposed() const { } template -simd_bits simd_bit_table::read_across_majors_at_minor_index(size_t major_start, size_t major_stop, size_t minor_index) const { +simd_bits simd_bit_table::read_across_majors_at_minor_index( + size_t major_start, size_t major_stop, size_t minor_index) const { assert(major_stop >= major_start); assert(major_stop <= num_major_bits_padded()); assert(minor_index < num_minor_bits_padded()); @@ -256,8 +256,10 @@ std::string simd_bit_table::str() const { } template -simd_bit_table simd_bit_table::concat_major(const simd_bit_table &second, size_t n_first, size_t n_second) const { - if (num_major_bits_padded() < n_first || second.num_major_bits_padded() < n_second || num_minor_bits_padded() != second.num_minor_bits_padded()) { +simd_bit_table simd_bit_table::concat_major( + const simd_bit_table &second, size_t n_first, size_t n_second) const { + if (num_major_bits_padded() < n_first || second.num_major_bits_padded() < n_second || + num_minor_bits_padded() != second.num_minor_bits_padded()) { throw std::invalid_argument("Size mismatch"); } simd_bit_table result(n_first + n_second, num_minor_bits_padded()); @@ -269,9 +271,13 @@ simd_bit_table simd_bit_table::concat_major(const simd_bit_table &secon } template -void simd_bit_table::overwrite_major_range_with(size_t dst_major_start, const simd_bit_table &src, size_t src_major_start, size_t num_major_indices) const { +void simd_bit_table::overwrite_major_range_with( + size_t dst_major_start, const simd_bit_table &src, size_t src_major_start, size_t num_major_indices) const { assert(src.num_minor_bits_padded() == num_minor_bits_padded()); - memcpy(data.u8 + dst_major_start * num_minor_u8_padded(), src.data.u8 + src_major_start * src.num_minor_u8_padded(), num_major_indices * num_minor_u8_padded()); + memcpy( + data.u8 + dst_major_start * num_minor_u8_padded(), + src.data.u8 + src_major_start * src.num_minor_u8_padded(), + num_major_indices * num_minor_u8_padded()); } template @@ -345,4 +351,4 @@ std::ostream &operator<<(std::ostream &out, const stim::simd_bit_table &v) { return out; } -} +} // namespace stim diff --git a/src/stim/mem/simd_bits.h b/src/stim/mem/simd_bits.h index cdd1e5f28..dd0b170e0 100644 --- a/src/stim/mem/simd_bits.h +++ b/src/stim/mem/simd_bits.h @@ -68,6 +68,7 @@ struct simd_bits { simd_bits &operator|=(const simd_bits_range_ref other); // Addition assigment simd_bits &operator+=(const simd_bits_range_ref other); + simd_bits &operator-=(const simd_bits_range_ref other); // right shift assignment simd_bits &operator>>=(int offset); // left shift assignment @@ -112,6 +113,8 @@ struct simd_bits { /// Returns the number of bits that are 1 in the bit range. size_t popcnt() const; + /// Returns the power-of-two-ness of the number, or SIZE_MAX if the number has no 1s. + size_t countr_zero() const; /// Inverts all bits in the range. void invert_bits(); diff --git a/src/stim/mem/simd_bits.inl b/src/stim/mem/simd_bits.inl index 72f418f70..73d1343f1 100644 --- a/src/stim/mem/simd_bits.inl +++ b/src/stim/mem/simd_bits.inl @@ -256,6 +256,12 @@ simd_bits &simd_bits::operator+=(const simd_bits_range_ref other) { return *this; } +template +simd_bits &simd_bits::operator-=(const simd_bits_range_ref other) { + simd_bits_range_ref(*this) -= other; + return *this; +} + template simd_bits &simd_bits::operator>>=(int offset) { simd_bits_range_ref(*this) >>= offset; @@ -302,6 +308,11 @@ size_t simd_bits::popcnt() const { return simd_bits_range_ref(*this).popcnt(); } +template +size_t simd_bits::countr_zero() const { + return simd_bits_range_ref(*this).countr_zero(); +} + template std::ostream &operator<<(std::ostream &out, const simd_bits m) { return out << simd_bits_range_ref(m); diff --git a/src/stim/mem/simd_bits.test.cc b/src/stim/mem/simd_bits.test.cc index 3080f2163..3421bc770 100644 --- a/src/stim/mem/simd_bits.test.cc +++ b/src/stim/mem/simd_bits.test.cc @@ -653,6 +653,19 @@ TEST_EACH_WORD_SIZE_W(simd_bits, popcnt, { ASSERT_EQ(simd_bits(0).popcnt(), 0); }) +TEST_EACH_WORD_SIZE_W(simd_bits, countr_zero, { + simd_bits data(1024); + ASSERT_EQ(data.countr_zero(), SIZE_MAX); + data[1000] = 1; + ASSERT_EQ(data.countr_zero(), 1000); + data[101] = 1; + ASSERT_EQ(data.countr_zero(), 101); + data[260] = 1; + ASSERT_EQ(data.countr_zero(), 101); + data[0] = 1; + ASSERT_EQ(data.countr_zero(), 0); +}) + TEST_EACH_WORD_SIZE_W(simd_bits, prefix_ref, { simd_bits data(1024); auto prefix = data.prefix_ref(257); diff --git a/src/stim/mem/simd_bits_range_ref.h b/src/stim/mem/simd_bits_range_ref.h index 94ea73973..376e1fe2f 100644 --- a/src/stim/mem/simd_bits_range_ref.h +++ b/src/stim/mem/simd_bits_range_ref.h @@ -68,6 +68,7 @@ struct simd_bits_range_ref { simd_bits_range_ref operator|=(const simd_bits_range_ref other); // Addition assigment simd_bits_range_ref operator+=(const simd_bits_range_ref other); + simd_bits_range_ref operator-=(const simd_bits_range_ref other); // Shift assigment simd_bits_range_ref operator>>=(int offset); simd_bits_range_ref operator<<=(int offset); @@ -110,6 +111,8 @@ struct simd_bits_range_ref { void randomize(size_t num_bits, std::mt19937_64 &rng); /// Returns the number of bits that are 1 in the bit range. size_t popcnt() const; + /// Returns the power-of-two-ness of the number, or SIZE_MAX if the number has no 1s. + size_t countr_zero() const; /// Returns whether or not the two ranges have set bits in common. bool intersects(const simd_bits_range_ref other) const; diff --git a/src/stim/mem/simd_bits_range_ref.inl b/src/stim/mem/simd_bits_range_ref.inl index 0d6040331..d6135a005 100644 --- a/src/stim/mem/simd_bits_range_ref.inl +++ b/src/stim/mem/simd_bits_range_ref.inl @@ -66,6 +66,14 @@ simd_bits_range_ref simd_bits_range_ref::operator+=(const simd_bits_range_ return *this; } +template +simd_bits_range_ref simd_bits_range_ref::operator-=(const simd_bits_range_ref other) { + invert_bits(); + *this += other; + invert_bits(); + return *this; +} + template simd_bits_range_ref simd_bits_range_ref::operator>>=(int offset) { uint64_t incoming_word; @@ -213,6 +221,22 @@ size_t simd_bits_range_ref::popcnt() const { return result; } +template +size_t simd_bits_range_ref::countr_zero() const { + size_t n = num_u64_padded(); + for (size_t k = 0; k < n; k++) { + uint64_t u = u64[k]; + if (u) { + for (size_t r = 0; r < 64; r++) { + if ((u >> r) & 1) { + return r + 64 * k; + } + } + } + } + return SIZE_MAX; +} + template bool simd_bits_range_ref::intersects(const simd_bits_range_ref other) const { size_t n = std::min(num_u64_padded(), other.num_u64_padded()); diff --git a/src/stim/mem/simd_bits_range_ref.test.cc b/src/stim/mem/simd_bits_range_ref.test.cc index 235b85d69..2af444dab 100644 --- a/src/stim/mem/simd_bits_range_ref.test.cc +++ b/src/stim/mem/simd_bits_range_ref.test.cc @@ -354,6 +354,20 @@ TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, popcnt, { ASSERT_EQ(ref.popcnt(), 66); }) +TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, countr_zero, { + simd_bits data(1024); + simd_bits_range_ref ref(data); + ASSERT_EQ(ref.countr_zero(), SIZE_MAX); + data[1000] = 1; + ASSERT_EQ(ref.countr_zero(), 1000); + data[101] = 1; + ASSERT_EQ(ref.countr_zero(), 101); + data[260] = 1; + ASSERT_EQ(ref.countr_zero(), 101); + data[0] = 1; + ASSERT_EQ(ref.countr_zero(), 0); +}) + TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, intersects, { simd_bits data(1024); simd_bits other(512); diff --git a/src/stim/py/stim.pybind.cc b/src/stim/py/stim.pybind.cc index 4a6828450..159d172e2 100644 --- a/src/stim/py/stim.pybind.cc +++ b/src/stim/py/stim.pybind.cc @@ -37,6 +37,7 @@ #include "stim/simulators/measurements_to_detection_events.pybind.h" #include "stim/simulators/tableau_simulator.pybind.h" #include "stim/stabilizers/pauli_string.pybind.h" +#include "stim/stabilizers/pauli_string_iter.pybind.h" #include "stim/stabilizers/tableau.h" #include "stim/stabilizers/tableau.pybind.h" #include "stim/stabilizers/tableau_iter.pybind.h" @@ -418,6 +419,7 @@ PYBIND11_MODULE(STIM_PYBIND11_MODULE_NAME, m) { auto c_compiled_measurement_sampler = pybind_compiled_measurement_sampler(m); auto c_compiled_m2d_converter = pybind_compiled_measurements_to_detection_events_converter(m); auto c_pauli_string = pybind_pauli_string(m); + auto c_pauli_string_iter = pybind_pauli_string_iter(m); auto c_tableau = pybind_tableau(m); auto c_tableau_iter = pybind_tableau_iter(m); @@ -466,6 +468,7 @@ PYBIND11_MODULE(STIM_PYBIND11_MODULE_NAME, m) { pybind_tableau_methods(m, c_tableau); pybind_pauli_string_methods(m, c_pauli_string); + pybind_pauli_string_iter_methods(m, c_pauli_string_iter); pybind_compiled_detector_sampler_methods(m, c_compiled_detector_sampler); pybind_compiled_measurement_sampler_methods(m, c_compiled_measurement_sampler); diff --git a/src/stim/simulators/count_determined_measurements.inl b/src/stim/simulators/count_determined_measurements.inl index 27320be1c..e060ee290 100644 --- a/src/stim/simulators/count_determined_measurements.inl +++ b/src/stim/simulators/count_determined_measurements.inl @@ -96,8 +96,7 @@ uint64_t count_determined_measurements(const Circuit &circuit) { break; } default: - throw std::invalid_argument( - "count_determined_measurements unhandled measurement type " + inst.str()); + throw std::invalid_argument("count_determined_measurements unhandled measurement type " + inst.str()); } }); return result; diff --git a/src/stim/simulators/dem_sampler.inl b/src/stim/simulators/dem_sampler.inl index bf46cf2f9..8cce14b23 100644 --- a/src/stim/simulators/dem_sampler.inl +++ b/src/stim/simulators/dem_sampler.inl @@ -14,13 +14,12 @@ * limitations under the License. */ -#include "stim/simulators/dem_sampler.h" - #include #include "stim/io/measure_record_reader.h" #include "stim/io/measure_record_writer.h" #include "stim/probability_util.h" +#include "stim/simulators/dem_sampler.h" namespace stim { @@ -97,15 +96,7 @@ void DemSampler::sample_write( if (err_out != nullptr) { write_table_data( - err_out, - shots_left, - (size_t)num_errors, - simd_bits(0), - err_buffer, - err_out_format, - 'M', - 'M', - false); + err_out, shots_left, (size_t)num_errors, simd_bits(0), err_buffer, err_out_format, 'M', 'M', false); } if (obs_out != nullptr) { diff --git a/src/stim/simulators/frame_simulator.inl b/src/stim/simulators/frame_simulator.inl index 006ab91a4..4b5d6c0c8 100644 --- a/src/stim/simulators/frame_simulator.inl +++ b/src/stim/simulators/frame_simulator.inl @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/simulators/frame_simulator.h" - #include #include #include "stim/circuit/gate_decomposition.h" #include "stim/probability_util.h" +#include "stim/simulators/frame_simulator.h" #include "stim/simulators/tableau_simulator.h" namespace stim { @@ -58,10 +57,15 @@ FrameSimulator::FrameSimulator( } template -void FrameSimulator::configure_for(CircuitStats new_circuit_stats, FrameSimulatorMode new_mode, size_t new_batch_size) { - bool storing_all_measurements = new_mode == FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY || new_mode == FrameSimulatorMode::STORE_EVERYTHING_TO_MEMORY; - bool storing_all_detections = new_mode == FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY || new_mode == FrameSimulatorMode::STORE_EVERYTHING_TO_MEMORY; - bool storing_any_detections = new_mode == FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY || new_mode == FrameSimulatorMode::STORE_EVERYTHING_TO_MEMORY || new_mode == FrameSimulatorMode::STREAM_DETECTIONS_TO_DISK; +void FrameSimulator::configure_for( + CircuitStats new_circuit_stats, FrameSimulatorMode new_mode, size_t new_batch_size) { + bool storing_all_measurements = new_mode == FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY || + new_mode == FrameSimulatorMode::STORE_EVERYTHING_TO_MEMORY; + bool storing_all_detections = new_mode == FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY || + new_mode == FrameSimulatorMode::STORE_EVERYTHING_TO_MEMORY; + bool storing_any_detections = new_mode == FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY || + new_mode == FrameSimulatorMode::STORE_EVERYTHING_TO_MEMORY || + new_mode == FrameSimulatorMode::STREAM_DETECTIONS_TO_DISK; batch_size = new_batch_size; num_qubits = new_circuit_stats.num_qubits; @@ -80,8 +84,12 @@ void FrameSimulator::configure_for(CircuitStats new_circuit_stats, FrameSimul m_record.destructive_resize(batch_size, num_stored_measurements); num_observables = storing_any_detections ? new_circuit_stats.num_observables : 0; - det_record.destructive_resize(batch_size, storing_all_detections ? new_circuit_stats.num_detectors : storing_any_detections ? 1 : 0), - obs_record.destructive_resize(num_observables, batch_size); + det_record.destructive_resize( + batch_size, + storing_all_detections ? new_circuit_stats.num_detectors + : storing_any_detections ? 1 + : 0), + obs_record.destructive_resize(num_observables, batch_size); } template @@ -370,9 +378,10 @@ void FrameSimulator::single_cx(uint32_t c, uint32_t t) { t &= ~TARGET_INVERTED_BIT; if (!((c | t) & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT))) { x_table[c].for_each_word( - z_table[c], x_table[t], z_table[t], []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { + z_table[c], + x_table[t], + z_table[t], + [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { z1 ^= z2; x2 ^= x1; }); @@ -390,9 +399,10 @@ void FrameSimulator::single_cy(uint32_t c, uint32_t t) { t &= ~TARGET_INVERTED_BIT; if (!((c | t) & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT))) { x_table[c].for_each_word( - z_table[c], x_table[t], z_table[t], []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { + z_table[c], + x_table[t], + z_table[t], + [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { z1 ^= x2 ^ z2; z2 ^= x1; x2 ^= x1; @@ -435,9 +445,10 @@ void FrameSimulator::do_ZCZ(const CircuitInstruction &target_data) { t &= ~TARGET_INVERTED_BIT; if (!((c | t) & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT))) { x_table[c].for_each_word( - z_table[c], x_table[t], z_table[t], []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { + z_table[c], + x_table[t], + z_table[t], + [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { z1 ^= x2; z2 ^= x1; }); @@ -459,9 +470,10 @@ void FrameSimulator::do_SWAP(const CircuitInstruction &target_data) { size_t q1 = targets[k].data; size_t q2 = targets[k + 1].data; x_table[q1].for_each_word( - z_table[q1], x_table[q2], z_table[q2], []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { + z_table[q1], + x_table[q2], + z_table[q2], + [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { std::swap(z1, z2); std::swap(x1, x2); }); @@ -470,96 +482,88 @@ void FrameSimulator::do_SWAP(const CircuitInstruction &target_data) { template void FrameSimulator::do_ISWAP(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - auto dx = x1 ^ x2; - auto t1 = z1 ^ dx; - auto t2 = z2 ^ dx; - z1 = t2; - z2 = t1; - std::swap(x1, x2); - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + auto dx = x1 ^ x2; + auto t1 = z1 ^ dx; + auto t2 = z2 ^ dx; + z1 = t2; + z2 = t1; + std::swap(x1, x2); + }); } template void FrameSimulator::do_CXSWAP(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - z2 ^= z1; - z1 ^= z2; - x1 ^= x2; - x2 ^= x1; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + z2 ^= z1; + z1 ^= z2; + x1 ^= x2; + x2 ^= x1; + }); } template void FrameSimulator::do_SWAPCX(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - z1 ^= z2; - z2 ^= z1; - x2 ^= x1; - x1 ^= x2; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + z1 ^= z2; + z2 ^= z1; + x2 ^= x1; + x1 ^= x2; + }); } template void FrameSimulator::do_SQRT_XX(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - auto dz = z1 ^ z2; - x1 ^= dz; - x2 ^= dz; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + auto dz = z1 ^ z2; + x1 ^= dz; + x2 ^= dz; + }); } template void FrameSimulator::do_SQRT_YY(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - auto d = x1 ^ z1 ^ x2 ^ z2; - x1 ^= d; - z1 ^= d; - x2 ^= d; - z2 ^= d; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + auto d = x1 ^ z1 ^ x2 ^ z2; + x1 ^= d; + z1 ^= d; + x2 ^= d; + z2 ^= d; + }); } template void FrameSimulator::do_SQRT_ZZ(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - auto dx = x1 ^ x2; - z1 ^= dx; - z2 ^= dx; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + auto dx = x1 ^ x2; + z1 ^= dx; + z2 ^= dx; + }); } template void FrameSimulator::do_XCX(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - x1 ^= z2; - x2 ^= z1; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + x1 ^= z2; + x2 ^= z1; + }); } template void FrameSimulator::do_XCY(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - x1 ^= x2 ^ z2; - x2 ^= z1; - z2 ^= z1; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + x1 ^= x2 ^ z2; + x2 ^= z1; + z2 ^= z1; + }); } template @@ -573,27 +577,25 @@ void FrameSimulator::do_XCZ(const CircuitInstruction &target_data) { template void FrameSimulator::do_YCX(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - x2 ^= x1 ^ z1; - x1 ^= z2; - z1 ^= z2; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + x2 ^= x1 ^ z1; + x1 ^= z2; + z1 ^= z2; + }); } template void FrameSimulator::do_YCY(const CircuitInstruction &target_data) { - for_each_target_pair(*this, target_data, []( - simd_word &x1, simd_word &z1, - simd_word &x2, simd_word &z2) { - auto y1 = x1 ^ z1; - auto y2 = x2 ^ z2; - x1 ^= y2; - z1 ^= y2; - x2 ^= y1; - z2 ^= y1; - }); + for_each_target_pair( + *this, target_data, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + auto y1 = x1 ^ z1; + auto y2 = x2 ^ z2; + x1 ^= y2; + z1 ^= y2; + x2 ^= y1; + z2 ^= y1; + }); } template @@ -732,8 +734,7 @@ void FrameSimulator::do_ELSE_CORRELATED_ERROR(const CircuitInstruction &targe } // Omit locations blocked by prev error, while updating prev error mask. simd_bits_range_ref{rng_buffer}.for_each_word( - last_correlated_error_occurred, []( - simd_word &buf, simd_word &prev) { + last_correlated_error_occurred, [](simd_word &buf, simd_word &prev) { buf = prev.andnot(buf); prev |= buf; }); @@ -857,9 +858,7 @@ void FrameSimulator::do_MZZ_disjoint_controls_segment(const CircuitInstructio template void FrameSimulator::do_MXX(const CircuitInstruction &inst) { decompose_pair_instruction_into_segments_with_single_use_controls( - inst, - num_qubits, - [&](CircuitInstruction segment){ + inst, num_qubits, [&](CircuitInstruction segment) { do_MXX_disjoint_controls_segment(segment); }); } @@ -867,9 +866,7 @@ void FrameSimulator::do_MXX(const CircuitInstruction &inst) { template void FrameSimulator::do_MYY(const CircuitInstruction &inst) { decompose_pair_instruction_into_segments_with_single_use_controls( - inst, - num_qubits, - [&](CircuitInstruction segment){ + inst, num_qubits, [&](CircuitInstruction segment) { do_MYY_disjoint_controls_segment(segment); }); } @@ -877,9 +874,7 @@ void FrameSimulator::do_MYY(const CircuitInstruction &inst) { template void FrameSimulator::do_MZZ(const CircuitInstruction &inst) { decompose_pair_instruction_into_segments_with_single_use_controls( - inst, - num_qubits, - [&](CircuitInstruction segment){ + inst, num_qubits, [&](CircuitInstruction segment) { do_MZZ_disjoint_controls_segment(segment); }); } diff --git a/src/stim/simulators/frame_simulator_util.inl b/src/stim/simulators/frame_simulator_util.inl index c4f3b89cf..94b920534 100644 --- a/src/stim/simulators/frame_simulator_util.inl +++ b/src/stim/simulators/frame_simulator_util.inl @@ -12,20 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/simulators/frame_simulator_util.h" - -#include "stim/simulators/frame_simulator.h" #include "stim/simulators/force_streaming.h" +#include "stim/simulators/frame_simulator.h" +#include "stim/simulators/frame_simulator_util.h" namespace stim { template std::pair, simd_bit_table> sample_batch_detection_events( const Circuit &circuit, size_t num_shots, std::mt19937_64 &rng) { - FrameSimulator sim(circuit.compute_stats(), FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY, num_shots, std::move(rng)); + FrameSimulator sim( + circuit.compute_stats(), FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY, num_shots, std::move(rng)); sim.reset_all(); sim.do_circuit(circuit); - rng = std::move(sim.rng); // Update input rng as if it was used directly, by moving the updated state out of the simulator. + rng = std::move( + sim.rng); // Update input rng as if it was used directly, by moving the updated state out of the simulator. return std::pair, simd_bit_table>{ std::move(sim.det_record.storage), @@ -148,12 +149,14 @@ void rerun_frame_sim_in_memory_and_write_dets_to_disk( if (prepend_observables || append_observables) { if (prepend_observables) { assert(!append_observables); - out_concat_buf.overwrite_major_range_with(circuit_stats.num_observables, det_data, 0, circuit_stats.num_detectors); + out_concat_buf.overwrite_major_range_with( + circuit_stats.num_observables, det_data, 0, circuit_stats.num_detectors); out_concat_buf.overwrite_major_range_with(0, obs_data, 0, circuit_stats.num_observables); } else { assert(append_observables); out_concat_buf.overwrite_major_range_with(0, det_data, 0, circuit_stats.num_detectors); - out_concat_buf.overwrite_major_range_with(circuit_stats.num_detectors, obs_data, 0, circuit_stats.num_observables); + out_concat_buf.overwrite_major_range_with( + circuit_stats.num_detectors, obs_data, 0, circuit_stats.num_observables); } char c1 = append_observables ? 'D' : 'L'; @@ -291,11 +294,13 @@ simd_bit_table sample_batch_measurements( size_t num_samples, std::mt19937_64 &rng, bool transposed) { - FrameSimulator sim(circuit.compute_stats(), FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY, num_samples, std::move(rng)); + FrameSimulator sim( + circuit.compute_stats(), FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY, num_samples, std::move(rng)); sim.reset_all(); sim.do_circuit(circuit); simd_bit_table result = std::move(sim.m_record.storage); - rng = std::move(sim.rng); // Update input rng as if it was used directly, by moving the updated state out of the simulator. + rng = std::move( + sim.rng); // Update input rng as if it was used directly, by moving the updated state out of the simulator. if (reference_sample.not_zero()) { result = transposed_vs_ref(num_samples, result, reference_sample); diff --git a/src/stim/simulators/measurements_to_detection_events.inl b/src/stim/simulators/measurements_to_detection_events.inl index 535198f73..53c7109fd 100644 --- a/src/stim/simulators/measurements_to_detection_events.inl +++ b/src/stim/simulators/measurements_to_detection_events.inl @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/simulators/measurements_to_detection_events.h" - #include #include "stim/circuit/gate_data.h" @@ -22,6 +20,7 @@ #include "stim/io/stim_data_formats.h" #include "stim/mem/simd_util.h" #include "stim/simulators/frame_simulator.h" +#include "stim/simulators/measurements_to_detection_events.h" #include "stim/simulators/tableau_simulator.h" #include "stim/stabilizers/pauli_string.h" @@ -58,7 +57,8 @@ void measurements_to_detection_events_helper( // The frame simulator is used to account for flips in the measurement results that originate from the sweep data. // Eg. a `CNOT sweep[5] 0` can bit flip qubit 0, which can invert later measurement results, which will invert the // expected parity of detectors involving that measurement. This can vary from shot to shot. - FrameSimulator frame_sim(circuit_stats, FrameSimulatorMode::STREAM_DETECTIONS_TO_DISK, batch_size, std::mt19937_64(0)); + FrameSimulator frame_sim( + circuit_stats, FrameSimulatorMode::STREAM_DETECTIONS_TO_DISK, batch_size, std::mt19937_64(0)); frame_sim.sweep_table = sweep_bits__minor_shot_index; frame_sim.guarantee_anticommutation_via_frame_randomization = false; @@ -69,8 +69,7 @@ void measurements_to_detection_events_helper( switch (op.gate_type) { case GateType::DETECTOR: { - simd_bits_range_ref out_row = - out_detection_results__minor_shot_index[detector_offset]; + simd_bits_range_ref out_row = out_detection_results__minor_shot_index[detector_offset]; detector_offset++; // Include dependence from gates controlled by sweep bits. @@ -221,8 +220,7 @@ void stream_measurements_to_detection_events_helper( } // Buffers and transposed buffers. - simd_bit_table measurements__minor_shot_index( - circuit_stats.num_measurements, num_buffered_shots); + simd_bit_table measurements__minor_shot_index(circuit_stats.num_measurements, num_buffered_shots); simd_bit_table out__minor_shot_index(num_out_bits_including_any_obs, num_buffered_shots); simd_bit_table out__major_shot_index(num_buffered_shots, num_out_bits_including_any_obs); simd_bit_table sweep_bits__minor_shot_index(num_sweep_bits_available, num_buffered_shots); diff --git a/src/stim/simulators/tableau_simulator.inl b/src/stim/simulators/tableau_simulator.inl index 831b38917..112fde44a 100644 --- a/src/stim/simulators/tableau_simulator.inl +++ b/src/stim/simulators/tableau_simulator.inl @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/simulators/tableau_simulator.h" - #include #include "stim/circuit/gate_data.h" #include "stim/circuit/gate_decomposition.h" #include "stim/probability_util.h" +#include "stim/simulators/tableau_simulator.h" #include "stim/simulators/vector_simulator.h" namespace stim { @@ -166,9 +165,7 @@ uint32_t TableauSimulator::try_isolate_observable_to_qubit_z(PauliStringRef -void TableauSimulator::postselect_observable( - PauliStringRef observable, - bool desired_result) { +void TableauSimulator::postselect_observable(PauliStringRef observable, bool desired_result) { ensure_large_enough_for_qubits(observable.num_qubits); uint32_t pivot = try_isolate_observable_to_qubit_z(observable, false); @@ -232,9 +229,7 @@ void TableauSimulator::do_MX(const CircuitInstruction &target_data) { } template -void TableauSimulator::do_MXX_disjoint_controls_segment( - const CircuitInstruction &inst) { - +void TableauSimulator::do_MXX_disjoint_controls_segment(const CircuitInstruction &inst) { // Transform from 2 qubit measurements to single qubit measurements. do_ZCX(CircuitInstruction{GateType::CX, {}, inst.targets}); @@ -257,9 +252,7 @@ void TableauSimulator::do_MXX_disjoint_controls_segment( } template -void TableauSimulator::do_MYY_disjoint_controls_segment( - const CircuitInstruction &inst) { - +void TableauSimulator::do_MYY_disjoint_controls_segment(const CircuitInstruction &inst) { // Transform from 2 qubit measurements to single qubit measurements. do_ZCY(CircuitInstruction{GateType::CY, {}, inst.targets}); @@ -282,9 +275,7 @@ void TableauSimulator::do_MYY_disjoint_controls_segment( } template -void TableauSimulator::do_MZZ_disjoint_controls_segment( - const CircuitInstruction &inst) { - +void TableauSimulator::do_MZZ_disjoint_controls_segment(const CircuitInstruction &inst) { // Transform from 2 qubit measurements to single qubit measurements. do_XCZ(CircuitInstruction{GateType::XCZ, {}, inst.targets}); @@ -309,9 +300,7 @@ void TableauSimulator::do_MZZ_disjoint_controls_segment( template void TableauSimulator::do_MXX(const CircuitInstruction &inst) { decompose_pair_instruction_into_segments_with_single_use_controls( - inst, - inv_state.num_qubits, - [&](CircuitInstruction segment){ + inst, inv_state.num_qubits, [&](CircuitInstruction segment) { do_MXX_disjoint_controls_segment(segment); }); } @@ -319,9 +308,7 @@ void TableauSimulator::do_MXX(const CircuitInstruction &inst) { template void TableauSimulator::do_MYY(const CircuitInstruction &inst) { decompose_pair_instruction_into_segments_with_single_use_controls( - inst, - inv_state.num_qubits, - [&](CircuitInstruction segment){ + inst, inv_state.num_qubits, [&](CircuitInstruction segment) { do_MYY_disjoint_controls_segment(segment); }); } @@ -329,9 +316,7 @@ void TableauSimulator::do_MYY(const CircuitInstruction &inst) { template void TableauSimulator::do_MZZ(const CircuitInstruction &inst) { decompose_pair_instruction_into_segments_with_single_use_controls( - inst, - inv_state.num_qubits, - [&](CircuitInstruction segment){ + inst, inv_state.num_qubits, [&](CircuitInstruction segment) { do_MZZ_disjoint_controls_segment(segment); }); } @@ -1108,8 +1093,7 @@ void TableauSimulator::do_Z(const CircuitInstruction &target_data) { } template -simd_bits TableauSimulator::sample_circuit( - const Circuit &circuit, std::mt19937_64 &rng, int8_t sign_bias) { +simd_bits TableauSimulator::sample_circuit(const Circuit &circuit, std::mt19937_64 &rng, int8_t sign_bias) { TableauSimulator sim(std::move(rng), circuit.count_qubits(), sign_bias); sim.expand_do_circuit(circuit); @@ -1131,7 +1115,8 @@ void TableauSimulator::ensure_large_enough_for_qubits(size_t num_qubits) { } template -void TableauSimulator::sample_stream(FILE *in, FILE *out, SampleFormat format, bool interactive, std::mt19937_64 &rng) { +void TableauSimulator::sample_stream( + FILE *in, FILE *out, SampleFormat format, bool interactive, std::mt19937_64 &rng) { TableauSimulator sim(std::move(rng), 1); auto writer = MeasureRecordWriter::make(out, format); Circuit unprocessed; @@ -1526,7 +1511,7 @@ int8_t TableauSimulator::peek_observable_expectation(const PauliString &ob template void TableauSimulator::do_gate(const CircuitInstruction &inst) { - switch(inst.gate_type) { + switch (inst.gate_type) { case GateType::DETECTOR: do_I(inst); break; @@ -1732,4 +1717,3 @@ void TableauSimulator::do_gate(const CircuitInstruction &inst) { } } // namespace stim - diff --git a/src/stim/stabilizers/conversions.inl b/src/stim/stabilizers/conversions.inl index 3c424cc2c..b37c24fd7 100644 --- a/src/stim/stabilizers/conversions.inl +++ b/src/stim/stabilizers/conversions.inl @@ -1,8 +1,7 @@ -#include "stim/stabilizers/conversions.h" - #include "stim/probability_util.h" #include "stim/simulators/tableau_simulator.h" #include "stim/simulators/vector_simulator.h" +#include "stim/stabilizers/conversions.h" namespace stim { @@ -30,8 +29,7 @@ inline size_t compute_occupation(const std::vector> &state_v } template -Circuit stabilizer_state_vector_to_circuit( - const std::vector> &state_vector, bool little_endian) { +Circuit stabilizer_state_vector_to_circuit(const std::vector> &state_vector, bool little_endian) { if (!is_power_of_2(state_vector.size())) { std::stringstream ss; ss << "Expected number of amplitudes to be a power of 2."; @@ -144,8 +142,7 @@ std::vector>> tableau_to_unitary(const Tableau -Tableau circuit_to_tableau( - const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset) { +Tableau circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset) { Tableau result(circuit.count_qubits()); TableauSimulator sim(std::mt19937_64(0), circuit.count_qubits()); diff --git a/src/stim/stabilizers/pauli_string.inl b/src/stim/stabilizers/pauli_string.inl index 2b017ae95..f3c5a1e0c 100644 --- a/src/stim/stabilizers/pauli_string.inl +++ b/src/stim/stabilizers/pauli_string.inl @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/stabilizers/pauli_string.h" - #include #include #include #include "stim/mem/simd_util.h" +#include "stim/stabilizers/pauli_string.h" namespace stim { diff --git a/src/stim/stabilizers/pauli_string.pybind.cc b/src/stim/stabilizers/pauli_string.pybind.cc index 6bc494963..fa1b68b54 100644 --- a/src/stim/stabilizers/pauli_string.pybind.cc +++ b/src/stim/stabilizers/pauli_string.pybind.cc @@ -14,6 +14,7 @@ #include "stim/stabilizers/pauli_string.h" +#include "pauli_string_iter.h" #include "stim/circuit/circuit_instruction.pybind.h" #include "stim/py/base.pybind.h" #include "stim/py/numpy.pybind.h" @@ -1612,4 +1613,97 @@ void stim_pybind::pybind_pauli_string_methods(pybind11::module &m, pybind11::cla [](const pybind11::str &d) { return PyPauliString::from_text(pybind11::cast(d).data()); })); + + c.def_static( + "iter_all", + [](size_t num_qubits, + size_t min_weight, + const pybind11::object &max_weight_obj, + const std::string &allowed_paulis) { + bool allow_x = false; + bool allow_y = false; + bool allow_z = false; + for (char c : allowed_paulis) { + switch (c) { + case 'X': + allow_x = true; + break; + case 'Y': + allow_y = true; + break; + case 'Z': + allow_z = true; + break; + default: + throw std::invalid_argument( + "allowed_paulis='" + allowed_paulis + "' had characters other than 'X', 'Y', and 'Z'."); + } + } + size_t max_weight = num_qubits; + if (!max_weight_obj.is_none()) { + int64_t v = pybind11::cast(max_weight_obj); + if (v < 0) { + min_weight = 1; + max_weight = 0; + } else { + max_weight = (size_t)v; + } + } + return PauliStringIterator( + num_qubits, min_weight, max_weight, allow_x, allow_y, allow_z); + }, + pybind11::arg("num_qubits"), + pybind11::kw_only(), + pybind11::arg("min_weight") = 0, + pybind11::arg("max_weight") = pybind11::none(), + pybind11::arg("allowed_paulis") = "XYZ", + clean_doc_string(R"DOC( + Returns an iterator that iterates over all matching pauli strings. + + Args: + num_qubits: The desired number of qubits in the pauli strings. + min_weight: Defaults to 0. The minimum number of non-identity terms that + must be present in each yielded pauli string. + max_weight: Defaults to None (unused). The maximum number of non-identity + terms that must be present in each yielded pauli string. + allowed_paulis: Defaults to "XYZ". Set this to a string containing the + non-identity paulis that are allowed to appear in each yielded pauli + string. This argument must be a string made up of only "X", "Y", and + "Z" characters. A non-identity Pauli is allowed if it appears in the + string, and not allowed if it doesn't. Identity Paulis are always + allowed. + + Returns: + An Iterable[stim.PauliString] that yields the requested pauli strings. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... num_qubits=3, + ... min_weight=1, + ... max_weight=2, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X__ + +Z__ + +_X_ + +_Z_ + +__X + +__Z + +XX_ + +XZ_ + +ZX_ + +ZZ_ + +X_X + +X_Z + +Z_X + +Z_Z + +_XX + +_XZ + +_ZX + +_ZZ + )DOC") + .data()); } diff --git a/src/stim/stabilizers/pauli_string_iter.h b/src/stim/stabilizers/pauli_string_iter.h new file mode 100644 index 000000000..5c49c8555 --- /dev/null +++ b/src/stim/stabilizers/pauli_string_iter.h @@ -0,0 +1,145 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _STIM_STABILIZERS_PAULI_STRING_ITER_H +#define _STIM_STABILIZERS_PAULI_STRING_ITER_H + +#include "stim/mem/fixed_cap_vector.h" +#include "stim/mem/span_ref.h" +#include "stim/stabilizers/tableau.h" + +namespace stim { + +/// Tracks the state of a loop. +struct NestedLooperLoop { + /// The first index that should be iterated. + uint64_t start; + /// One past the last index that should be iterated. + uint64_t end; + /// If this is set to the index of another loop, the starting offset is shifted by that other loop's value. + /// Set to UINT64_MAX to not use. + /// This is used in 'append_combination_loops' to avoid repeating combinations. + uint64_t offset_from_other = UINT64_MAX; + /// The current value of the iteration variable for this loop. + /// UINT64_MAX means 'loop is not started yet'. + uint64_t cur = UINT64_MAX; +}; + +/// A helper class for managing dynamically nested loops. +struct NestedLooper { + std::vector loops; + uint64_t k = 0; + + /// Adds a series of nested loops for iterating combinations of w values from [0, n). + inline void append_combination_loops(uint64_t n, uint64_t w) { + if (w > 0) { + loops.push_back(NestedLooperLoop{0, n - w + 1}); + for (uint64_t j = 1; j < w; j++) { + auto v = loops.size() - 1; + loops.push_back(NestedLooperLoop{1, n - w + j + 1, v}); + } + } + } + + /// Clears all loop variables and sets the loop index to the outermost loop. + inline void start() { + k = 0; + for (auto &loop : loops) { + loop.cur = UINT64_MAX; + } + } + + inline bool iter_next(const std::function &on_iter) { + if (loops.empty()) { + return false; + } + + // k is the index of the loop to advance. + // In the first step, k will be 0. + // In later step, k is loops.size(). + if (k == loops.size()) { + // Drop k by 1 to advance the innermost loop. + k--; + } + + while (true) { + // Start or advance the current loop. + if (loops[k].cur == UINT64_MAX) { + loops[k].cur = loops[k].start; + if (loops[k].offset_from_other != UINT64_MAX) { + loops[k].cur += loops[loops[k].offset_from_other].cur; + } + } else { + loops[k].cur++; + } + + // Notify the caller so they can dynamically add inner loops if needed. + on_iter(k); + + // Check if the current loop has ended. + if (loops[k].cur >= loops[k].end) { + if (k == 0) { + // The outermost loop ended. + return false; + } + loops[k].cur = UINT64_MAX; + k -= 1; + continue; + } + + // Move down to the next loop. + k++; + if (k == loops.size()) { + // We're inside the innermost loop. + return true; + } + } + } +}; + +/// Iterates over pauli strings matching specified parameters. +/// +/// The template parameter, W, represents the SIMD width. +template +struct PauliStringIterator { + // Parameter storage. + size_t num_qubits; /// Number of qubits in results. + size_t min_weight; /// Minimum number of non-identity terms in results. + size_t max_weight; /// Maximum number of non-identity terms in results. + bool allow_x; /// Whether results are permitted to contain 'X' terms. + bool allow_y; /// Whether results are permitted to contain 'Y' terms. + bool allow_z; /// Whether results are permitted to contain 'Z' terms. + + // Progress storage. + NestedLooper looper; /// Tracks nested loops over target weight, the target qubits, and the target paulis. + PauliString result; /// When iter_next() returns true, the result is stored in this field. + + PauliStringIterator( + size_t num_qubits, size_t min_weight, size_t max_weight, bool allow_x, bool allow_y, bool allow_z); + + /// Updates the `result` field to point at the next yielded pauli string. + /// Returns true if this succeeded, or false if iteration has ended. + bool iter_next(); + + // Restarts iteration. + void restart(); +}; + +} // namespace stim + +#include "stim/stabilizers/pauli_string_iter.inl" + +#endif diff --git a/src/stim/stabilizers/pauli_string_iter.inl b/src/stim/stabilizers/pauli_string_iter.inl new file mode 100644 index 000000000..097be2be2 --- /dev/null +++ b/src/stim/stabilizers/pauli_string_iter.inl @@ -0,0 +1,75 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "stim/stabilizers/pauli_string.h" +#include "stim/stabilizers/pauli_string_iter.h" + +namespace stim { + +template +PauliStringIterator::PauliStringIterator( + size_t num_qubits, size_t min_weight, size_t max_weight, bool allow_x, bool allow_y, bool allow_z) + : num_qubits(num_qubits), + min_weight(min_weight), + max_weight(max_weight), + allow_x(allow_x), + allow_y(allow_y), + allow_z(allow_z), + result(num_qubits) { + restart(); +} + +template +bool PauliStringIterator::iter_next() { + return looper.iter_next([&](size_t loop_index) { + const NestedLooperLoop &loop = looper.loops[loop_index]; + if (loop_index == 0) { + // Reached a new weight. Need to iterate over xs. + looper.loops.resize(loop_index + 1); + looper.append_combination_loops(num_qubits, loop.cur); + } else if (loop_index == looper.loops[0].cur) { + // Reached a new weight mask. Need to iterate over X/Z values. + looper.loops.resize(loop_index + 1); + result.xs.clear(); + result.zs.clear(); + size_t pauli_weight = allow_x + allow_y + allow_z; + for (size_t j = 0; j < looper.loops[0].cur; j++) { + looper.loops.push_back(NestedLooperLoop{1, 1 + pauli_weight}); + } + } else if (loop_index > looper.loops[0].cur) { + // Iterating a pauli. Keep the results up to date as the paulis change. + auto q = looper.loops[loop_index - looper.loops[0].cur].cur; + auto v = loop.cur; + v += !allow_x && v >= 1; + v += !allow_y && v >= 2; + v += !allow_z && v >= 3; + bool y = (v & 1) != 0; + bool z = (v & 2) != 0; + result.xs[q] = y ^ z; + result.zs[q] = z; + } + }); +} + +template +void PauliStringIterator::restart() { + looper.loops.clear(); + size_t clamped_max_weight = std::min(max_weight, num_qubits); + if (clamped_max_weight >= min_weight) { + looper.loops.push_back({min_weight, clamped_max_weight + 1, UINT64_MAX}); + } + looper.start(); +} + +} // namespace stim diff --git a/src/stim/stabilizers/pauli_string_iter.perf.cc b/src/stim/stabilizers/pauli_string_iter.perf.cc new file mode 100644 index 000000000..94db2f1b5 --- /dev/null +++ b/src/stim/stabilizers/pauli_string_iter.perf.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "stim/stabilizers/pauli_string_iter.h" + +#include "stim/benchmark_util.perf.h" + +using namespace stim; + +BENCHMARK(pauli_iter_xz_2_to_5_of_5) { + size_t c = 0; + size_t n = 0; + benchmark_go([&]() { + PauliStringIterator iter(5, 2, 5, true, false, true); + n = 0; + while (iter.iter_next()) { + c += iter.result.num_qubits; + n += 1; + } + }) + .goal_micros(8) + .show_rate("PauliStrings", n); + if (c == 0) { + std::cerr << "use the output\n"; + } +} + +BENCHMARK(pauli_iter_xyz_1_of_1000) { + size_t c = 0; + size_t n = 0; + benchmark_go([&]() { + PauliStringIterator iter(1000, 1, 1, true, true, true); + n = 0; + while (iter.iter_next()) { + c += iter.result.num_qubits; + n += 1; + } + }) + .goal_micros(55) + .show_rate("PauliStrings", n); + if (c == 0) { + std::cerr << "use the output\n"; + } +} diff --git a/src/stim/stabilizers/pauli_string_iter.pybind.cc b/src/stim/stabilizers/pauli_string_iter.pybind.cc new file mode 100644 index 000000000..ea7a7ee80 --- /dev/null +++ b/src/stim/stabilizers/pauli_string_iter.pybind.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "stim/stabilizers/pauli_string_iter.pybind.h" + +#include "pauli_string.pybind.h" +#include "stim/py/base.pybind.h" + +using namespace stim; +using namespace stim_pybind; + +pybind11::class_> stim_pybind::pybind_pauli_string_iter(pybind11::module &m) { + auto c = pybind11::class_>( + m, + "PauliStringIterator", + clean_doc_string(R"DOC( + Iterates over all pauli strings matching specified patterns. + + Examples: + >>> import stim + >>> pauli_string_iterator = stim.PauliString.iter_all( + ... 2, + ... min_weight=1, + ... max_weight=1, + ... allowed_paulis="XZ", + ... ) + >>> for p in pauli_string_iterator: + ... print(p) + +X_ + +Z_ + +_X + +_Z + )DOC") + .data()); + return c; +} + +void stim_pybind::pybind_pauli_string_iter_methods( + pybind11::module &m, pybind11::class_> &c) { + c.def( + "__iter__", + [](PauliStringIterator &self) -> PauliStringIterator { + PauliStringIterator copy = self; + return copy; + }, + clean_doc_string(R"DOC( + Returns an independent copy of the pauli string iterator. + + Since for-loops and loop-comprehensions call `iter` on things they + iterate, this effectively allows the iterator to be iterated + multiple times. + )DOC") + .data()); + + c.def( + "__next__", + [](PauliStringIterator &self) -> PyPauliString { + if (!self.iter_next()) { + throw pybind11::stop_iteration(); + } + return PyPauliString(self.result); + }, + clean_doc_string(R"DOC( + Returns the next iterated pauli string. + )DOC") + .data()); +} diff --git a/src/stim/stabilizers/pauli_string_iter.pybind.h b/src/stim/stabilizers/pauli_string_iter.pybind.h new file mode 100644 index 000000000..ebf16cb19 --- /dev/null +++ b/src/stim/stabilizers/pauli_string_iter.pybind.h @@ -0,0 +1,28 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _STIM_STABILIZERS_PAULI_STRING_ITER_PYBIND_H +#define _STIM_STABILIZERS_PAULI_STRING_ITER_PYBIND_H + +#include + +#include "stim/stabilizers/pauli_string_iter.h" + +namespace stim_pybind { +pybind11::class_> pybind_pauli_string_iter(pybind11::module &m); +void pybind_pauli_string_iter_methods( + pybind11::module &m, pybind11::class_> &c); +} // namespace stim_pybind + +#endif diff --git a/src/stim/stabilizers/pauli_string_iter.test.cc b/src/stim/stabilizers/pauli_string_iter.test.cc new file mode 100644 index 000000000..b79a32079 --- /dev/null +++ b/src/stim/stabilizers/pauli_string_iter.test.cc @@ -0,0 +1,297 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "stim/stabilizers/pauli_string_iter.h" + +#include "gtest/gtest.h" + +#include "stim/mem/simd_word.test.h" + +using namespace stim; + +std::vector loop_state(const NestedLooper &looper) { + std::vector state; + for (const auto &e : looper.loops) { + state.push_back(e.cur); + } + return state; +} + +std::vector> loop_drain(NestedLooper &looper) { + std::vector> state; + looper.start(); + while (looper.iter_next([](size_t k) { + })) { + state.push_back(loop_state(looper)); + } + return state; +} + +TEST(pauli_string_iter, NestedLooper_simple) { + NestedLooper looper; + looper.loops.push_back(NestedLooperLoop{0, 3}); + looper.loops.push_back(NestedLooperLoop{2, 6}); + + ASSERT_EQ( + loop_drain(looper), + (std::vector>{ + {0, 2}, + {0, 3}, + {0, 4}, + {0, 5}, + {1, 2}, + {1, 3}, + {1, 4}, + {1, 5}, + {2, 2}, + {2, 3}, + {2, 4}, + {2, 5}, + })); +} + +TEST(pauli_string_iter, NestedLooper_shifted) { + NestedLooper looper; + looper.loops.push_back(NestedLooperLoop{0, 3}); + looper.loops.push_back(NestedLooperLoop{2, 6, 0}); + ASSERT_EQ( + loop_drain(looper), + (std::vector>{ + {0, 2}, + {0, 3}, + {0, 4}, + {0, 5}, + {1, 3}, + {1, 4}, + {1, 5}, + {2, 4}, + {2, 5}, + })); +} + +TEST(pauli_string_iter, NestedLooper_append_combination_loops) { + NestedLooper looper; + looper.append_combination_loops(6, 3); + ASSERT_EQ( + loop_drain(looper), + (std::vector>{ + {0, 1, 2}, {0, 1, 3}, {0, 1, 4}, {0, 1, 5}, {0, 2, 3}, {0, 2, 4}, {0, 2, 5}, + {0, 3, 4}, {0, 3, 5}, {0, 4, 5}, {1, 2, 3}, {1, 2, 4}, {1, 2, 5}, {1, 3, 4}, + {1, 3, 5}, {1, 4, 5}, {2, 3, 4}, {2, 3, 5}, {2, 4, 5}, {3, 4, 5}, + })); + + looper.loops.clear(); + looper.append_combination_loops(10, 9); + ASSERT_EQ( + loop_drain(looper), + (std::vector>{ + {0, 1, 2, 3, 4, 5, 6, 7, 8}, + {0, 1, 2, 3, 4, 5, 6, 7, 9}, + {0, 1, 2, 3, 4, 5, 6, 8, 9}, + {0, 1, 2, 3, 4, 5, 7, 8, 9}, + {0, 1, 2, 3, 4, 6, 7, 8, 9}, + {0, 1, 2, 3, 5, 6, 7, 8, 9}, + {0, 1, 2, 4, 5, 6, 7, 8, 9}, + {0, 1, 3, 4, 5, 6, 7, 8, 9}, + {0, 2, 3, 4, 5, 6, 7, 8, 9}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + })); +} + +TEST(pauli_string_iter, NestedLooper_inplace_edit) { + NestedLooper looper; + looper.loops.push_back(NestedLooperLoop{1, 3}); + + std::vector> state; + looper.start(); + while (looper.iter_next([&](size_t k) { + if (k == 0 && looper.loops[0].cur == 2) { + looper.loops.push_back(NestedLooperLoop{2, 4}); + } + })) { + state.push_back(loop_state(looper)); + } + + ASSERT_EQ( + state, + (std::vector>{ + {1}, + {2, 2}, + {2, 3}, + })); +} + +template +std::vector record_pauli_string(PauliStringIterator iter) { + std::vector results; + while (iter.iter_next()) { + results.push_back(iter.result.str()); + } + return results; +} + +TEST_EACH_WORD_SIZE_W(pauli_string_iter, small_cases, { + // Empty. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(3, 0, 0, true, true, true)), + (std::vector{ + "+___", + })); + + // Empty or single. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 1, true, true, true)), + (std::vector{ + "+__", + "+X_", + "+Y_", + "+Z_", + "+_X", + "+_Y", + "+_Z", + })); + + // Single. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(3, 1, 1, true, true, true)), + (std::vector{ + "+X__", + "+Y__", + "+Z__", + "+_X_", + "+_Y_", + "+_Z_", + "+__X", + "+__Y", + "+__Z", + })); + + // Full doubles. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 2, 2, true, true, true)), + (std::vector{ + "+XX", + "+XY", + "+XZ", + "+YX", + "+YY", + "+YZ", + "+ZX", + "+ZY", + "+ZZ", + })); + + // All length 2. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, true, true, true)), + (std::vector{ + "+__", + "+X_", + "+Y_", + "+Z_", + "+_X", + "+_Y", + "+_Z", + "+XX", + "+XY", + "+XZ", + "+YX", + "+YY", + "+YZ", + "+ZX", + "+ZY", + "+ZZ", + })); + + // XY subset. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, false, true, true)), + (std::vector{ + "+__", + "+Y_", + "+Z_", + "+_Y", + "+_Z", + "+YY", + "+YZ", + "+ZY", + "+ZZ", + })); + // XZ subset. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, true, false, true)), + (std::vector{ + "+__", + "+X_", + "+Z_", + "+_X", + "+_Z", + "+XX", + "+XZ", + "+ZX", + "+ZZ", + })); + // YZ subset. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, true, true, false)), + (std::vector{ + "+__", + "+X_", + "+Y_", + "+_X", + "+_Y", + "+XX", + "+XY", + "+YX", + "+YY", + })); + + // X subset. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, true, false, false)), + (std::vector{ + "+__", + "+X_", + "+_X", + "+XX", + })); + // Y subset. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, false, true, false)), + (std::vector{ + "+__", + "+Y_", + "+_Y", + "+YY", + })); + // Z subset. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, false, false, true)), + (std::vector{ + "+__", + "+Z_", + "+_Z", + "+ZZ", + })); + + // No pauli subset. + ASSERT_EQ( + record_pauli_string(PauliStringIterator(2, 0, 2, false, false, false)), + (std::vector{ + "+__", + })); + ASSERT_EQ(record_pauli_string(PauliStringIterator(2, 1, 2, false, false, false)), (std::vector{})); + ASSERT_EQ(record_pauli_string(PauliStringIterator(2, 3, 6, false, false, false)), (std::vector{})); + ASSERT_EQ(record_pauli_string(PauliStringIterator(2, 2, 1, false, false, false)), (std::vector{})); +}) diff --git a/src/stim/stabilizers/pauli_string_pybind_test.py b/src/stim/stabilizers/pauli_string_pybind_test.py index 97eced9a9..632f72d36 100644 --- a/src/stim/stabilizers/pauli_string_pybind_test.py +++ b/src/stim/stabilizers/pauli_string_pybind_test.py @@ -833,3 +833,46 @@ def test_before_after(): assert after.before(stim.Circuit("C_XYZ 1 4 6")) == before assert after.before(stim.Circuit("C_XYZ 1 4 6")[0]) == before assert after.before(stim.Tableau.from_named_gate("C_XYZ"), targets=[1, 4, 6]) == before + + +def test_iter_small(): + assert list(stim.PauliString.iter_all(0)) == [stim.PauliString(0)] + assert list(stim.PauliString.iter_all(1)) == [ + stim.PauliString("_"), + stim.PauliString("X"), + stim.PauliString("Y"), + stim.PauliString("Z"), + ] + assert list(stim.PauliString.iter_all(1, max_weight=-1)) == [ + ] + assert list(stim.PauliString.iter_all(1, max_weight=0)) == [ + stim.PauliString("_"), + ] + assert list(stim.PauliString.iter_all(1, max_weight=1)) == [ + stim.PauliString("_"), + stim.PauliString("X"), + stim.PauliString("Y"), + stim.PauliString("Z"), + ] + assert list(stim.PauliString.iter_all(1, min_weight=1, max_weight=1)) == [ + stim.PauliString("X"), + stim.PauliString("Y"), + stim.PauliString("Z"), + ] + assert list(stim.PauliString.iter_all(2, min_weight=1, max_weight=1, allowed_paulis="XY")) == [ + stim.PauliString("X_"), + stim.PauliString("Y_"), + stim.PauliString("_X"), + stim.PauliString("_Y"), + ] + + with pytest.raises(ValueError, match="characters other than"): + stim.PauliString.iter_all(2, allowed_paulis="A") + + +def test_iter_reusable(): + v = stim.PauliString.iter_all(2) + vs1 = list(v) + vs2 = list(v) + assert vs1 == vs2 + assert len(vs1) == 4**2 diff --git a/src/stim/stabilizers/pauli_string_ref.inl b/src/stim/stabilizers/pauli_string_ref.inl index c9d1c634d..522c446eb 100644 --- a/src/stim/stabilizers/pauli_string_ref.inl +++ b/src/stim/stabilizers/pauli_string_ref.inl @@ -24,10 +24,7 @@ namespace stim { template PauliStringRef::PauliStringRef( - size_t init_num_qubits, - bit_ref init_sign, - simd_bits_range_ref init_xs, - simd_bits_range_ref init_zs) + size_t init_num_qubits, bit_ref init_sign, simd_bits_range_ref init_xs, simd_bits_range_ref init_zs) : num_qubits(init_num_qubits), sign(init_sign), xs(init_xs), zs(init_zs) { assert(init_xs.num_bits_padded() == init_zs.num_bits_padded()); assert(init_xs.num_simd_words == (init_num_qubits + W - 1) / W); @@ -107,19 +104,20 @@ uint8_t PauliStringRef::inplace_right_mul_returning_log_i_scalar(const PauliS simd_word cnt1{}; simd_word cnt2{}; - xs.for_each_word(zs, rhs.xs, rhs.zs, [&cnt1, &cnt2](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { - // Update the left hand side Paulis. - auto old_x1 = x1; - auto old_z1 = z1; - x1 ^= x2; - z1 ^= z2; + xs.for_each_word( + zs, rhs.xs, rhs.zs, [&cnt1, &cnt2](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + // Update the left hand side Paulis. + auto old_x1 = x1; + auto old_z1 = z1; + x1 ^= x2; + z1 ^= z2; - // At each bit position: accumulate anti-commutation (+i or -i) counts. - auto x1z2 = old_x1 & z2; - auto anti_commutes = (x2 & old_z1) ^ x1z2; - cnt2 ^= (cnt1 ^ x1 ^ z1 ^ x1z2) & anti_commutes; - cnt1 ^= anti_commutes; - }); + // At each bit position: accumulate anti-commutation (+i or -i) counts. + auto x1z2 = old_x1 & z2; + auto anti_commutes = (x2 & old_z1) ^ x1z2; + cnt2 ^= (cnt1 ^ x1 ^ z1 ^ x1z2) & anti_commutes; + cnt1 ^= anti_commutes; + }); // Combine final anti-commutation phase tally (mod 4). auto s = (uint8_t)cnt1.popcount(); @@ -134,14 +132,16 @@ bool PauliStringRef::commutes(const PauliStringRef &other) const noexcept return other.commutes(*this); } simd_word cnt1{}; - xs.for_each_word(zs, other.xs, other.zs, [&cnt1](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { - cnt1 ^= (x1 & z2) ^ (x2 & z1); - }); + xs.for_each_word( + zs, other.xs, other.zs, [&cnt1](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2) { + cnt1 ^= (x1 & z2) ^ (x2 & z1); + }); return (cnt1.popcount() & 1) == 0; } template -void PauliStringRef::after_inplace_broadcast(const Tableau &tableau, SpanRef indices, bool inverse) { +void PauliStringRef::after_inplace_broadcast( + const Tableau &tableau, SpanRef indices, bool inverse) { if (tableau.num_qubits == 0 || indices.size() % tableau.num_qubits != 0) { throw std::invalid_argument("len(tableau) == 0 or len(indices) % len(tableau) != 0"); } @@ -334,7 +334,7 @@ template size_t PauliStringRef::weight() const { size_t total = 0; xs.for_each_word(zs, [&](const simd_word &w1, const simd_word &w2) { - total += (w1 | w2).popcount(); + total += (w1 | w2).popcount(); }); return total; } diff --git a/src/stim/stabilizers/tableau.inl b/src/stim/stabilizers/tableau.inl index 4e3bab6c7..3c7a750a1 100644 --- a/src/stim/stabilizers/tableau.inl +++ b/src/stim/stabilizers/tableau.inl @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/stabilizers/tableau.h" - #include #include #include @@ -24,6 +22,7 @@ #include "stim/circuit/gate_data.h" #include "stim/simulators/vector_simulator.h" #include "stim/stabilizers/pauli_string.h" +#include "stim/stabilizers/tableau.h" namespace stim { @@ -380,10 +379,8 @@ simd_bit_table random_stabilizer_tableau_raw(size_t n, std::mt19937_64 &rng) inv.do_square_transpose(); inv_m.do_square_transpose(); - auto fused = - simd_bit_table::from_quadrants(n, lower, simd_bit_table(n, n), prod, inv); - auto fused_m = simd_bit_table::from_quadrants( - n, lower_m, simd_bit_table(n, n), prod_m, inv_m); + auto fused = simd_bit_table::from_quadrants(n, lower, simd_bit_table(n, n), prod, inv); + auto fused_m = simd_bit_table::from_quadrants(n, lower_m, simd_bit_table(n, n), prod_m, inv_m); simd_bit_table u(2 * n, 2 * n); @@ -753,4 +750,3 @@ PauliString Tableau::y_output(size_t input_index) const { } } // namespace stim - diff --git a/src/stim/stabilizers/tableau_iter.inl b/src/stim/stabilizers/tableau_iter.inl index a277ccd93..484f830cd 100644 --- a/src/stim/stabilizers/tableau_iter.inl +++ b/src/stim/stabilizers/tableau_iter.inl @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/stabilizers/tableau_iter.h" - #include "stim/stabilizers/pauli_string.h" +#include "stim/stabilizers/tableau_iter.h" namespace stim { @@ -186,8 +185,8 @@ TableauIterator &TableauIterator::operator=(const TableauIterator &othe } template -std::pair>, SpanRef>> TableauIterator::constraints_for_pauli_iterator( - size_t k) const { +std::pair>, SpanRef>> +TableauIterator::constraints_for_pauli_iterator(size_t k) const { const PauliStringRef *tab_obs_start = &tableau_column_refs[0]; SpanRef> commute_rng = {tab_obs_start, tab_obs_start + k}; SpanRef> anticommute_rng; diff --git a/src/stim/stabilizers/tableau_iter.perf.cc b/src/stim/stabilizers/tableau_iter.perf.cc index 15793abf6..e9e416b40 100644 --- a/src/stim/stabilizers/tableau_iter.perf.cc +++ b/src/stim/stabilizers/tableau_iter.perf.cc @@ -27,7 +27,7 @@ BENCHMARK(tableau_iter_unsigned_3q) { } }) .goal_millis(200) - .show_rate("TableausPerSecond", 1451520); + .show_rate("Tableaus", 1451520); if (c == 0) { std::cerr << "use the output\n"; } @@ -42,7 +42,7 @@ BENCHMARK(tableau_iter_all_3q) { } }) .goal_millis(420) - .show_rate("TableausPerSecond", 92897280); + .show_rate("Tableaus", 92897280); if (c == 0) { std::cerr << "use the output\n"; } diff --git a/src/stim/stabilizers/tableau_transposed_raii.inl b/src/stim/stabilizers/tableau_transposed_raii.inl index 932290ecd..520e51321 100644 --- a/src/stim/stabilizers/tableau_transposed_raii.inl +++ b/src/stim/stabilizers/tableau_transposed_raii.inl @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stim/stabilizers/tableau_transposed_raii.h" - #include #include #include "stim/stabilizers/pauli_string.h" +#include "stim/stabilizers/tableau_transposed_raii.h" namespace stim { @@ -61,7 +60,10 @@ inline void for_each_trans_obs(TableauTransposedRaii &trans, size_t q1, size_ template void TableauTransposedRaii::append_ZCX(size_t control, size_t target) { for_each_trans_obs( - *this, control, target, [](simd_word &cx, simd_word &cz, simd_word &tx, simd_word &tz, simd_word &s) { + *this, + control, + target, + [](simd_word &cx, simd_word &cz, simd_word &tx, simd_word &tz, simd_word &s) { s ^= (cz ^ tx).andnot(cx & tz); cz ^= tz; tx ^= cx; @@ -71,7 +73,10 @@ void TableauTransposedRaii::append_ZCX(size_t control, size_t target) { template void TableauTransposedRaii::append_ZCY(size_t control, size_t target) { for_each_trans_obs( - *this, control, target, [](simd_word &cx, simd_word &cz, simd_word &tx, simd_word &tz, simd_word &s) { + *this, + control, + target, + [](simd_word &cx, simd_word &cz, simd_word &tx, simd_word &tz, simd_word &s) { cz ^= tx; s ^= cx & cz & (tx ^ tz); cz ^= tz; @@ -83,7 +88,10 @@ void TableauTransposedRaii::append_ZCY(size_t control, size_t target) { template void TableauTransposedRaii::append_ZCZ(size_t control, size_t target) { for_each_trans_obs( - *this, control, target, [](simd_word &cx, simd_word &cz, simd_word &tx, simd_word &tz, simd_word &s) { + *this, + control, + target, + [](simd_word &cx, simd_word &cz, simd_word &tx, simd_word &tz, simd_word &s) { s ^= cx & tx & (cz ^ tz); cz ^= tx; tz ^= cx; @@ -92,10 +100,11 @@ void TableauTransposedRaii::append_ZCZ(size_t control, size_t target) { template void TableauTransposedRaii::append_SWAP(size_t q1, size_t q2) { - for_each_trans_obs(*this, q1, q2, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2, simd_word &s) { - std::swap(x1, x2); - std::swap(z1, z2); - }); + for_each_trans_obs( + *this, q1, q2, [](simd_word &x1, simd_word &z1, simd_word &x2, simd_word &z2, simd_word &s) { + std::swap(x1, x2); + std::swap(z1, z2); + }); } template