From bad41d70ee8963972b5f824a71896d38b3cde8a1 Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Fri, 23 Aug 2024 21:16:35 -0400 Subject: [PATCH 1/7] add hypergraph-UF and mwpf decoder in sinter --- .../sinter/_decoding_all_built_in_decoders.py | 3 + glue/sample/src/sinter/_decoding_mwpf.py | 262 ++++++++++++++++++ 2 files changed, 265 insertions(+) create mode 100644 glue/sample/src/sinter/_decoding_mwpf.py diff --git a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding_all_built_in_decoders.py index a9fc5e76c..a27b3e9ae 100644 --- a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py +++ b/glue/sample/src/sinter/_decoding_all_built_in_decoders.py @@ -4,9 +4,12 @@ from sinter._decoding_fusion_blossom import FusionBlossomDecoder from sinter._decoding_pymatching import PyMatchingDecoder from sinter._decoding_vacuous import VacuousDecoder +from sinter._decoding_mwpf import HyperUFDecoder, MwpfDecoder BUILT_IN_DECODERS: Dict[str, Decoder] = { 'vacuous': VacuousDecoder(), 'pymatching': PyMatchingDecoder(), 'fusion_blossom': FusionBlossomDecoder(), + 'hyper_uf': HyperUFDecoder(), + 'mwpf': MwpfDecoder(), } diff --git a/glue/sample/src/sinter/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding_mwpf.py new file mode 100644 index 000000000..636de1f01 --- /dev/null +++ b/glue/sample/src/sinter/_decoding_mwpf.py @@ -0,0 +1,262 @@ +import math +import pathlib +from typing import Callable, List, TYPE_CHECKING, Tuple + +import numpy as np +import stim + +from sinter._decoding_decoder_class import Decoder, CompiledDecoder + +if TYPE_CHECKING: + import mwpf + +DEFAULT_TIMEOUT: float = 10.0 # decoder timeout in seconds + + +class MwpfCompiledDecoder(CompiledDecoder): + def __init__( + self, + solver: "mwpf.SolverSerialJointSingleHair", + fault_masks: "np.ndarray", + num_dets: int, + num_obs: int, + ): + self.solver = solver + self.fault_masks = fault_masks + self.num_dets = num_dets + self.num_obs = num_obs + + def decode_shots_bit_packed( + self, + *, + bit_packed_detection_event_data: "np.ndarray", + ) -> "np.ndarray": + num_shots = bit_packed_detection_event_data.shape[0] + predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8) + import mwpf + + for shot in range(num_shots): + dets_sparse = np.flatnonzero( + np.unpackbits( + bit_packed_detection_event_data[shot], + count=self.num_dets, + bitorder="little", + ) + ) + syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse) + self.solver.solve(syndrome) + prediction = int( + np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]) + ) + predictions[shot] = np.packbits(prediction, bitorder="little") + self.solver.clear() + return predictions + + +class MwpfDecoder(Decoder): + """Use MWPF to predict observables from detection events.""" + + def compile_decoder_for_dem( + self, *, dem: "stim.DetectorErrorModel", timeout: float = DEFAULT_TIMEOUT + ) -> CompiledDecoder: + try: + import mwpf + except ImportError as ex: + raise ImportError( + "The decoder 'MWPF' isn't installed\n" + "To fix this, install the python package 'MWPF' into your environment.\n" + "For example, if you are using pip, run `pip install MWPF`.\n" + ) from ex + + solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( + dem, timeout=timeout + ) + return MwpfCompiledDecoder( + solver, fault_masks, dem.num_detectors, dem.num_observables + ) + + def decode_via_files( + self, + *, + num_shots: int, + num_dets: int, + num_obs: int, + dem_path: pathlib.Path, + dets_b8_in_path: pathlib.Path, + obs_predictions_b8_out_path: pathlib.Path, + tmp_dir: pathlib.Path, + timeout: float = DEFAULT_TIMEOUT, + ) -> None: + try: + import mwpf + except ImportError as ex: + raise ImportError( + "The decoder 'MWPF' isn't installed\n" + "To fix this, install the python package 'MWPF' into your environment.\n" + "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" + ) from ex + + error_model = stim.DetectorErrorModel.from_file(dem_path) + solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( + error_model, timeout=timeout + ) + num_det_bytes = math.ceil(num_dets / 8) + with open(dets_b8_in_path, "rb") as dets_in_f: + with open(obs_predictions_b8_out_path, "wb") as obs_out_f: + for _ in range(num_shots): + dets_bit_packed = np.fromfile( + dets_in_f, dtype=np.uint8, count=num_det_bytes + ) + if dets_bit_packed.shape != (num_det_bytes,): + raise IOError("Missing dets data.") + dets_sparse = np.flatnonzero( + np.unpackbits( + dets_bit_packed, count=num_dets, bitorder="little" + ) + ) + syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse) + solver.solve(syndrome) + prediction = int( + np.bitwise_xor.reduce(fault_masks[solver.subgraph()]) + ) + obs_out_f.write( + prediction.to_bytes((num_obs + 7) // 8, byteorder="little") + ) + solver.clear() + + +class HyperUFDecoder(MwpfDecoder): + """Setting timeout to 0 becomes effectively a hypergraph UF decoder""" + + def compile_decoder_for_dem( + self, *, dem: "stim.DetectorErrorModel" + ) -> CompiledDecoder: + return super().compile_decoder_for_dem(dem=dem, timeout=0.0) + + def decode_via_files( + self, + *, + num_shots: int, + num_dets: int, + num_obs: int, + dem_path: pathlib.Path, + dets_b8_in_path: pathlib.Path, + obs_predictions_b8_out_path: pathlib.Path, + tmp_dir: pathlib.Path, + ) -> None: + return super().decode_via_files( + num_shots=num_shots, + num_dets=num_dets, + num_obs=num_obs, + dem_path=dem_path, + dets_b8_in_path=dets_b8_in_path, + obs_predictions_b8_out_path=obs_predictions_b8_out_path, + tmp_dir=tmp_dir, + timeout=0.0, + ) + + +def iter_flatten_model( + model: stim.DetectorErrorModel, + handle_error: Callable[[float, List[int], List[int]], None], + handle_detector_coords: Callable[[int, np.ndarray], None], +): + det_offset = 0 + coords_offset = np.zeros(100, dtype=np.float64) + + def _helper(m: stim.DetectorErrorModel, reps: int): + nonlocal det_offset + nonlocal coords_offset + for _ in range(reps): + for instruction in m: + if isinstance(instruction, stim.DemRepeatBlock): + _helper(instruction.body_copy(), instruction.repeat_count) + elif isinstance(instruction, stim.DemInstruction): + if instruction.type == "error": + dets: List[int] = [] + frames: List[int] = [] + t: stim.DemTarget + p = instruction.args_copy()[0] + for t in instruction.targets_copy(): + if t.is_relative_detector_id(): + dets.append(t.val + det_offset) + elif t.is_logical_observable_id(): + frames.append(t.val) + handle_error(p, dets, frames) + elif instruction.type == "shift_detectors": + det_offset += instruction.targets_copy()[0] + a = np.array(instruction.args_copy()) + coords_offset[: len(a)] += a + elif instruction.type == "detector": + a = np.array(instruction.args_copy()) + for t in instruction.targets_copy(): + handle_detector_coords( + t.val + det_offset, a + coords_offset[: len(a)] + ) + elif instruction.type == "logical_observable": + pass + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + _helper(model, 1) + + +def detector_error_model_to_mwpf_solver_and_fault_masks( + model: stim.DetectorErrorModel, timeout: float = DEFAULT_TIMEOUT +) -> Tuple["mwpf.SolverSerialJointSingleHair", np.ndarray]: + """Convert a stim error model into a NetworkX graph.""" + + import mwpf + + num_detectors = model.num_detectors + is_detector_connected = np.full(num_detectors, False, dtype=bool) + hyperedges: List[Tuple[List[int], float, int]] = [] + + def handle_error(p: float, dets: List[int], frame_changes: List[int]): + if p == 0: + return + if len(dets) == 0: + # No symptoms for this error. + # Code probably has distance 1. + # Accept it and keep going, though of course decoding will probably perform terribly. + return + if p > 0.5: + # mwpf doesn't support negative edge weights. + # approximate them as weight 0. + p = 0.5 + weight = math.log((1 - p) / p) + mask = sum(1 << k for k in frame_changes) + is_detector_connected[dets] = True + hyperedges.append((dets, weight, mask)) + + def handle_detector_coords(detector: int, coords: np.ndarray): + pass + + iter_flatten_model( + model, + handle_error=handle_error, + handle_detector_coords=handle_detector_coords, + ) + + # fix the input by connecting an edge to all isolated vertices + for idx in range(num_detectors): + if not is_detector_connected[idx]: + hyperedges.append(([idx], 0, 0)) + + max_weight = max(1e-4, max((w for _, w, _ in hyperedges), default=1)) + rescaled_edges = [ + mwpf.HyperEdge(v, round(w * 2**10 / max_weight) * 2) for v, w, _ in hyperedges + ] + fault_masks = np.array([e[2] for e in hyperedges], dtype=np.uint64) + + initializer = mwpf.SolverInitializer( + num_detectors, # Total number of nodes. + rescaled_edges, # Weighted edges. + ) + + return ( + mwpf.SolverSerialJointSingleHair(initializer, {"primal": {"timeout": timeout}}), + fault_masks, + ) From 33f25a3ff00064be9d3c5b066e3eac3f0d0c1fd8 Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Sun, 25 Aug 2024 20:55:59 -0400 Subject: [PATCH 2/7] remove timeout --- .../sinter/_decoding_all_built_in_decoders.py | 6 +- glue/sample/src/sinter/_decoding_mwpf.py | 74 +++++++++++-------- 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding_all_built_in_decoders.py index a27b3e9ae..4f011cdf1 100644 --- a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py +++ b/glue/sample/src/sinter/_decoding_all_built_in_decoders.py @@ -10,6 +10,8 @@ 'vacuous': VacuousDecoder(), 'pymatching': PyMatchingDecoder(), 'fusion_blossom': FusionBlossomDecoder(), - 'hyper_uf': HyperUFDecoder(), - 'mwpf': MwpfDecoder(), + # an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049) + 'hypergraph_union_find': HyperUFDecoder(), + # Minimum-Weight Parity Factor using similar primal-dual method the blossom algorithm (https://pypi.org/project/mwpf/) + 'mw_parity_factor': MwpfDecoder(), } diff --git a/glue/sample/src/sinter/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding_mwpf.py index 636de1f01..9841915d7 100644 --- a/glue/sample/src/sinter/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding_mwpf.py @@ -1,6 +1,6 @@ import math import pathlib -from typing import Callable, List, TYPE_CHECKING, Tuple +from typing import Callable, List, TYPE_CHECKING, Tuple, Any import numpy as np import stim @@ -10,7 +10,13 @@ if TYPE_CHECKING: import mwpf -DEFAULT_TIMEOUT: float = 10.0 # decoder timeout in seconds + +def mwpf_import_error() -> ImportError: + return ImportError( + "The decoder 'MWPF' isn't installed\n" + "To fix this, install the python package 'MWPF' into your environment.\n" + "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" + ) class MwpfCompiledDecoder(CompiledDecoder): @@ -57,19 +63,18 @@ class MwpfDecoder(Decoder): """Use MWPF to predict observables from detection events.""" def compile_decoder_for_dem( - self, *, dem: "stim.DetectorErrorModel", timeout: float = DEFAULT_TIMEOUT + self, + *, + dem: "stim.DetectorErrorModel", + decoder_cls: Any = None, # decoder class used to construct the MWPF decoder. + # in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins` + # but just provide different plugins for optimizing the primal and/or dual solutions. + # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only + # grows the clusters until the first valid solution appears; some more optimized solvers uses + # one or more plugins to further optimize the solution, which requires longer decoding time. ) -> CompiledDecoder: - try: - import mwpf - except ImportError as ex: - raise ImportError( - "The decoder 'MWPF' isn't installed\n" - "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF`.\n" - ) from ex - solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - dem, timeout=timeout + dem, decoder_cls=decoder_cls ) return MwpfCompiledDecoder( solver, fault_masks, dem.num_detectors, dem.num_observables @@ -85,20 +90,11 @@ def decode_via_files( dets_b8_in_path: pathlib.Path, obs_predictions_b8_out_path: pathlib.Path, tmp_dir: pathlib.Path, - timeout: float = DEFAULT_TIMEOUT, + decoder_cls: Any = None, ) -> None: - try: - import mwpf - except ImportError as ex: - raise ImportError( - "The decoder 'MWPF' isn't installed\n" - "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" - ) from ex - error_model = stim.DetectorErrorModel.from_file(dem_path) solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - error_model, timeout=timeout + error_model, decoder_cls=decoder_cls ) num_det_bytes = math.ceil(num_dets / 8) with open(dets_b8_in_path, "rb") as dets_in_f: @@ -126,12 +122,17 @@ def decode_via_files( class HyperUFDecoder(MwpfDecoder): - """Setting timeout to 0 becomes effectively a hypergraph UF decoder""" - def compile_decoder_for_dem( self, *, dem: "stim.DetectorErrorModel" ) -> CompiledDecoder: - return super().compile_decoder_for_dem(dem=dem, timeout=0.0) + try: + import mwpf + except ImportError as ex: + raise mwpf_import_error() from ex + + return super().compile_decoder_for_dem( + dem=dem, decoder_cls=mwpf.SolverSerialUnionFind + ) def decode_via_files( self, @@ -144,6 +145,11 @@ def decode_via_files( obs_predictions_b8_out_path: pathlib.Path, tmp_dir: pathlib.Path, ) -> None: + try: + import mwpf + except ImportError as ex: + raise mwpf_import_error() from ex + return super().decode_via_files( num_shots=num_shots, num_dets=num_dets, @@ -152,7 +158,7 @@ def decode_via_files( dets_b8_in_path=dets_b8_in_path, obs_predictions_b8_out_path=obs_predictions_b8_out_path, tmp_dir=tmp_dir, - timeout=0.0, + decoder_cls=mwpf.SolverSerialUnionFind, ) @@ -204,11 +210,14 @@ def _helper(m: stim.DetectorErrorModel, reps: int): def detector_error_model_to_mwpf_solver_and_fault_masks( - model: stim.DetectorErrorModel, timeout: float = DEFAULT_TIMEOUT + model: stim.DetectorErrorModel, decoder_cls: Any = None ) -> Tuple["mwpf.SolverSerialJointSingleHair", np.ndarray]: """Convert a stim error model into a NetworkX graph.""" - import mwpf + try: + import mwpf + except ImportError as ex: + raise mwpf_import_error() from ex num_detectors = model.num_detectors is_detector_connected = np.full(num_detectors, False, dtype=bool) @@ -256,7 +265,10 @@ def handle_detector_coords(detector: int, coords: np.ndarray): rescaled_edges, # Weighted edges. ) + if decoder_cls is None: + # default to the solver with highest accuracy + decoder_cls = mwpf.SolverSerialJointSingleHair return ( - mwpf.SolverSerialJointSingleHair(initializer, {"primal": {"timeout": timeout}}), + decoder_cls(initializer), fault_masks, ) From eab795f8b43c15580148e127f523ddd777a580fb Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Mon, 26 Aug 2024 19:06:45 -0400 Subject: [PATCH 3/7] install mwpf in github CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c6709c6f..7f5bb8ea4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -394,7 +394,7 @@ jobs: - run: bazel build :stim_dev_wheel - run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl - run: pip install -e glue/sample - - run: pip install pytest pymatching fusion-blossom~=0.1.4 + - run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1 - run: pytest glue/sample - run: dev/doctest_proper.py --module sinter - run: sinter help From 1c9080909dc9c0787aa065668a1509b938430d0b Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Wed, 28 Aug 2024 09:18:51 -0400 Subject: [PATCH 4/7] solve failed test but logical error is still too high, need to check --- glue/sample/src/sinter/_decoding_mwpf.py | 59 ++++++++++++++++++------ glue/sample/src/sinter/_decoding_test.py | 5 ++ 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/glue/sample/src/sinter/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding_mwpf.py index 9841915d7..18643a90b 100644 --- a/glue/sample/src/sinter/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding_mwpf.py @@ -1,6 +1,6 @@ import math import pathlib -from typing import Callable, List, TYPE_CHECKING, Tuple, Any +from typing import Callable, List, TYPE_CHECKING, Tuple, Any, Optional import numpy as np import stim @@ -50,12 +50,15 @@ def decode_shots_bit_packed( ) ) syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse) - self.solver.solve(syndrome) - prediction = int( - np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]) - ) + if self.solver is None: + prediction = 0 + else: + self.solver.solve(syndrome) + prediction = int( + np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]) + ) + self.solver.clear() predictions[shot] = np.packbits(prediction, bitorder="little") - self.solver.clear() return predictions @@ -92,6 +95,8 @@ def decode_via_files( tmp_dir: pathlib.Path, decoder_cls: Any = None, ) -> None: + import mwpf + error_model = stim.DetectorErrorModel.from_file(dem_path) solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( error_model, decoder_cls=decoder_cls @@ -111,14 +116,17 @@ def decode_via_files( ) ) syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse) - solver.solve(syndrome) - prediction = int( - np.bitwise_xor.reduce(fault_masks[solver.subgraph()]) - ) + if solver is None: + prediction = 0 + else: + solver.solve(syndrome) + prediction = int( + np.bitwise_xor.reduce(fault_masks[solver.subgraph()]) + ) + solver.clear() obs_out_f.write( prediction.to_bytes((num_obs + 7) // 8, byteorder="little") ) - solver.clear() class HyperUFDecoder(MwpfDecoder): @@ -209,9 +217,28 @@ def _helper(m: stim.DetectorErrorModel, reps: int): _helper(model, 1) +def deduplicate_hyperedges( + hyperedges: List[Tuple[List[int], float, int]] +) -> List[Tuple[List[int], float, int]]: + indices: dict[frozenset[int], int] = dict() + result: List[Tuple[List[int], float, int]] = [] + for dets, weight, mask in hyperedges: + dets_set = frozenset(dets) + if dets_set in indices: + idx = indices[dets_set] + p1 = 1 / (1 + math.exp(weight)) + p2 = 1 / (1 + math.exp(result[idx][1])) + p = p1 * (1 - p2) + p2 * (1 - p1) + result[idx] = (dets, math.log((1 - p) / p), mask) + else: + indices[dets_set] = len(result) + result.append((dets, weight, mask)) + return result + + def detector_error_model_to_mwpf_solver_and_fault_masks( model: stim.DetectorErrorModel, decoder_cls: Any = None -) -> Tuple["mwpf.SolverSerialJointSingleHair", np.ndarray]: +) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]: """Convert a stim error model into a NetworkX graph.""" try: @@ -248,6 +275,8 @@ def handle_detector_coords(detector: int, coords: np.ndarray): handle_error=handle_error, handle_detector_coords=handle_detector_coords, ) + # mwpf package panic on duplicate edges, thus we need to handle them here + hyperedges = deduplicate_hyperedges(hyperedges) # fix the input by connecting an edge to all isolated vertices for idx in range(num_detectors): @@ -269,6 +298,10 @@ def handle_detector_coords(detector: int, coords: np.ndarray): # default to the solver with highest accuracy decoder_cls = mwpf.SolverSerialJointSingleHair return ( - decoder_cls(initializer), + ( + decoder_cls(initializer) + if num_detectors > 0 and len(rescaled_edges) > 0 + else None + ), fault_masks, ) diff --git a/glue/sample/src/sinter/_decoding_test.py b/glue/sample/src/sinter/_decoding_test.py index 2ca9fbbca..e7aafc7c6 100644 --- a/glue/sample/src/sinter/_decoding_test.py +++ b/glue/sample/src/sinter/_decoding_test.py @@ -27,6 +27,11 @@ def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]: import fusion_blossom except ImportError: available_decoders.remove('fusion_blossom') + try: + import mwpf + except ImportError: + available_decoders.remove('hypergraph_union_find') + available_decoders.remove('mw_parity_factor') e = os.environ.get('SINTER_PYTEST_CUSTOM_DECODERS') if e is not None: From 82e9af5c6ce6b76ad6765a6f19d11a5a6cf627fc Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Wed, 28 Aug 2024 09:24:20 -0400 Subject: [PATCH 5/7] fixed test errors --- glue/sample/src/sinter/_decoding_mwpf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/glue/sample/src/sinter/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding_mwpf.py index 18643a90b..642ae1aaa 100644 --- a/glue/sample/src/sinter/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding_mwpf.py @@ -229,7 +229,9 @@ def deduplicate_hyperedges( p1 = 1 / (1 + math.exp(weight)) p2 = 1 / (1 + math.exp(result[idx][1])) p = p1 * (1 - p2) + p2 * (1 - p1) - result[idx] = (dets, math.log((1 - p) / p), mask) + # not sure why would this fail? two hyperedges with different masks? + # assert mask == result[idx][2], (result[idx], (dets, weight, mask)) + result[idx] = (dets, math.log((1 - p) / p), result[idx][2]) else: indices[dets_set] = len(result) result.append((dets, weight, mask)) From 2d907bde30a228fb0e64c5369844e80905fe9103 Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Wed, 20 Nov 2024 08:00:11 -0500 Subject: [PATCH 6/7] add `cluster_node_limit` for mwpf decoder to better tune decoding time and accuracy --- .github/workflows/ci.yml | 2 +- glue/sample/src/sinter/_decoding_mwpf.py | 37 +++++++++++++++--------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7f5bb8ea4..704539f66 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -394,7 +394,7 @@ jobs: - run: bazel build :stim_dev_wheel - run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl - run: pip install -e glue/sample - - run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1 + - run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.5 - run: pytest glue/sample - run: dev/doctest_proper.py --module sinter - run: sinter help diff --git a/glue/sample/src/sinter/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding_mwpf.py index 642ae1aaa..67a652026 100644 --- a/glue/sample/src/sinter/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding_mwpf.py @@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError: return ImportError( "The decoder 'MWPF' isn't installed\n" "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" + "For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n" ) @@ -75,12 +75,18 @@ def compile_decoder_for_dem( # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only # grows the clusters until the first valid solution appears; some more optimized solvers uses # one or more plugins to further optimize the solution, which requires longer decoding time. + cluster_node_limit: int = 50, # The maximum number of nodes in a cluster. ) -> CompiledDecoder: solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - dem, decoder_cls=decoder_cls + dem, + decoder_cls=decoder_cls, + cluster_node_limit=cluster_node_limit, ) return MwpfCompiledDecoder( - solver, fault_masks, dem.num_detectors, dem.num_observables + solver, + fault_masks, + dem.num_detectors, + dem.num_observables, ) def decode_via_files( @@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int): def deduplicate_hyperedges( hyperedges: List[Tuple[List[int], float, int]] ) -> List[Tuple[List[int], float, int]]: - indices: dict[frozenset[int], int] = dict() + indices: dict[frozenset[int], Tuple[int, float]] = dict() result: List[Tuple[List[int], float, int]] = [] for dets, weight, mask in hyperedges: dets_set = frozenset(dets) if dets_set in indices: - idx = indices[dets_set] + idx, min_weight = indices[dets_set] p1 = 1 / (1 + math.exp(weight)) p2 = 1 / (1 + math.exp(result[idx][1])) p = p1 * (1 - p2) + p2 * (1 - p1) - # not sure why would this fail? two hyperedges with different masks? - # assert mask == result[idx][2], (result[idx], (dets, weight, mask)) - result[idx] = (dets, math.log((1 - p) / p), result[idx][2]) + # choosing the mask from the most likely error + new_mask = result[idx][2] + if weight < min_weight: + indices[dets_set] = (idx, weight) + new_mask = mask + result[idx] = (dets, math.log((1 - p) / p), new_mask) else: - indices[dets_set] = len(result) + indices[dets_set] = (len(result), weight) result.append((dets, weight, mask)) return result def detector_error_model_to_mwpf_solver_and_fault_masks( - model: stim.DetectorErrorModel, decoder_cls: Any = None + model: stim.DetectorErrorModel, + decoder_cls: Any = None, + cluster_node_limit: int = 50, ) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]: """Convert a stim error model into a NetworkX graph.""" @@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]): # Accept it and keep going, though of course decoding will probably perform terribly. return if p > 0.5: - # mwpf doesn't support negative edge weights. + # mwpf doesn't support negative edge weights (yet, will be supported in the next version). # approximate them as weight 0. p = 0.5 weight = math.log((1 - p) / p) @@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): # mwpf package panic on duplicate edges, thus we need to handle them here hyperedges = deduplicate_hyperedges(hyperedges) - # fix the input by connecting an edge to all isolated vertices + # fix the input by connecting an edge to all isolated vertices; will be supported in the next version for idx in range(num_detectors): if not is_detector_connected[idx]: hyperedges.append(([idx], 0, 0)) @@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): decoder_cls = mwpf.SolverSerialJointSingleHair return ( ( - decoder_cls(initializer) + decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit}) if num_detectors > 0 and len(rescaled_edges) > 0 else None ), From 3b46f3ffc7241f443f18789527187f68a2d74f59 Mon Sep 17 00:00:00 2001 From: Yue Wu Date: Wed, 20 Nov 2024 08:19:50 -0500 Subject: [PATCH 7/7] move _decoding_mwpf to new _decoding folder --- .../src/sinter/_decoding/_decoding_mwpf.py | 37 +- glue/sample/src/sinter/_decoding_mwpf.py | 320 ------------------ 2 files changed, 24 insertions(+), 333 deletions(-) delete mode 100644 glue/sample/src/sinter/_decoding_mwpf.py diff --git a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py index 461cbc0f5..2b69c608e 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py @@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError: return ImportError( "The decoder 'MWPF' isn't installed\n" "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" + "For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n" ) @@ -75,12 +75,18 @@ def compile_decoder_for_dem( # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only # grows the clusters until the first valid solution appears; some more optimized solvers uses # one or more plugins to further optimize the solution, which requires longer decoding time. + cluster_node_limit: int = 50, # The maximum number of nodes in a cluster. ) -> CompiledDecoder: solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - dem, decoder_cls=decoder_cls + dem, + decoder_cls=decoder_cls, + cluster_node_limit=cluster_node_limit, ) return MwpfCompiledDecoder( - solver, fault_masks, dem.num_detectors, dem.num_observables + solver, + fault_masks, + dem.num_detectors, + dem.num_observables, ) def decode_via_files( @@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int): def deduplicate_hyperedges( hyperedges: List[Tuple[List[int], float, int]] ) -> List[Tuple[List[int], float, int]]: - indices: dict[frozenset[int], int] = dict() + indices: dict[frozenset[int], Tuple[int, float]] = dict() result: List[Tuple[List[int], float, int]] = [] for dets, weight, mask in hyperedges: dets_set = frozenset(dets) if dets_set in indices: - idx = indices[dets_set] + idx, min_weight = indices[dets_set] p1 = 1 / (1 + math.exp(weight)) p2 = 1 / (1 + math.exp(result[idx][1])) p = p1 * (1 - p2) + p2 * (1 - p1) - # not sure why would this fail? two hyperedges with different masks? - # assert mask == result[idx][2], (result[idx], (dets, weight, mask)) - result[idx] = (dets, math.log((1 - p) / p), result[idx][2]) + # choosing the mask from the most likely error + new_mask = result[idx][2] + if weight < min_weight: + indices[dets_set] = (idx, weight) + new_mask = mask + result[idx] = (dets, math.log((1 - p) / p), new_mask) else: - indices[dets_set] = len(result) + indices[dets_set] = (len(result), weight) result.append((dets, weight, mask)) return result def detector_error_model_to_mwpf_solver_and_fault_masks( - model: stim.DetectorErrorModel, decoder_cls: Any = None + model: stim.DetectorErrorModel, + decoder_cls: Any = None, + cluster_node_limit: int = 50, ) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]: """Convert a stim error model into a NetworkX graph.""" @@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]): # Accept it and keep going, though of course decoding will probably perform terribly. return if p > 0.5: - # mwpf doesn't support negative edge weights. + # mwpf doesn't support negative edge weights (yet, will be supported in the next version). # approximate them as weight 0. p = 0.5 weight = math.log((1 - p) / p) @@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): # mwpf package panic on duplicate edges, thus we need to handle them here hyperedges = deduplicate_hyperedges(hyperedges) - # fix the input by connecting an edge to all isolated vertices + # fix the input by connecting an edge to all isolated vertices; will be supported in the next version for idx in range(num_detectors): if not is_detector_connected[idx]: hyperedges.append(([idx], 0, 0)) @@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): decoder_cls = mwpf.SolverSerialJointSingleHair return ( ( - decoder_cls(initializer) + decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit}) if num_detectors > 0 and len(rescaled_edges) > 0 else None ), diff --git a/glue/sample/src/sinter/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding_mwpf.py deleted file mode 100644 index 67a652026..000000000 --- a/glue/sample/src/sinter/_decoding_mwpf.py +++ /dev/null @@ -1,320 +0,0 @@ -import math -import pathlib -from typing import Callable, List, TYPE_CHECKING, Tuple, Any, Optional - -import numpy as np -import stim - -from sinter._decoding_decoder_class import Decoder, CompiledDecoder - -if TYPE_CHECKING: - import mwpf - - -def mwpf_import_error() -> ImportError: - return ImportError( - "The decoder 'MWPF' isn't installed\n" - "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n" - ) - - -class MwpfCompiledDecoder(CompiledDecoder): - def __init__( - self, - solver: "mwpf.SolverSerialJointSingleHair", - fault_masks: "np.ndarray", - num_dets: int, - num_obs: int, - ): - self.solver = solver - self.fault_masks = fault_masks - self.num_dets = num_dets - self.num_obs = num_obs - - def decode_shots_bit_packed( - self, - *, - bit_packed_detection_event_data: "np.ndarray", - ) -> "np.ndarray": - num_shots = bit_packed_detection_event_data.shape[0] - predictions = np.zeros(shape=(num_shots, self.num_obs), dtype=np.uint8) - import mwpf - - for shot in range(num_shots): - dets_sparse = np.flatnonzero( - np.unpackbits( - bit_packed_detection_event_data[shot], - count=self.num_dets, - bitorder="little", - ) - ) - syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse) - if self.solver is None: - prediction = 0 - else: - self.solver.solve(syndrome) - prediction = int( - np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]) - ) - self.solver.clear() - predictions[shot] = np.packbits(prediction, bitorder="little") - return predictions - - -class MwpfDecoder(Decoder): - """Use MWPF to predict observables from detection events.""" - - def compile_decoder_for_dem( - self, - *, - dem: "stim.DetectorErrorModel", - decoder_cls: Any = None, # decoder class used to construct the MWPF decoder. - # in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins` - # but just provide different plugins for optimizing the primal and/or dual solutions. - # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only - # grows the clusters until the first valid solution appears; some more optimized solvers uses - # one or more plugins to further optimize the solution, which requires longer decoding time. - cluster_node_limit: int = 50, # The maximum number of nodes in a cluster. - ) -> CompiledDecoder: - solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - dem, - decoder_cls=decoder_cls, - cluster_node_limit=cluster_node_limit, - ) - return MwpfCompiledDecoder( - solver, - fault_masks, - dem.num_detectors, - dem.num_observables, - ) - - def decode_via_files( - self, - *, - num_shots: int, - num_dets: int, - num_obs: int, - dem_path: pathlib.Path, - dets_b8_in_path: pathlib.Path, - obs_predictions_b8_out_path: pathlib.Path, - tmp_dir: pathlib.Path, - decoder_cls: Any = None, - ) -> None: - import mwpf - - error_model = stim.DetectorErrorModel.from_file(dem_path) - solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - error_model, decoder_cls=decoder_cls - ) - num_det_bytes = math.ceil(num_dets / 8) - with open(dets_b8_in_path, "rb") as dets_in_f: - with open(obs_predictions_b8_out_path, "wb") as obs_out_f: - for _ in range(num_shots): - dets_bit_packed = np.fromfile( - dets_in_f, dtype=np.uint8, count=num_det_bytes - ) - if dets_bit_packed.shape != (num_det_bytes,): - raise IOError("Missing dets data.") - dets_sparse = np.flatnonzero( - np.unpackbits( - dets_bit_packed, count=num_dets, bitorder="little" - ) - ) - syndrome = mwpf.SyndromePattern(defect_vertices=dets_sparse) - if solver is None: - prediction = 0 - else: - solver.solve(syndrome) - prediction = int( - np.bitwise_xor.reduce(fault_masks[solver.subgraph()]) - ) - solver.clear() - obs_out_f.write( - prediction.to_bytes((num_obs + 7) // 8, byteorder="little") - ) - - -class HyperUFDecoder(MwpfDecoder): - def compile_decoder_for_dem( - self, *, dem: "stim.DetectorErrorModel" - ) -> CompiledDecoder: - try: - import mwpf - except ImportError as ex: - raise mwpf_import_error() from ex - - return super().compile_decoder_for_dem( - dem=dem, decoder_cls=mwpf.SolverSerialUnionFind - ) - - def decode_via_files( - self, - *, - num_shots: int, - num_dets: int, - num_obs: int, - dem_path: pathlib.Path, - dets_b8_in_path: pathlib.Path, - obs_predictions_b8_out_path: pathlib.Path, - tmp_dir: pathlib.Path, - ) -> None: - try: - import mwpf - except ImportError as ex: - raise mwpf_import_error() from ex - - return super().decode_via_files( - num_shots=num_shots, - num_dets=num_dets, - num_obs=num_obs, - dem_path=dem_path, - dets_b8_in_path=dets_b8_in_path, - obs_predictions_b8_out_path=obs_predictions_b8_out_path, - tmp_dir=tmp_dir, - decoder_cls=mwpf.SolverSerialUnionFind, - ) - - -def iter_flatten_model( - model: stim.DetectorErrorModel, - handle_error: Callable[[float, List[int], List[int]], None], - handle_detector_coords: Callable[[int, np.ndarray], None], -): - det_offset = 0 - coords_offset = np.zeros(100, dtype=np.float64) - - def _helper(m: stim.DetectorErrorModel, reps: int): - nonlocal det_offset - nonlocal coords_offset - for _ in range(reps): - for instruction in m: - if isinstance(instruction, stim.DemRepeatBlock): - _helper(instruction.body_copy(), instruction.repeat_count) - elif isinstance(instruction, stim.DemInstruction): - if instruction.type == "error": - dets: List[int] = [] - frames: List[int] = [] - t: stim.DemTarget - p = instruction.args_copy()[0] - for t in instruction.targets_copy(): - if t.is_relative_detector_id(): - dets.append(t.val + det_offset) - elif t.is_logical_observable_id(): - frames.append(t.val) - handle_error(p, dets, frames) - elif instruction.type == "shift_detectors": - det_offset += instruction.targets_copy()[0] - a = np.array(instruction.args_copy()) - coords_offset[: len(a)] += a - elif instruction.type == "detector": - a = np.array(instruction.args_copy()) - for t in instruction.targets_copy(): - handle_detector_coords( - t.val + det_offset, a + coords_offset[: len(a)] - ) - elif instruction.type == "logical_observable": - pass - else: - raise NotImplementedError() - else: - raise NotImplementedError() - - _helper(model, 1) - - -def deduplicate_hyperedges( - hyperedges: List[Tuple[List[int], float, int]] -) -> List[Tuple[List[int], float, int]]: - indices: dict[frozenset[int], Tuple[int, float]] = dict() - result: List[Tuple[List[int], float, int]] = [] - for dets, weight, mask in hyperedges: - dets_set = frozenset(dets) - if dets_set in indices: - idx, min_weight = indices[dets_set] - p1 = 1 / (1 + math.exp(weight)) - p2 = 1 / (1 + math.exp(result[idx][1])) - p = p1 * (1 - p2) + p2 * (1 - p1) - # choosing the mask from the most likely error - new_mask = result[idx][2] - if weight < min_weight: - indices[dets_set] = (idx, weight) - new_mask = mask - result[idx] = (dets, math.log((1 - p) / p), new_mask) - else: - indices[dets_set] = (len(result), weight) - result.append((dets, weight, mask)) - return result - - -def detector_error_model_to_mwpf_solver_and_fault_masks( - model: stim.DetectorErrorModel, - decoder_cls: Any = None, - cluster_node_limit: int = 50, -) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]: - """Convert a stim error model into a NetworkX graph.""" - - try: - import mwpf - except ImportError as ex: - raise mwpf_import_error() from ex - - num_detectors = model.num_detectors - is_detector_connected = np.full(num_detectors, False, dtype=bool) - hyperedges: List[Tuple[List[int], float, int]] = [] - - def handle_error(p: float, dets: List[int], frame_changes: List[int]): - if p == 0: - return - if len(dets) == 0: - # No symptoms for this error. - # Code probably has distance 1. - # Accept it and keep going, though of course decoding will probably perform terribly. - return - if p > 0.5: - # mwpf doesn't support negative edge weights (yet, will be supported in the next version). - # approximate them as weight 0. - p = 0.5 - weight = math.log((1 - p) / p) - mask = sum(1 << k for k in frame_changes) - is_detector_connected[dets] = True - hyperedges.append((dets, weight, mask)) - - def handle_detector_coords(detector: int, coords: np.ndarray): - pass - - iter_flatten_model( - model, - handle_error=handle_error, - handle_detector_coords=handle_detector_coords, - ) - # mwpf package panic on duplicate edges, thus we need to handle them here - hyperedges = deduplicate_hyperedges(hyperedges) - - # fix the input by connecting an edge to all isolated vertices; will be supported in the next version - for idx in range(num_detectors): - if not is_detector_connected[idx]: - hyperedges.append(([idx], 0, 0)) - - max_weight = max(1e-4, max((w for _, w, _ in hyperedges), default=1)) - rescaled_edges = [ - mwpf.HyperEdge(v, round(w * 2**10 / max_weight) * 2) for v, w, _ in hyperedges - ] - fault_masks = np.array([e[2] for e in hyperedges], dtype=np.uint64) - - initializer = mwpf.SolverInitializer( - num_detectors, # Total number of nodes. - rescaled_edges, # Weighted edges. - ) - - if decoder_cls is None: - # default to the solver with highest accuracy - decoder_cls = mwpf.SolverSerialJointSingleHair - return ( - ( - decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit}) - if num_detectors > 0 and len(rescaled_edges) > 0 - else None - ), - fault_masks, - )