diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aa535b4f..6173fafb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -431,7 +431,7 @@ jobs: - run: bazel build :stim_dev_wheel - run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl - run: pip install -e glue/sample - - run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1 + - run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.5 - run: pytest glue/sample - run: dev/doctest_proper.py --module sinter - run: sinter help diff --git a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py index dfe1129c..915fd0b0 100644 --- a/glue/sample/src/sinter/_decoding/_decoding_mwpf.py +++ b/glue/sample/src/sinter/_decoding/_decoding_mwpf.py @@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError: return ImportError( "The decoder 'MWPF' isn't installed\n" "To fix this, install the python package 'MWPF' into your environment.\n" - "For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n" + "For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n" ) @@ -75,12 +75,18 @@ def compile_decoder_for_dem( # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only # grows the clusters until the first valid solution appears; some more optimized solvers uses # one or more plugins to further optimize the solution, which requires longer decoding time. + cluster_node_limit: int = 50, # The maximum number of nodes in a cluster. ) -> CompiledDecoder: solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks( - dem, decoder_cls=decoder_cls + dem, + decoder_cls=decoder_cls, + cluster_node_limit=cluster_node_limit, ) return MwpfCompiledDecoder( - solver, fault_masks, dem.num_detectors, dem.num_observables + solver, + fault_masks, + dem.num_detectors, + dem.num_observables, ) def decode_via_files( @@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int): def deduplicate_hyperedges( hyperedges: List[Tuple[List[int], float, int]] ) -> List[Tuple[List[int], float, int]]: - indices: dict[frozenset[int], int] = dict() + indices: dict[frozenset[int], Tuple[int, float]] = dict() result: List[Tuple[List[int], float, int]] = [] for dets, weight, mask in hyperedges: dets_set = frozenset(dets) if dets_set in indices: - idx = indices[dets_set] + idx, min_weight = indices[dets_set] p1 = 1 / (1 + math.exp(weight)) p2 = 1 / (1 + math.exp(result[idx][1])) p = p1 * (1 - p2) + p2 * (1 - p1) - # not sure why would this fail? two hyperedges with different masks? - # assert mask == result[idx][2], (result[idx], (dets, weight, mask)) - result[idx] = (dets, math.log((1 - p) / p), result[idx][2]) + # choosing the mask from the most likely error + new_mask = result[idx][2] + if weight < min_weight: + indices[dets_set] = (idx, weight) + new_mask = mask + result[idx] = (dets, math.log((1 - p) / p), new_mask) else: - indices[dets_set] = len(result) + indices[dets_set] = (len(result), weight) result.append((dets, weight, mask)) return result def detector_error_model_to_mwpf_solver_and_fault_masks( - model: stim.DetectorErrorModel, decoder_cls: Any = None + model: stim.DetectorErrorModel, + decoder_cls: Any = None, + cluster_node_limit: int = 50, ) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]: """Convert a stim error model into a NetworkX graph.""" @@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]): # Accept it and keep going, though of course decoding will probably perform terribly. return if p > 0.5: - # mwpf doesn't support negative edge weights. + # mwpf doesn't support negative edge weights (yet, will be supported in the next version). # approximate them as weight 0. p = 0.5 weight = math.log((1 - p) / p) @@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): # mwpf package panic on duplicate edges, thus we need to handle them here hyperedges = deduplicate_hyperedges(hyperedges) - # fix the input by connecting an edge to all isolated vertices + # fix the input by connecting an edge to all isolated vertices; will be supported in the next version for idx in range(num_detectors): if not is_detector_connected[idx]: hyperedges.append(([idx], 0, 0)) @@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray): decoder_cls = mwpf.SolverSerialJointSingleHair return ( ( - decoder_cls(initializer) + decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit}) if num_detectors > 0 and len(rescaled_edges) > 0 else None ),