diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c6709c6f..a1b050ac1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -165,7 +165,7 @@ jobs: - run: mv dist/* output/stim - run: mv glue/cirq/dist/* output/stimcirq - run: mv glue/sample/dist/* output/sinter - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.1.7 with: name: dist path: | @@ -185,7 +185,7 @@ jobs: if: github.ref == 'refs/heads/main' runs-on: ubuntu-latest steps: - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4.1.7 with: name: dist path: dist diff --git a/dev/gen_sinter_api_reference.py b/dev/gen_sinter_api_reference.py index c3e348af0..a4b907f08 100644 --- a/dev/gen_sinter_api_reference.py +++ b/dev/gen_sinter_api_reference.py @@ -48,6 +48,25 @@ def main(): ``` '''.strip()) + replace_rules = [] + for package in ['stim', 'sinter']: + p = __import__(package) + for name in dir(p): + x = getattr(p, name) + if isinstance(x, type) and 'class' in str(x): + desired_name = f'{package}.{name}' + if '._' in str(x): + bad_name = str(x).split("'")[1] + replace_rules.append((bad_name, desired_name)) + lonely_name = desired_name.split(".")[-1] + for q in ['"', "'"]: + replace_rules.append(('ForwardRef(' + q + lonely_name + q + ')', desired_name)) + replace_rules.append(('ForwardRef(' + q + desired_name + q + ')', desired_name)) + replace_rules.append((q + desired_name + q, desired_name)) + replace_rules.append((q + lonely_name + q, desired_name)) + replace_rules.append(('ForwardRef(' + desired_name + ')', desired_name)) + replace_rules.append(('ForwardRef(' + lonely_name + ')', desired_name)) + for obj in objects: print() print(f'') @@ -58,7 +77,10 @@ def main(): print(f'# (in class {".".join(obj.full_name.split(".")[:-1])})') else: print(f'# (at top-level in the sinter module)') - print('\n'.join(obj.lines)) + for line in obj.lines: + for a, b in replace_rules: + line = line.replace(a, b) + print(line) print("```") diff --git a/dev/util_gen_stub_file.py b/dev/util_gen_stub_file.py index 26a34530c..2c6519994 100644 --- a/dev/util_gen_stub_file.py +++ b/dev/util_gen_stub_file.py @@ -1,5 +1,4 @@ import dataclasses -import sys import types from typing import Any from typing import Optional, Iterator, List @@ -9,6 +8,7 @@ keep = { "__add__", + "__radd__", "__eq__", "__call__", "__ge__", @@ -224,17 +224,6 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Opt text += '@abc.abstractmethod\n' sig_name = f'{term_name}{inspect.signature(obj)}' text += "\n".join(splay_signature(f"def {sig_name}:")) - text = text.replace('''ForwardRef('sinter.TaskStats')''', 'sinter.TaskStats') - text = text.replace('''ForwardRef('sinter.Task')''', 'sinter.Task') - text = text.replace('''ForwardRef('sinter.Progress')''', 'sinter.Progress') - text = text.replace('''ForwardRef('sinter.Decoder')''', 'sinter.Decoder') - text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats") - text = text.replace('sinter._decoding_decoder_class.CompiledDecoder', 'sinter.CompiledDecoder') - text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats") - text = text.replace("'stim.Circuit'", "stim.Circuit") - text = text.replace("'stim.DetectorErrorModel'", "stim.DetectorErrorModel") - text = text.replace("'sinter.CollectionOptions'", "sinter.CollectionOptions") - text = text.replace("'sinter.Fit'", 'sinter.Fit') # Replace default value lambdas with their source. if 'lambda' in str(text): diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index f509eef04..3b7e34cfa 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -1610,6 +1610,7 @@ def diagram( *, tick: Union[None, int, range] = None, filter_coords: Iterable[Union[Iterable[float], stim.DemTarget]] = ((),), + rows: int | None = None, ) -> 'stim._DiagramHelper': """Returns a diagram of the circuit, from a variety of options. diff --git a/doc/sinter_api.md b/doc/sinter_api.md index f940813f3..3093e41af 100644 --- a/doc/sinter_api.md +++ b/doc/sinter_api.md @@ -12,11 +12,16 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`sinter.CollectionOptions.combine`](#sinter.CollectionOptions.combine) - [`sinter.CompiledDecoder`](#sinter.CompiledDecoder) - [`sinter.CompiledDecoder.decode_shots_bit_packed`](#sinter.CompiledDecoder.decode_shots_bit_packed) +- [`sinter.CompiledSampler`](#sinter.CompiledSampler) + - [`sinter.CompiledSampler.handles_throttling`](#sinter.CompiledSampler.handles_throttling) + - [`sinter.CompiledSampler.sample`](#sinter.CompiledSampler.sample) - [`sinter.Decoder`](#sinter.Decoder) - [`sinter.Decoder.compile_decoder_for_dem`](#sinter.Decoder.compile_decoder_for_dem) - [`sinter.Decoder.decode_via_files`](#sinter.Decoder.decode_via_files) - [`sinter.Fit`](#sinter.Fit) - [`sinter.Progress`](#sinter.Progress) +- [`sinter.Sampler`](#sinter.Sampler) + - [`sinter.Sampler.compiled_sampler_for_task`](#sinter.Sampler.compiled_sampler_for_task) - [`sinter.Task`](#sinter.Task) - [`sinter.Task.__init__`](#sinter.Task.__init__) - [`sinter.Task.strong_id`](#sinter.Task.strong_id) @@ -26,6 +31,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`sinter.TaskStats`](#sinter.TaskStats) - [`sinter.TaskStats.to_anon_stats`](#sinter.TaskStats.to_anon_stats) - [`sinter.TaskStats.to_csv_line`](#sinter.TaskStats.to_csv_line) + - [`sinter.TaskStats.with_edits`](#sinter.TaskStats.with_edits) - [`sinter.better_sorted_str_terms`](#sinter.better_sorted_str_terms) - [`sinter.collect`](#sinter.collect) - [`sinter.comma_separated_key_values`](#sinter.comma_separated_key_values) @@ -257,6 +263,73 @@ def decode_shots_bit_packed( """ ``` + +```python +# sinter.CompiledSampler + +# (at top-level in the sinter module) +class CompiledSampler(metaclass=abc.ABCMeta): + """A sampler that has been configured for efficiently sampling some task. + """ +``` + + +```python +# sinter.CompiledSampler.handles_throttling + +# (in class sinter.CompiledSampler) +def handles_throttling( + self, +) -> bool: + """Return True to disable sinter wrapping samplers with throttling. + + By default, sinter will wrap samplers so that they initially only do + a small number of shots then slowly ramp up. Sometimes this behavior + is not desired (e.g. in unit tests). Override this method to return True + to disable it. + """ +``` + + +```python +# sinter.CompiledSampler.sample + +# (in class sinter.CompiledSampler) +@abc.abstractmethod +def sample( + self, + suggested_shots: int, +) -> sinter.AnonTaskStats: + """Samples shots and returns statistics. + + Args: + suggested_shots: The number of shots being requested. The sampler + may perform more shots or fewer shots than this, so technically + this argument can just be ignored. If a sampler is optimized for + a specific batch size, it can simply return one batch per call + regardless of this parameter. + + However, this parameter is a useful hint about the amount of + work being done. The sampler can use this to optimize its + behavior. For example, it could adjust its batch size downward + if the suggested shots is very small. Whereas if the suggested + shots is very high, the sampler should focus entirely on + achieving the best possible throughput. + + Note that, in typical workloads, the sampler will be called + repeatedly with the same value of suggested_shots. Therefore it + is reasonable to allocate buffers sized to accomodate the + current suggested_shots, expecting them to be useful again for + the next shot. + + Returns: + A sinter.AnonTaskStats saying how many shots were actually taken, + how many errors were seen, etc. + + The returned stats must have at least one shot. + """ +``` + ```python # sinter.Decoder @@ -385,9 +458,9 @@ class Fit: of the best fit's square error, or whose likelihood was within some maximum Bayes factor of the max likelihood hypothesis. """ - low: float - best: float - high: float + low: Optional[float] + best: Optional[float] + high: Optional[float] ``` @@ -409,10 +482,45 @@ class Progress: collection status, such as the number of tasks left and the estimated time to completion for each task. """ - new_stats: Tuple[sinter._task_stats.TaskStats, ...] + new_stats: Tuple[sinter.TaskStats, ...] status_message: str ``` + +```python +# sinter.Sampler + +# (at top-level in the sinter module) +class Sampler(metaclass=abc.ABCMeta): + """A strategy for producing stats from tasks. + + Call `sampler.compiled_sampler_for_task(task)` to get a compiled sampler for + a task, then call `compiled_sampler.sample(shots)` to collect statistics. + + A sampler differs from a `sinter.Decoder` because the sampler is responsible + for the full sampling process (e.g. simulating the circuit), whereas a + decoder can do nothing except predict observable flips from detection event + data. This prevents the decoders from cheating, but makes them less flexible + overall. A sampler can do things like use simulators other than stim, or + really anything at all as long as it ends with returning statistics about + shot counts, error counts, and etc. + """ +``` + + +```python +# sinter.Sampler.compiled_sampler_for_task + +# (in class sinter.Sampler) +@abc.abstractmethod +def compiled_sampler_for_task( + self, + task: sinter.Task, +) -> sinter.CompiledSampler: + """Creates, configures, and returns an object for sampling the task. + """ +``` + ```python # sinter.Task @@ -475,9 +583,9 @@ class Task: def __init__( self, *, - circuit: Optional[ForwardRef(stim.Circuit)] = None, + circuit: Optional[stim.Circuit] = None, decoder: Optional[str] = None, - detector_error_model: Optional[ForwardRef(stim.DetectorErrorModel)] = None, + detector_error_model: Optional[stim.DetectorErrorModel] = None, postselection_mask: Optional[np.ndarray] = None, postselected_observables_mask: Optional[np.ndarray] = None, json_metadata: Any = None, @@ -699,7 +807,7 @@ class TaskStats: # (in class sinter.TaskStats) def to_anon_stats( self, -) -> sinter._anon_task_stats.AnonTaskStats: +) -> sinter.AnonTaskStats: """Returns a `sinter.AnonTaskStats` with the same statistics. Examples: @@ -745,6 +853,25 @@ def to_csv_line( """ ``` + +```python +# sinter.TaskStats.with_edits + +# (in class sinter.TaskStats) +def with_edits( + self, + *, + strong_id: Optional[str] = None, + decoder: Optional[str] = None, + json_metadata: Optional[Any] = None, + shots: Optional[int] = None, + errors: Optional[int] = None, + discards: Optional[int] = None, + seconds: Optional[float] = None, + custom_counts: Optional[Counter[str]] = None, +) -> sinter.TaskStats: +``` + ```python # sinter.better_sorted_str_terms @@ -809,7 +936,7 @@ def collect( start_batch_size: Optional[int] = None, print_progress: bool = False, hint_num_tasks: Optional[int] = None, - custom_decoders: Optional[Dict[str, sinter.Decoder]] = None, + custom_decoders: Optional[Dict[str, Union[sinter.Decoder, sinter.Sampler]]] = None, custom_error_count_key: Optional[str] = None, allowed_cpu_affinity_ids: Optional[Iterable[int]] = None, ) -> List[sinter.TaskStats]: @@ -1124,7 +1251,7 @@ def iter_collect( num_workers: int, tasks: Union[Iterator[sinter.Task], Iterable[sinter.Task]], hint_num_tasks: Optional[int] = None, - additional_existing_data: Optional[sinter._existing_data.ExistingData] = None, + additional_existing_data: Union[NoneType, Dict[str, sinter.TaskStats], Iterable[sinter.TaskStats]] = None, max_shots: Optional[int] = None, max_errors: Optional[int] = None, decoders: Optional[Iterable[str]] = None, @@ -1133,7 +1260,7 @@ def iter_collect( start_batch_size: Optional[int] = None, count_observable_error_combos: bool = False, count_detection_events: bool = False, - custom_decoders: Optional[Dict[str, sinter.Decoder]] = None, + custom_decoders: Optional[Dict[str, Union[sinter.Decoder, sinter.Sampler]]] = None, custom_error_count_key: Optional[str] = None, allowed_cpu_affinity_ids: Optional[Iterable[int]] = None, ) -> Iterator[sinter.Progress]: @@ -1337,6 +1464,7 @@ def plot_discard_rate( filter_func: Callable[[sinter.TaskStats], Any] = lambda _: True, plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(), highlight_max_likelihood_factor: Optional[float] = 1000.0, + point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None, ) -> None: """Plots discard rates in curves with uncertainty highlights. @@ -1353,11 +1481,21 @@ def plot_discard_rate( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -1370,6 +1508,7 @@ def plot_discard_rate( highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is. Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood hypothesis will be highlighted. + point_label_func: Optional. Specifies text to draw next to data points. """ ``` @@ -1390,6 +1529,7 @@ def plot_error_rate( plot_args_func: Callable[[int, ~TCurveId, List[sinter.TaskStats]], Dict[str, Any]] = lambda index, group_key, group_stats: dict(), highlight_max_likelihood_factor: Optional[float] = 1000.0, line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None, + point_label_func: Callable[[sinter.TaskStats], Any] = lambda _: None, ) -> None: """Plots error rates in curves with uncertainty highlights. @@ -1410,11 +1550,21 @@ def plot_error_rate( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -1430,6 +1580,7 @@ def plot_error_rate( line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line fit to every curve. The scales determine how to transform the coordinates before performing the fit, and can be set to 'linear', 'sqrt', or 'log'. + point_label_func: Optional. Specifies text to draw next to data points. """ ``` @@ -1712,11 +1863,11 @@ def read_stats_from_csv_files( # (at top-level in the sinter module) def shot_error_rate_to_piece_error_rate( - shot_error_rate: Union[float, ForwardRef(sinter.Fit)], + shot_error_rate: Union[float, sinter.Fit], *, pieces: float, values: float = 1, -) -> Union[float, ForwardRef(sinter.Fit)]: +) -> Union[float, sinter.Fit]: """Convert from total error rate to per-piece error rate. Args: diff --git a/doc/stim.pyi b/doc/stim.pyi index e20cbea04..2ffe7d478 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -1010,6 +1010,7 @@ class Circuit: *, tick: Union[None, int, range] = None, filter_coords: Iterable[Union[Iterable[float], stim.DemTarget]] = ((),), + rows: int | None = None, ) -> 'stim._DiagramHelper': """Returns a diagram of the circuit, from a variety of options. diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index e20cbea04..2ffe7d478 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -1010,6 +1010,7 @@ class Circuit: *, tick: Union[None, int, range] = None, filter_coords: Iterable[Union[Iterable[float], stim.DemTarget]] = ((),), + rows: int | None = None, ) -> 'stim._DiagramHelper': """Returns a diagram of the circuit, from a variety of options. diff --git a/glue/sample/setup.py b/glue/sample/setup.py index f6c5da088..7fc73a6d9 100644 --- a/glue/sample/setup.py +++ b/glue/sample/setup.py @@ -37,6 +37,6 @@ install_requires=requirements, tests_require=['pytest', 'pymatching'], entry_points={ - 'console_scripts': ['sinter=sinter._main:main'], + 'console_scripts': ['sinter=sinter._command._main:main'], }, ) diff --git a/glue/sample/src/sinter/__init__.py b/glue/sample/src/sinter/__init__.py index 1237b79b4..a8ccc6788 100644 --- a/glue/sample/src/sinter/__init__.py +++ b/glue/sample/src/sinter/__init__.py @@ -1,26 +1,27 @@ __version__ = '1.14.dev0' -from sinter._anon_task_stats import ( - AnonTaskStats, -) from sinter._collection import ( collect, iter_collect, post_selection_mask_from_4th_coord, Progress, ) -from sinter._collection_options import ( +from sinter._data import ( + AnonTaskStats, CollectionOptions, -) -from sinter._csv_out import ( CSV_HEADER, -) -from sinter._decoding_all_built_in_decoders import ( - BUILT_IN_DECODERS, -) -from sinter._existing_data import ( read_stats_from_csv_files, stats_from_csv_files, + Task, + TaskStats, +) +from sinter._decoding import ( + CompiledDecoder, + Decoder, + BUILT_IN_DECODERS, + BUILT_IN_SAMPLERS, + Sampler, + CompiledSampler, ) from sinter._probability_util import ( comma_separated_key_values, @@ -38,19 +39,9 @@ plot_error_rate, group_by, ) -from sinter._task import ( - Task, -) -from sinter._task_stats import ( - TaskStats, -) from sinter._predict import ( predict_discards_bit_packed, predict_observables_bit_packed, predict_on_disk, predict_observables, ) -from sinter._decoding_decoder_class import ( - CompiledDecoder, - Decoder, -) diff --git a/glue/sample/src/sinter/_collection/__init__.py b/glue/sample/src/sinter/_collection/__init__.py new file mode 100644 index 000000000..271e17c7e --- /dev/null +++ b/glue/sample/src/sinter/_collection/__init__.py @@ -0,0 +1,10 @@ +from sinter._collection._collection import ( + collect, + iter_collect, + post_selection_mask_from_4th_coord, + post_selection_mask_from_predicate, + Progress, +) +from sinter._collection._printer import ( + ThrottledProgressPrinter, +) diff --git a/glue/sample/src/sinter/_collection.py b/glue/sample/src/sinter/_collection/_collection.py similarity index 87% rename from glue/sample/src/sinter/_collection.py rename to glue/sample/src/sinter/_collection/_collection.py index 40bfceef6..54f875ba5 100644 --- a/glue/sample/src/sinter/_collection.py +++ b/glue/sample/src/sinter/_collection/_collection.py @@ -1,19 +1,15 @@ import contextlib import dataclasses import pathlib -from typing import Any -from typing import Callable, Iterator, Optional, Union, Iterable, List, TYPE_CHECKING, Tuple, Dict +from typing import Any, Callable, Iterator, Optional, Union, Iterable, List, TYPE_CHECKING, Tuple, Dict import math import numpy as np import stim -from sinter._collection_options import CollectionOptions -from sinter._csv_out import CSV_HEADER -from sinter._collection_work_manager import CollectionWorkManager -from sinter._existing_data import ExistingData -from sinter._printer import ThrottledProgressPrinter -from sinter._task_stats import TaskStats +from sinter._data import CSV_HEADER, ExistingData, TaskStats, CollectionOptions, Task +from sinter._collection._collection_manager import CollectionManager +from sinter._collection._printer import ThrottledProgressPrinter if TYPE_CHECKING: import sinter @@ -42,7 +38,7 @@ def iter_collect(*, tasks: Union[Iterator['sinter.Task'], Iterable['sinter.Task']], hint_num_tasks: Optional[int] = None, - additional_existing_data: Optional[ExistingData] = None, + additional_existing_data: Union[None, dict[str, 'TaskStats'], Iterable['TaskStats']] = None, max_shots: Optional[int] = None, max_errors: Optional[int] = None, decoders: Optional[Iterable[str]] = None, @@ -51,7 +47,7 @@ def iter_collect(*, start_batch_size: Optional[int] = None, count_observable_error_combos: bool = False, count_detection_events: bool = False, - custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None, + custom_decoders: Optional[Dict[str, Union['sinter.Decoder', 'sinter.Sampler']]] = None, custom_error_count_key: Optional[str] = None, allowed_cpu_affinity_ids: Optional[Iterable[int]] = None, ) -> Iterator['sinter.Progress']: @@ -156,6 +152,19 @@ def iter_collect(*, >>> print(total_shots) 200 """ + existing_data: dict[str, TaskStats] + if isinstance(additional_existing_data, ExistingData): + existing_data = additional_existing_data.data + elif isinstance(additional_existing_data, dict): + existing_data = additional_existing_data + elif additional_existing_data is None: + existing_data = {} + else: + acc = ExistingData() + for stat in additional_existing_data: + acc.add_sample(stat) + existing_data = acc.data + if isinstance(decoders, str): decoders = [decoders] @@ -166,50 +175,65 @@ def iter_collect(*, except TypeError: pass - with CollectionWorkManager( - tasks_iter=iter(tasks), - global_collection_options=CollectionOptions( + if decoders is not None: + old_tasks = tasks + tasks = ( + Task( + circuit=task.circuit, + decoder=decoder, + detector_error_model=task.detector_error_model, + postselection_mask=task.postselection_mask, + postselected_observables_mask=task.postselected_observables_mask, + json_metadata=task.json_metadata, + collection_options=task.collection_options, + circuit_path=task.circuit_path, + ) + for task in old_tasks + for decoder in (decoders if task.decoder is None else [task.decoder]) + ) + + progress_log: list[Optional[TaskStats]] = [] + def log_progress(e: Optional[TaskStats]): + progress_log.append(e) + with CollectionManager( + num_workers=num_workers, + tasks=tasks, + collection_options=CollectionOptions( max_shots=max_shots, max_errors=max_errors, max_batch_seconds=max_batch_seconds, start_batch_size=start_batch_size, max_batch_size=max_batch_size, ), - decoders=decoders, + existing_data=existing_data, count_observable_error_combos=count_observable_error_combos, count_detection_events=count_detection_events, - additional_existing_data=additional_existing_data, - custom_decoders=custom_decoders, custom_error_count_key=custom_error_count_key, + custom_decoders=custom_decoders or {}, allowed_cpu_affinity_ids=allowed_cpu_affinity_ids, + worker_flush_period=max_batch_seconds or 120, + progress_callback=log_progress, ) as manager: try: yield Progress( new_stats=(), status_message=f"Starting {num_workers} workers..." ) - manager.start_workers(num_workers) - - yield Progress( - new_stats=(), - status_message="Finding work..." - ) - manager.fill_work_queue() - yield Progress( - new_stats=(), - status_message=manager.status(num_circuits=hint_num_tasks) - ) - - while manager.fill_work_queue(): - # Wait for a worker to finish a job. - sample = manager.wait_for_next_sample() - manager.fill_work_queue() + manager.start_workers() + manager.start_distributing_work() + + while manager.task_states: + manager.process_message() + if progress_log: + vals = list(progress_log) + progress_log.clear() + for e in vals: + if e is not None: + yield Progress( + new_stats=(e,), + status_message=manager.status_message(), + ) - # Report the incremental results. - yield Progress( - new_stats=(sample,) if sample.shots > 0 else (), - status_message=manager.status(num_circuits=hint_num_tasks), - ) except KeyboardInterrupt: yield Progress( new_stats=(), @@ -234,7 +258,7 @@ def collect(*, start_batch_size: Optional[int] = None, print_progress: bool = False, hint_num_tasks: Optional[int] = None, - custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None, + custom_decoders: Optional[Dict[str, Union['sinter.Decoder', 'sinter.Sampler']]] = None, custom_error_count_key: Optional[str] = None, allowed_cpu_affinity_ids: Optional[Iterable[int]] = None, ) -> List['sinter.TaskStats']: @@ -356,7 +380,7 @@ def collect(*, progress_printer = ThrottledProgressPrinter( outs=[], print_progress=print_progress, - min_progress_delay=1, + min_progress_delay=0.1, ) with contextlib.ExitStack() as exit_stack: # Open save/resume file. diff --git a/glue/sample/src/sinter/_collection/_collection_manager.py b/glue/sample/src/sinter/_collection/_collection_manager.py new file mode 100644 index 000000000..3c331355d --- /dev/null +++ b/glue/sample/src/sinter/_collection/_collection_manager.py @@ -0,0 +1,577 @@ +import collections +import contextlib +import math +import multiprocessing +import os +import pathlib +import queue +import tempfile +import threading +from typing import Any, Optional, List, Dict, Iterable, Callable, Tuple +from typing import Union +from typing import cast + +from sinter._collection._collection_worker_loop import collection_worker_loop +from sinter._collection._mux_sampler import MuxSampler +from sinter._collection._sampler_ramp_throttled import RampThrottledSampler +from sinter._data import CollectionOptions, Task, AnonTaskStats, TaskStats +from sinter._decoding import Sampler, Decoder + + +class _ManagedWorkerState: + def __init__(self, worker_id: int, *, cpu_pin: Optional[int] = None): + self.worker_id: int = worker_id + self.process: Union[multiprocessing.Process, threading.Thread, None] = None + self.input_queue: Optional[multiprocessing.Queue[Tuple[str, Any]]] = None + self.assigned_work_key: Any = None + self.asked_to_drop_shots: int = 0 + self.cpu_pin = cpu_pin + + # Shots transfer into this field when manager sends shot requests to workers. + # Shots transfer out of this field when clients flush results or respond to work return requests. + self.assigned_shots: int = 0 + + def send_message(self, message: Any): + self.input_queue.put(message) + + def ask_to_return_all_shots(self): + if self.asked_to_drop_shots == 0 and self.assigned_shots > 0: + self.send_message(( + 'return_shots', + ( + self.assigned_work_key, + self.assigned_shots, + ), + )) + self.asked_to_drop_shots = self.assigned_shots + + def has_returned_all_shots(self) -> bool: + return self.assigned_shots == 0 and self.asked_to_drop_shots == 0 + + def is_available_to_reassign(self) -> bool: + return self.assigned_work_key is None + + +class _ManagedTaskState: + def __init__(self, *, partial_task: Task, strong_id: str, shots_left: int, errors_left: int): + self.partial_task = partial_task + self.strong_id = strong_id + self.shots_left = shots_left + self.errors_left = errors_left + self.shots_unassigned = shots_left + self.shot_return_requests = 0 + self.assigned_soft_error_flush_threshold: int = errors_left + self.workers_assigned: list[int] = [] + + def is_completed(self) -> bool: + return self.shots_left <= 0 or self.errors_left <= 0 + + +class CollectionManager: + def __init__( + self, + *, + existing_data: Dict[Any, TaskStats], + collection_options: CollectionOptions, + custom_decoders: dict[str, Union[Decoder, Sampler]], + num_workers: int, + worker_flush_period: float, + tasks: Iterable[Task], + progress_callback: Callable[[Optional[TaskStats]], None], + allowed_cpu_affinity_ids: Optional[Iterable[int]], + count_observable_error_combos: bool = False, + count_detection_events: bool = False, + custom_error_count_key: Optional[str] = None, + use_threads_for_debugging: bool = False, + ): + assert isinstance(custom_decoders, dict) + self.existing_data = existing_data + self.num_workers: int = num_workers + self.custom_decoders = custom_decoders + self.worker_flush_period: float = worker_flush_period + self.progress_callback = progress_callback + self.collection_options = collection_options + self.partial_tasks: list[Task] = list(tasks) + self.task_strong_ids: List[Optional[str]] = [None] * len(self.partial_tasks) + self.allowed_cpu_affinity_ids = None if allowed_cpu_affinity_ids is None else sorted(set(allowed_cpu_affinity_ids)) + self.count_observable_error_combos = count_observable_error_combos + self.count_detection_events = count_detection_events + self.custom_error_count_key = custom_error_count_key + self.use_threads_for_debugging = use_threads_for_debugging + + self.shared_worker_output_queue: Optional[multiprocessing.SimpleQueue[Tuple[str, int, Any]]] = None + self.task_states: Dict[Any, _ManagedTaskState] = {} + self.started: bool = False + self.total_collected = {k: v.to_anon_stats() for k, v in existing_data.items()} + + if self.allowed_cpu_affinity_ids is None: + cpus = range(os.cpu_count()) + else: + num_cpus = os.cpu_count() + cpus = [e for e in self.allowed_cpu_affinity_ids if e < num_cpus] + self.worker_states: List[_ManagedWorkerState] = [] + for index in range(num_workers): + cpu_pin = None if len(cpus) == 0 else cpus[index % len(cpus)] + self.worker_states.append(_ManagedWorkerState(index, cpu_pin=cpu_pin)) + self.tmp_dir: Optional[pathlib.Path] = None + + def __enter__(self): + self.exit_stack = contextlib.ExitStack().__enter__() + self.tmp_dir = pathlib.Path(self.exit_stack.enter_context(tempfile.TemporaryDirectory())) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.hard_stop() + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + self.exit_stack = None + self.tmp_dir = None + + def start_workers(self, *, actually_start_worker_processes: bool = True): + assert not self.started + + sampler = RampThrottledSampler( + sub_sampler=MuxSampler( + custom_decoders=self.custom_decoders, + count_observable_error_combos=self.count_observable_error_combos, + count_detection_events=self.count_detection_events, + tmp_dir=self.tmp_dir, + ), + target_batch_seconds=1, + max_batch_shots=1024, + ) + + self.started = True + current_method = multiprocessing.get_start_method() + try: + # To ensure the child processes do not accidentally share ANY state + # related to random number generation, we use 'spawn' instead of 'fork'. + multiprocessing.set_start_method('spawn', force=True) + # Create queues after setting start method to work around a deadlock + # bug that occurs otherwise. + self.shared_worker_output_queue = multiprocessing.SimpleQueue() + + for worker_id in range(self.num_workers): + worker_state = self.worker_states[worker_id] + worker_state.input_queue = multiprocessing.Queue() + worker_state.input_queue.cancel_join_thread() + worker_state.assigned_work_key = None + args = ( + self.worker_flush_period, + worker_id, + sampler, + worker_state.input_queue, + self.shared_worker_output_queue, + worker_state.cpu_pin, + self.custom_error_count_key, + ) + if self.use_threads_for_debugging: + worker_state.process = threading.Thread( + target=collection_worker_loop, + args=args, + ) + else: + worker_state.process = multiprocessing.Process( + target=collection_worker_loop, + args=args, + ) + + if actually_start_worker_processes: + worker_state.process.start() + finally: + multiprocessing.set_start_method(current_method, force=True) + + def start_distributing_work(self): + self._compute_task_ids() + self._distribute_work() + + def _compute_task_ids(self): + idle_worker_ids = list(range(self.num_workers)) + unknown_task_ids = list(range(len(self.partial_tasks))) + worker_to_task_map = {} + while worker_to_task_map or unknown_task_ids: + while idle_worker_ids and unknown_task_ids: + worker_id = idle_worker_ids.pop() + unknown_task_id = unknown_task_ids.pop() + worker_to_task_map[worker_id] = unknown_task_id + self.worker_states[worker_id].send_message(('compute_strong_id', self.partial_tasks[unknown_task_id])) + + try: + message = self.shared_worker_output_queue.get() + message_type, worker_id, message_body = message + if message_type == 'computed_strong_id': + assert worker_id in worker_to_task_map + assert isinstance(message_body, str) + self.task_strong_ids[worker_to_task_map.pop(worker_id)] = message_body + idle_worker_ids.append(worker_id) + elif message_type == 'stopped_due_to_exception': + cur_task, cur_shots_left, unflushed_work_done, traceback, ex = message_body + raise ValueError(f'Worker failed: traceback={traceback}') from ex + else: + raise NotImplementedError(f'{message_type=}') + self.progress_callback(None) + except queue.Empty: + pass + + assert len(idle_worker_ids) == self.num_workers + seen = set() + for k in range(len(self.partial_tasks)): + options = self.partial_tasks[k].collection_options.combine(self.collection_options) + key: str = self.task_strong_ids[k] + if key in seen: + raise ValueError(f'Same task given twice: {self.partial_tasks[k]!r}') + seen.add(key) + + shots_left = options.max_shots + errors_left = options.max_errors + if errors_left is None: + errors_left = shots_left + errors_left = min(errors_left, shots_left) + if key in self.existing_data: + val = self.existing_data[key] + shots_left -= val.shots + if self.custom_error_count_key is None: + errors_left -= val.errors + else: + errors_left -= val.custom_counts[self.custom_error_count_key] + if shots_left <= 0: + continue + self.task_states[key] = _ManagedTaskState( + partial_task=self.partial_tasks[k], + strong_id=key, + shots_left=shots_left, + errors_left=errors_left, + ) + if self.task_states[key].is_completed(): + del self.task_states[key] + + def hard_stop(self): + if not self.started: + return + + removed_workers = [state.process for state in self.worker_states] + for state in self.worker_states: + if isinstance(state.process, threading.Thread): + state.send_message('stop') + state.process = None + state.assigned_work_key = None + state.input_queue = None + self.shared_worker_output_queue = None + self.started = False + self.task_states.clear() + + # SIGKILL everything. + for w in removed_workers: + if isinstance(w, multiprocessing.Process): + w.kill() + # Wait for them to be done. + for w in removed_workers: + w.join() + + def _handle_task_progress(self, task_id: Any): + task_state = self.task_states[task_id] + if task_state.is_completed(): + workers_ready = all(self.worker_states[worker_id].has_returned_all_shots() for worker_id in task_state.workers_assigned) + if workers_ready: + # Task is fully completed and can be forgotten entirely. Re-assign the workers. + del self.task_states[task_id] + for worker_id in task_state.workers_assigned: + w = self.worker_states[worker_id] + assert w.assigned_shots <= 0 + assert w.asked_to_drop_shots == 0 + w.assigned_work_key = None + self._distribute_work() + else: + # Task is sufficiently sampled, but some workers are still running. + for worker_id in task_state.workers_assigned: + self.worker_states[worker_id].ask_to_return_all_shots() + self.progress_callback(None) + else: + self._distribute_unassigned_workers_to_jobs() + self._distribute_work_within_a_job(task_state) + + def state_summary(self) -> str: + lines = [] + for worker_id, worker in enumerate(self.worker_states): + lines.append(f'worker {worker_id}:' + f' asked_to_drop_shots={worker.asked_to_drop_shots}' + f' assigned_shots={worker.assigned_shots}' + f' assigned_work_key={worker.assigned_work_key}') + for task in self.task_states.values(): + lines.append(f'task {task.strong_id=}:\n' + f' workers_assigned={task.workers_assigned}\n' + f' shot_return_requests={task.shot_return_requests}\n' + f' shots_left={task.shots_left}\n' + f' errors_left={task.errors_left}\n' + f' shots_unassigned={task.shots_unassigned}') + return '\n' + '\n'.join(lines) + '\n' + + def process_message(self) -> bool: + try: + message = self.shared_worker_output_queue.get() + except queue.Empty: + return False + + message_type, worker_id, message_body = message + worker_state = self.worker_states[worker_id] + + if message_type == 'flushed_results': + task_strong_id, anon_stat = message_body + assert isinstance(anon_stat, AnonTaskStats) + assert worker_state.assigned_work_key == task_strong_id + task_state = self.task_states[task_strong_id] + + worker_state.assigned_shots -= anon_stat.shots + task_state.shots_left -= anon_stat.shots + if worker_state.assigned_shots < 0: + # Worker over-achieved. Correct the imbalance by giving them the shots. + extra_shots = abs(worker_state.assigned_shots) + worker_state.assigned_shots += extra_shots + task_state.shots_unassigned -= extra_shots + worker_state.send_message(( + 'accept_shots', + (task_state.strong_id, extra_shots), + )) + + if self.custom_error_count_key is None: + task_state.errors_left -= anon_stat.errors + else: + task_state.errors_left -= anon_stat.custom_counts[self.custom_error_count_key] + + stat = TaskStats( + strong_id=task_state.strong_id, + decoder=task_state.partial_task.decoder, + json_metadata=task_state.partial_task.json_metadata, + shots=anon_stat.shots, + discards=anon_stat.discards, + seconds=anon_stat.seconds, + errors=anon_stat.errors, + custom_counts=anon_stat.custom_counts, + ) + + self._handle_task_progress(task_strong_id) + + if stat.strong_id not in self.total_collected: + self.total_collected[stat.strong_id] = AnonTaskStats() + self.total_collected[stat.strong_id] += stat.to_anon_stats() + self.progress_callback(stat) + + elif message_type == 'changed_job': + pass + + elif message_type == 'accepted_shots': + pass + + elif message_type == 'returned_shots': + task_key, shots_returned = message_body + assert isinstance(shots_returned, int) + assert shots_returned >= 0 + assert worker_state.assigned_work_key == task_key + assert worker_state.asked_to_drop_shots or worker_state.asked_to_drop_errors + task_state = self.task_states[task_key] + task_state.shot_return_requests -= 1 + worker_state.asked_to_drop_shots = 0 + worker_state.asked_to_drop_errors = 0 + task_state.shots_unassigned += shots_returned + worker_state.assigned_shots -= shots_returned + assert worker_state.assigned_shots >= 0 + self._handle_task_progress(task_key) + + elif message_type == 'stopped_due_to_exception': + cur_task, cur_shots_left, unflushed_work_done, traceback, ex = message_body + raise RuntimeError(f'Worker failed: traceback={traceback}') from ex + + else: + raise NotImplementedError(f'{message_type=}') + + return True + + def run_until_done(self) -> bool: + try: + while self.task_states: + self.process_message() + return True + + except KeyboardInterrupt: + return False + + finally: + self.hard_stop() + + def _distribute_unassigned_workers_to_jobs(self): + idle_workers = [ + w + for w in range(self.num_workers)[::-1] + if self.worker_states[w].is_available_to_reassign() + ] + if not idle_workers or not self.started: + return + + groups = collections.defaultdict(list) + for work_state in self.task_states.values(): + if not work_state.is_completed(): + groups[len(work_state.workers_assigned)].append(work_state) + for k in groups.keys(): + groups[k] = groups[k][::-1] + if not groups: + return + min_assigned = min(groups.keys(), default=0) + + # Distribute workers to unfinished jobs with the fewest workers. + while idle_workers: + task_state: _ManagedTaskState = groups[min_assigned].pop() + groups[min_assigned + 1].append(task_state) + if not groups[min_assigned]: + min_assigned += 1 + + worker_id = idle_workers.pop() + task_state.workers_assigned.append(worker_id) + worker_state = self.worker_states[worker_id] + worker_state.assigned_work_key = task_state.strong_id + worker_state.send_message(( + 'change_job', + (task_state.partial_task, CollectionOptions(max_errors=task_state.errors_left), task_state.assigned_soft_error_flush_threshold), + )) + + def _distribute_unassigned_work_to_workers_within_a_job(self, task_state: _ManagedTaskState): + if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0: + return + + num_task_workers = len(task_state.workers_assigned) + expected_shots_per_worker = (task_state.shots_left + num_task_workers - 1) // num_task_workers + + # Give unassigned shots to idle workers. + for worker_id in sorted(task_state.workers_assigned, key=lambda wid: self.worker_states[wid].assigned_shots): + worker_state = self.worker_states[worker_id] + if worker_state.assigned_shots < expected_shots_per_worker: + shots_to_assign = min(expected_shots_per_worker - worker_state.assigned_shots, + task_state.shots_unassigned) + if shots_to_assign > 0: + task_state.shots_unassigned -= shots_to_assign + worker_state.assigned_shots += shots_to_assign + worker_state.send_message(( + 'accept_shots', + (task_state.strong_id, shots_to_assign), + )) + + def status_message(self) -> str: + num_known_tasks_ids = sum(e is not None for e in self.task_strong_ids) + if num_known_tasks_ids < len(self.task_strong_ids): + return f"Analyzed {num_known_tasks_ids}/{len(self.task_strong_ids)} tasks..." + max_errors = self.collection_options.max_errors + max_shots = self.collection_options.max_shots + + tasks_left = 0 + lines = [] + skipped_lines = [] + for k, strong_id in enumerate(self.task_strong_ids): + if strong_id not in self.task_states: + continue + c = self.total_collected.get(strong_id, AnonTaskStats()) + tasks_left += 1 + w = len(self.task_states[strong_id].workers_assigned) + dt = None + if max_shots is not None and c.shots: + dt = (max_shots - c.shots) * c.seconds / c.shots + c_errors = c.custom_counts[self.custom_error_count_key] if self.custom_error_count_key is not None else c.errors + if max_errors is not None and c_errors and c.seconds: + dt2 = (max_errors - c_errors) * c.seconds / c_errors + if dt is None: + dt = dt2 + else: + dt = min(dt, dt2) + if dt is not None: + dt /= 60 + if dt is not None and w > 0: + dt /= w + line = [ + f'{w}', + self.partial_tasks[k].decoder, + ("?" if dt is None or dt == 0 else "[draining]" if dt <= 0 else "<1m" if dt < 1 else str(round(dt)) + 'm') + ('·∞' if w == 0 else ''), + f'{max_shots - c.shots}' if max_shots is not None else f'{c.shots}', + f'{max_errors - c_errors}' if max_errors is not None else f'{c_errors}', + ",".join( + [f"{k}={v}" for k, v in self.partial_tasks[k].json_metadata.items()] + if isinstance(self.partial_tasks[k].json_metadata, dict) + else str(self.partial_tasks[k].json_metadata) + ) + ] + if w == 0: + skipped_lines.append(line) + else: + lines.append(line) + if len(lines) < 50 and skipped_lines: + missing_lines = 50 - len(lines) + lines.extend(skipped_lines[:missing_lines]) + skipped_lines = skipped_lines[missing_lines:] + + if lines: + lines.insert(0, [ + 'workers', + 'decoder', + 'eta', + 'shots_left' if max_shots is not None else 'shots_taken', + 'errors_left' if max_errors is not None else 'errors_seen', + 'json_metadata']) + justs = cast(list[Callable[[str, int], str]], [str.rjust, str.rjust, str.rjust, str.rjust, str.rjust, str.ljust]) + cols = len(lines[0]) + lengths = [ + max(len(lines[row][col]) for row in range(len(lines))) + for col in range(cols) + ] + lines = [ + " " + " ".join(justs[col](row[col], lengths[col]) for col in range(cols)) + for row in lines + ] + if skipped_lines: + lines.append(' ... (' + str(len(skipped_lines)) + ' more tasks) ...') + return f'{tasks_left} tasks left:\n' + '\n'.join(lines) + + def _update_soft_error_threshold_for_a_job(self, task_state: _ManagedTaskState): + if task_state.errors_left <= len(task_state.workers_assigned): + desired_threshold = 1 + elif task_state.errors_left <= task_state.assigned_soft_error_flush_threshold * self.num_workers: + desired_threshold = max(1, math.ceil(task_state.errors_left * 0.5 / self.num_workers)) + else: + return + + if task_state.assigned_soft_error_flush_threshold != desired_threshold: + task_state.assigned_soft_error_flush_threshold = desired_threshold + for wid in task_state.workers_assigned: + self.worker_states[wid].send_message(('set_soft_error_flush_threshold', desired_threshold)) + + def _take_work_if_unsatisfied_workers_within_a_job(self, task_state: _ManagedTaskState): + if not self.started or not task_state.workers_assigned or task_state.shots_left <= 0: + return + + if all(self.worker_states[w].assigned_shots > 0 for w in task_state.workers_assigned): + return + + w = len(task_state.workers_assigned) + expected_shots_per_worker = (task_state.shots_left + w - 1) // w + + # There are idle workers that couldn't be given any shots. Take shots from other workers. + for worker_id in sorted(task_state.workers_assigned, key=lambda w: self.worker_states[w].assigned_shots, reverse=True): + worker_state = self.worker_states[worker_id] + if worker_state.asked_to_drop_shots or worker_state.assigned_shots <= expected_shots_per_worker: + continue + shots_to_take = worker_state.assigned_shots - expected_shots_per_worker + assert shots_to_take > 0 + worker_state.asked_to_drop_shots = shots_to_take + task_state.shot_return_requests += 1 + worker_state.send_message(( + 'return_shots', + ( + task_state.strong_id, + shots_to_take, + ), + )) + + def _distribute_work_within_a_job(self, t: _ManagedTaskState): + self._distribute_unassigned_work_to_workers_within_a_job(t) + self._take_work_if_unsatisfied_workers_within_a_job(t) + + def _distribute_work(self): + self._distribute_unassigned_workers_to_jobs() + for w in self.task_states.values(): + if not w.is_completed(): + self._distribute_work_within_a_job(w) diff --git a/glue/sample/src/sinter/_collection/_collection_manager_test.py b/glue/sample/src/sinter/_collection/_collection_manager_test.py new file mode 100644 index 000000000..c4cb359b2 --- /dev/null +++ b/glue/sample/src/sinter/_collection/_collection_manager_test.py @@ -0,0 +1,287 @@ +import multiprocessing +import time +from typing import Any, List, Union + +import sinter +import stim + +from sinter._collection._collection_manager import CollectionManager + + +def _assert_drain_queue(q: multiprocessing.Queue, expected_contents: List[Any]): + for v in expected_contents: + assert q.get(timeout=0.1) == v + if not q.empty(): + assert False, f'queue had another item: {q.get()=}' + + +def _put_wait_not_empty(q: Union[multiprocessing.Queue, multiprocessing.SimpleQueue], item: Any): + q.put(item) + while q.empty(): + time.sleep(0.0001) + + +def test_manager(): + log = [] + t0 = sinter.Task( + circuit=stim.Circuit('H 0'), + detector_error_model=stim.DetectorErrorModel(), + decoder='fusion_blossom', + collection_options=sinter.CollectionOptions(max_shots=100_000_000, max_errors=100), + json_metadata={'a': 3}, + ) + t1 = sinter.Task( + circuit=stim.Circuit('M 0'), + detector_error_model=stim.DetectorErrorModel(), + decoder='pymatching', + collection_options=sinter.CollectionOptions(max_shots=10_000_000), + json_metadata=None, + ) + manager = CollectionManager( + num_workers=3, + worker_flush_period=30, + tasks=[t0, t1], + progress_callback=log.append, + existing_data={}, + collection_options=sinter.CollectionOptions(), + custom_decoders={}, + allowed_cpu_affinity_ids=None, + ) + + assert manager.state_summary() == """ +worker 0: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=None +worker 1: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=None +worker 2: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=None +""" + + manager.start_workers(actually_start_worker_processes=False) + manager.shared_worker_output_queue.put(('computed_strong_id', 2, 'c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa')) + manager.shared_worker_output_queue.put(('computed_strong_id', 1, 'a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604')) + manager.start_distributing_work() + + assert manager.state_summary() == """ +worker 0: asked_to_drop_shots=0 assigned_shots=100000000 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +worker 1: asked_to_drop_shots=0 assigned_shots=5000000 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +worker 2: asked_to_drop_shots=0 assigned_shots=5000000 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +task task.strong_id='a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604': + workers_assigned=[0] + shot_return_requests=0 + shots_left=100000000 + errors_left=100 + shots_unassigned=0 +task task.strong_id='c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa': + workers_assigned=[1, 2] + shot_return_requests=0 + shots_left=10000000 + errors_left=10000000 + shots_unassigned=0 +""" + + _assert_drain_queue(manager.worker_states[0].input_queue, [ + ( + 'change_job', + (t0, sinter.CollectionOptions(max_errors=100), 100), + ), + ( + 'accept_shots', + (t0.strong_id(), 100_000_000), + ), + ]) + _assert_drain_queue(manager.worker_states[1].input_queue, [ + ('compute_strong_id', t0), + ( + 'change_job', + (t1, sinter.CollectionOptions(max_errors=10000000), 10000000), + ), + ( + 'accept_shots', + (t1.strong_id(), 5_000_000), + ), + ]) + _assert_drain_queue(manager.worker_states[2].input_queue, [ + ('compute_strong_id', t1), + ( + 'change_job', + (t1, sinter.CollectionOptions(max_errors=10000000), 10000000), + ), + ( + 'accept_shots', + (t1.strong_id(), 5_000_000), + ), + ]) + + assert manager.shared_worker_output_queue.empty() + assert log.pop() is None + assert log.pop() is None + assert not log + _put_wait_not_empty(manager.shared_worker_output_queue, ( + 'flushed_results', + 2, + (t1.strong_id(), sinter.AnonTaskStats( + shots=5_000_000, + errors=123, + discards=0, + seconds=1, + )), + )) + + assert manager.process_message() + assert manager.state_summary() == """ +worker 0: asked_to_drop_shots=0 assigned_shots=100000000 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +worker 1: asked_to_drop_shots=2500000 assigned_shots=5000000 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +worker 2: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +task task.strong_id='a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604': + workers_assigned=[0] + shot_return_requests=0 + shots_left=100000000 + errors_left=100 + shots_unassigned=0 +task task.strong_id='c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa': + workers_assigned=[1, 2] + shot_return_requests=1 + shots_left=5000000 + errors_left=9999877 + shots_unassigned=0 +""" + + assert log.pop() == sinter.TaskStats( + strong_id=t1.strong_id(), + decoder=t1.decoder, + json_metadata=t1.json_metadata, + shots=5_000_000, + errors=123, + discards=0, + seconds=1, + ) + assert not log + + _assert_drain_queue(manager.worker_states[0].input_queue, []) + _assert_drain_queue(manager.worker_states[1].input_queue, [ + ( + 'return_shots', + (t1.strong_id(), 2_500_000), + ), + ]) + _assert_drain_queue(manager.worker_states[2].input_queue, []) + + _put_wait_not_empty(manager.shared_worker_output_queue, ( + 'returned_shots', + 1, + (t1.strong_id(), 2_000_000), + )) + assert manager.process_message() + assert manager.state_summary() == """ +worker 0: asked_to_drop_shots=0 assigned_shots=100000000 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +worker 1: asked_to_drop_shots=0 assigned_shots=3000000 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +worker 2: asked_to_drop_shots=0 assigned_shots=2000000 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +task task.strong_id='a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604': + workers_assigned=[0] + shot_return_requests=0 + shots_left=100000000 + errors_left=100 + shots_unassigned=0 +task task.strong_id='c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa': + workers_assigned=[1, 2] + shot_return_requests=0 + shots_left=5000000 + errors_left=9999877 + shots_unassigned=0 +""" + + _assert_drain_queue(manager.worker_states[0].input_queue, []) + _assert_drain_queue(manager.worker_states[1].input_queue, []) + _assert_drain_queue(manager.worker_states[2].input_queue, [ + ( + 'accept_shots', + (t1.strong_id(), 2_000_000), + ), + ]) + + _put_wait_not_empty(manager.shared_worker_output_queue, ( + 'flushed_results', + 1, + (t1.strong_id(), sinter.AnonTaskStats( + shots=3_000_000, + errors=444, + discards=1, + seconds=2, + )) + )) + assert manager.process_message() + assert manager.state_summary() == """ +worker 0: asked_to_drop_shots=0 assigned_shots=100000000 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +worker 1: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +worker 2: asked_to_drop_shots=1000000 assigned_shots=2000000 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +task task.strong_id='a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604': + workers_assigned=[0] + shot_return_requests=0 + shots_left=100000000 + errors_left=100 + shots_unassigned=0 +task task.strong_id='c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa': + workers_assigned=[1, 2] + shot_return_requests=1 + shots_left=2000000 + errors_left=9999433 + shots_unassigned=0 +""" + + _put_wait_not_empty(manager.shared_worker_output_queue, ( + 'flushed_results', + 2, + (t1.strong_id(), sinter.AnonTaskStats( + shots=2_000_000, + errors=555, + discards=2, + seconds=2.5, + )) + )) + assert manager.process_message() + assert manager.state_summary() == """ +worker 0: asked_to_drop_shots=0 assigned_shots=100000000 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +worker 1: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +worker 2: asked_to_drop_shots=1000000 assigned_shots=0 assigned_work_key=c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa +task task.strong_id='a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604': + workers_assigned=[0] + shot_return_requests=0 + shots_left=100000000 + errors_left=100 + shots_unassigned=0 +task task.strong_id='c03f7852e4579e2a99cefac80eeb6b09556907540ab3d7787a3d07309c3333aa': + workers_assigned=[1, 2] + shot_return_requests=1 + shots_left=0 + errors_left=9998878 + shots_unassigned=0 +""" + + assert manager.shared_worker_output_queue.empty() + _put_wait_not_empty(manager.shared_worker_output_queue, ( + 'returned_shots', + 2, + (t1.strong_id(), 0) + )) + assert manager.process_message() + assert manager.shared_worker_output_queue.empty() + assert manager.state_summary() == """ +worker 0: asked_to_drop_shots=66666666 assigned_shots=100000000 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +worker 1: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +worker 2: asked_to_drop_shots=0 assigned_shots=0 assigned_work_key=a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604 +task task.strong_id='a9165b6e4ab1053c04c017d0739a7bfff0910d62091fc9ee81716833eda7f604': + workers_assigned=[0, 1, 2] + shot_return_requests=1 + shots_left=100000000 + errors_left=100 + shots_unassigned=0 +""" + + _assert_drain_queue(manager.worker_states[0].input_queue, [ + ('return_shots', (t0.strong_id(), 66666666)), + ]) + _assert_drain_queue(manager.worker_states[1].input_queue, [ + ('change_job', (t0, sinter.CollectionOptions(max_errors=100), 100)), + ]) + _assert_drain_queue(manager.worker_states[2].input_queue, [ + ('return_shots', (t1.strong_id(), 1000000)), + ('change_job', (t0, sinter.CollectionOptions(max_errors=100), 100)), + ]) diff --git a/glue/sample/src/sinter/_collection_test.py b/glue/sample/src/sinter/_collection/_collection_test.py similarity index 78% rename from glue/sample/src/sinter/_collection_test.py rename to glue/sample/src/sinter/_collection/_collection_test.py index 3ca72a5e9..2956d4194 100644 --- a/glue/sample/src/sinter/_collection_test.py +++ b/glue/sample/src/sinter/_collection/_collection_test.py @@ -1,6 +1,9 @@ import collections +import math import pathlib +import sys import tempfile +import time import pytest import stim @@ -208,3 +211,64 @@ def test_iter_collect_worker_fails(): ), ]), )) + + +class FixedSizeSampler(sinter.Sampler, sinter.CompiledSampler): + def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler: + return self + + def sample(self, suggested_shots: int) -> 'sinter.AnonTaskStats': + return sinter.AnonTaskStats( + shots=1024, + errors=5, + ) + + +def test_fixed_size_sampler(): + results = sinter.collect( + num_workers=2, + tasks=[ + sinter.Task( + circuit=stim.Circuit(), + decoder='fixed_size_sampler', + json_metadata={}, + collection_options=sinter.CollectionOptions( + max_shots=100_000, + max_errors=1_000, + ), + ) + ], + custom_decoders={'fixed_size_sampler': FixedSizeSampler()} + ) + assert 100_000 <= results[0].shots <= 100_000 + 3000 + + +class MockTimingSampler(sinter.Sampler, sinter.CompiledSampler): + def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler: + return self + + def sample(self, suggested_shots: int) -> 'sinter.AnonTaskStats': + actual_shots = -(-suggested_shots // 1024) * 1024 + time.sleep(actual_shots * 0.00001) + return sinter.AnonTaskStats( + shots=actual_shots, + errors=5, + seconds=actual_shots * 0.00001, + ) + + +def test_mock_timing_sampler(): + results = sinter.collect( + num_workers=12, + tasks=[ + sinter.Task( + circuit=stim.Circuit(), + decoder='MockTimingSampler', + json_metadata={}, + ) + ], + max_shots=1_000_000, + max_errors=10_000, + custom_decoders={'MockTimingSampler': MockTimingSampler()}, + ) + assert 1_000_000 <= results[0].shots <= 1_000_000 + 12000 diff --git a/glue/sample/src/sinter/_collection/_collection_worker_loop.py b/glue/sample/src/sinter/_collection/_collection_worker_loop.py new file mode 100644 index 000000000..1467315fb --- /dev/null +++ b/glue/sample/src/sinter/_collection/_collection_worker_loop.py @@ -0,0 +1,35 @@ +import os +from typing import Optional, TYPE_CHECKING + +from sinter._decoding import Sampler +from sinter._collection._collection_worker_state import CollectionWorkerState + +if TYPE_CHECKING: + import multiprocessing + + +def collection_worker_loop( + flush_period: float, + worker_id: int, + sampler: Sampler, + inp: 'multiprocessing.Queue', + out: 'multiprocessing.Queue', + core_affinity: Optional[int], + custom_error_count_key: Optional[str], +) -> None: + try: + if core_affinity is not None and hasattr(os, 'sched_setaffinity'): + os.sched_setaffinity(0, {core_affinity}) + except: + # If setting the core affinity fails, we keep going regardless. + pass + + worker = CollectionWorkerState( + flush_period=flush_period, + worker_id=worker_id, + sampler=sampler, + inp=inp, + out=out, + custom_error_count_key=custom_error_count_key, + ) + worker.run_message_loop() diff --git a/glue/sample/src/sinter/_collection/_collection_worker_state.py b/glue/sample/src/sinter/_collection/_collection_worker_state.py new file mode 100644 index 000000000..ba8967e6a --- /dev/null +++ b/glue/sample/src/sinter/_collection/_collection_worker_state.py @@ -0,0 +1,256 @@ +import queue +import time +from typing import Any +from typing import Optional +from typing import TYPE_CHECKING + +import stim + +from sinter._data import AnonTaskStats +from sinter._data import CollectionOptions +from sinter._data import Task +from sinter._decoding import CompiledSampler +from sinter._decoding import Sampler + +if TYPE_CHECKING: + import multiprocessing + + +def _fill_in_task(task: Task) -> Task: + changed = False + circuit = task.circuit + if circuit is None: + circuit = stim.Circuit.from_file(task.circuit_path) + changed = True + dem = task.detector_error_model + if dem is None: + try: + dem = circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True) + except ValueError: + dem = circuit.detector_error_model(approximate_disjoint_errors=True) + changed = True + if not changed: + return task + return Task( + circuit=circuit, + decoder=task.decoder, + detector_error_model=dem, + postselection_mask=task.postselection_mask, + postselected_observables_mask=task.postselected_observables_mask, + json_metadata=task.json_metadata, + collection_options=task.collection_options, + ) + + +class CollectionWorkerState: + def __init__( + self, + *, + flush_period: float, + worker_id: int, + inp: 'multiprocessing.Queue', + out: 'multiprocessing.Queue', + sampler: Sampler, + custom_error_count_key: Optional[str], + ): + assert isinstance(flush_period, (int, float)) + assert isinstance(sampler, Sampler) + self.max_flush_period = flush_period + self.cur_flush_period = 0.01 + self.inp = inp + self.out = out + self.sampler = sampler + self.compiled_sampler: CompiledSampler | None = None + self.worker_id = worker_id + + self.current_task: Task | None = None + self.current_error_cutoff: int | None = None + self.custom_error_count_key = custom_error_count_key + self.current_task_shots_left: int = 0 + self.unflushed_results: AnonTaskStats = AnonTaskStats() + self.last_flush_message_time = time.monotonic() + self.soft_error_flush_threshold: int = 1 + + def _send_message_to_manager(self, message: Any): + self.out.put(message) + + def state_summary(self) -> str: + lines = [ + f'Worker(id={self.worker_id}) [', + f' max_flush_period={self.max_flush_period}', + f' cur_flush_period={self.cur_flush_period}', + f' sampler={self.sampler}', + f' compiled_sampler={self.compiled_sampler}', + f' current_task={self.current_task}', + f' current_error_cutoff={self.current_error_cutoff}', + f' custom_error_count_key={self.custom_error_count_key}', + f' current_task_shots_left={self.current_task_shots_left}', + f' unflushed_results={self.unflushed_results}', + f' last_flush_message_time={self.last_flush_message_time}', + f' soft_error_flush_threshold={self.soft_error_flush_threshold}', + f']', + ] + return '\n' + '\n'.join(lines) + '\n' + + def flush_results(self): + if self.unflushed_results.shots > 0: + self.last_flush_message_time = time.monotonic() + self.cur_flush_period = min(self.cur_flush_period * 1.4, self.max_flush_period) + self._send_message_to_manager(( + 'flushed_results', + self.worker_id, + (self.current_task.strong_id(), self.unflushed_results), + )) + self.unflushed_results = AnonTaskStats() + return True + return False + + def accept_shots(self, *, shots_delta: int): + assert shots_delta >= 0 + self.current_task_shots_left += shots_delta + self._send_message_to_manager(( + 'accepted_shots', + self.worker_id, + (self.current_task.strong_id(), shots_delta), + )) + + def return_shots(self, *, requested_shots: int): + assert requested_shots >= 0 + returned_shots = max(0, min(requested_shots, self.current_task_shots_left)) + self.current_task_shots_left -= returned_shots + if self.current_task_shots_left <= 0: + self.flush_results() + self._send_message_to_manager(( + 'returned_shots', + self.worker_id, + (self.current_task.strong_id(), returned_shots), + )) + + def compute_strong_id(self, *, new_task: Task): + strong_id = _fill_in_task(new_task).strong_id() + self._send_message_to_manager(( + 'computed_strong_id', + self.worker_id, + strong_id, + )) + + def change_job(self, *, new_task: Task, new_collection_options: CollectionOptions): + self.flush_results() + + self.current_task = _fill_in_task(new_task) + self.current_error_cutoff = new_collection_options.max_errors + self.compiled_sampler = self.sampler.compiled_sampler_for_task(self.current_task) + assert self.current_task.strong_id() is not None + self.current_task_shots_left = 0 + self.last_flush_message_time = time.monotonic() + + self._send_message_to_manager(( + 'changed_job', + self.worker_id, + (self.current_task.strong_id(),), + )) + + def process_messages(self) -> int: + num_processed = 0 + while True: + try: + message = self.inp.get_nowait() + except queue.Empty: + return num_processed + + num_processed += 1 + message_type, message_body = message + + if message_type == 'stop': + return -1 + + elif message_type == 'flush_results': + self.flush_results() + + elif message_type == 'compute_strong_id': + assert isinstance(message_body, Task) + self.compute_strong_id(new_task=message_body) + + elif message_type == 'change_job': + new_task, new_collection_options, soft_error_flush_threshold = message_body + self.cur_flush_period = 0.01 + self.soft_error_flush_threshold = soft_error_flush_threshold + assert isinstance(new_task, Task) + self.change_job(new_task=new_task, new_collection_options=new_collection_options) + + elif message_type == 'set_soft_error_flush_threshold': + soft_error_flush_threshold = message_body + self.soft_error_flush_threshold = soft_error_flush_threshold + + elif message_type == 'accept_shots': + job_key, shots_delta = message_body + assert isinstance(shots_delta, int) + assert job_key == self.current_task.strong_id() + self.accept_shots(shots_delta=shots_delta) + + elif message_type == 'return_shots': + job_key, requested_shots = message_body + assert isinstance(requested_shots, int) + assert job_key == self.current_task.strong_id() + self.return_shots(requested_shots=requested_shots) + + else: + raise NotImplementedError(f'{message_type=}') + + def num_unflushed_errors(self) -> int: + if self.custom_error_count_key is not None: + return self.unflushed_results.custom_counts[self.custom_error_count_key] + return self.unflushed_results.errors + + def do_some_work(self) -> bool: + did_some_work = False + + # Sample some stats. + if self.current_task_shots_left > 0: + # Don't keep sampling if we've exceeded the number of errors needed. + if self.current_error_cutoff is not None and self.current_error_cutoff <= 0: + return self.flush_results() + + some_work_done = self.compiled_sampler.sample(self.current_task_shots_left) + if some_work_done.shots < 1: + raise ValueError(f"Sampler didn't do any work. It returned statistics with shots == 0: {some_work_done}.") + assert isinstance(some_work_done, AnonTaskStats) + self.current_task_shots_left -= some_work_done.shots + if self.current_error_cutoff is not None: + errors_done = some_work_done.custom_counts[self.custom_error_count_key] if self.custom_error_count_key is not None else some_work_done.errors + self.current_error_cutoff -= errors_done + self.unflushed_results += some_work_done + did_some_work = True + + # Report them periodically. + should_flush = False + if self.num_unflushed_errors() >= self.soft_error_flush_threshold: + should_flush = True + if self.unflushed_results.shots > 0: + if self.current_task_shots_left <= 0 or self.last_flush_message_time + self.cur_flush_period < time.monotonic(): + should_flush = True + if should_flush: + did_some_work |= self.flush_results() + + return did_some_work + + def run_message_loop(self): + try: + while True: + num_messages_processed = self.process_messages() + if num_messages_processed == -1: + break + did_some_work = self.do_some_work() + if not did_some_work and num_messages_processed == 0: + time.sleep(0.01) + + except KeyboardInterrupt: + pass + + except BaseException as ex: + import traceback + self._send_message_to_manager(( + 'stopped_due_to_exception', + self.worker_id, + (None if self.current_task is None else self.current_task.strong_id(), self.current_task_shots_left, self.unflushed_results, traceback.format_exc(), ex), + )) diff --git a/glue/sample/src/sinter/_collection/_collection_worker_test.py b/glue/sample/src/sinter/_collection/_collection_worker_test.py new file mode 100644 index 000000000..db0518dc4 --- /dev/null +++ b/glue/sample/src/sinter/_collection/_collection_worker_test.py @@ -0,0 +1,214 @@ +import collections +import multiprocessing +import time +from typing import Any, List + +import sinter +import stim + +from sinter._collection._collection_worker_state import CollectionWorkerState + + +class MockWorkHandler(sinter.Sampler, sinter.CompiledSampler): + def __init__(self): + self.expected_task = None + self.expected = collections.deque() + + def compiled_sampler_for_task(self, task: sinter.Task) -> sinter.CompiledSampler: + assert task == self.expected_task + return self + + def handles_throttling(self) -> bool: + return True + + def sample(self, shots: int) -> sinter.AnonTaskStats: + assert self.expected + expected_shots, response = self.expected.popleft() + assert shots == expected_shots + return response + + +def _assert_drain_queue(q: multiprocessing.Queue, expected_contents: List[Any]): + for v in expected_contents: + assert q.get(timeout=0.1) == v + assert q.empty() + + +def _put_wait_not_empty(q: multiprocessing.Queue, item: Any): + q.put(item) + while q.empty(): + time.sleep(0.0001) + + +def test_worker_stop(): + handler = MockWorkHandler() + + inp = multiprocessing.Queue() + out = multiprocessing.Queue() + inp.cancel_join_thread() + out.cancel_join_thread() + + worker = CollectionWorkerState( + flush_period=-1, + worker_id=5, + sampler=handler, + inp=inp, + out=out, + custom_error_count_key=None, + ) + + assert worker.process_messages() == 0 + _assert_drain_queue(out, []) + + t0 = sinter.Task( + circuit=stim.Circuit('H 0'), + detector_error_model=stim.DetectorErrorModel(), + decoder='mock', + collection_options=sinter.CollectionOptions(max_shots=100_000_000), + json_metadata={'a': 3}, + ) + handler.expected_task = t0 + + _put_wait_not_empty(inp, ('change_job', (t0, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000))) + assert worker.process_messages() == 1 + _assert_drain_queue(out, [('changed_job', 5, (t0.strong_id(),))]) + + _put_wait_not_empty(inp, ('stop', None)) + assert worker.process_messages() == -1 + + +def test_worker_skip_work(): + handler = MockWorkHandler() + + inp = multiprocessing.Queue() + out = multiprocessing.Queue() + inp.cancel_join_thread() + out.cancel_join_thread() + + worker = CollectionWorkerState( + flush_period=-1, + worker_id=5, + sampler=handler, + inp=inp, + out=out, + custom_error_count_key=None, + ) + + assert worker.process_messages() == 0 + _assert_drain_queue(out, []) + + t0 = sinter.Task( + circuit=stim.Circuit('H 0'), + detector_error_model=stim.DetectorErrorModel(), + decoder='mock', + collection_options=sinter.CollectionOptions(max_shots=100_000_000), + json_metadata={'a': 3}, + ) + handler.expected_task = t0 + _put_wait_not_empty(inp, ('change_job', (t0, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000))) + assert worker.process_messages() == 1 + _assert_drain_queue(out, [('changed_job', 5, (t0.strong_id(),))]) + + _put_wait_not_empty(inp, ('accept_shots', (t0.strong_id(), 10000))) + assert worker.process_messages() == 1 + _assert_drain_queue(out, [('accepted_shots', 5, (t0.strong_id(), 10000))]) + + assert worker.current_task == t0 + assert worker.current_task_shots_left == 10000 + assert worker.process_messages() == 0 + _assert_drain_queue(out, []) + + _put_wait_not_empty(inp, ('return_shots', (t0.strong_id(), 2000))) + assert worker.process_messages() == 1 + _assert_drain_queue(out, [ + ('returned_shots', 5, (t0.strong_id(), 2000)), + ]) + + _put_wait_not_empty(inp, ('return_shots', (t0.strong_id(), 20000000))) + assert worker.process_messages() == 1 + _assert_drain_queue(out, [ + ('returned_shots', 5, (t0.strong_id(), 8000)), + ]) + + assert not worker.do_some_work() + + +def test_worker_finish_work(): + handler = MockWorkHandler() + + inp = multiprocessing.Queue() + out = multiprocessing.Queue() + inp.cancel_join_thread() + out.cancel_join_thread() + + worker = CollectionWorkerState( + flush_period=-1, + worker_id=5, + sampler=handler, + inp=inp, + out=out, + custom_error_count_key=None, + ) + + assert worker.process_messages() == 0 + _assert_drain_queue(out, []) + + ta = sinter.Task( + circuit=stim.Circuit('H 0'), + detector_error_model=stim.DetectorErrorModel(), + decoder='mock', + collection_options=sinter.CollectionOptions(max_shots=100_000_000), + json_metadata={'a': 3}, + ) + handler.expected_task = ta + _put_wait_not_empty(inp, ('change_job', (ta, sinter.CollectionOptions(max_errors=100_000_000), 100_000_000))) + _put_wait_not_empty(inp, ('accept_shots', (ta.strong_id(), 10000))) + assert worker.process_messages() == 2 + _assert_drain_queue(out, [ + ('changed_job', 5, (ta.strong_id(),)), + ('accepted_shots', 5, (ta.strong_id(), 10000)), + ]) + + assert worker.current_task == ta + assert worker.current_task_shots_left == 10000 + assert worker.process_messages() == 0 + _assert_drain_queue(out, []) + + handler.expected.append(( + 10000, + sinter.AnonTaskStats( + shots=1000, + errors=23, + discards=0, + seconds=1, + ), + )) + + assert worker.do_some_work() + worker.flush_results() + _assert_drain_queue(out, [ + ('flushed_results', 5, (ta.strong_id(), sinter.AnonTaskStats(shots=1000, errors=23, discards=0, seconds=1)))]) + + handler.expected.append(( + 9000, + sinter.AnonTaskStats( + shots=9000, + errors=13, + discards=0, + seconds=1, + ), + )) + + assert worker.do_some_work() + worker.flush_results() + _assert_drain_queue(out, [ + ('flushed_results', 5, (ta.strong_id(), sinter.AnonTaskStats( + shots=9000, + errors=13, + discards=0, + seconds=1, + ))), + ]) + assert not worker.do_some_work() + worker.flush_results() + _assert_drain_queue(out, []) diff --git a/glue/sample/src/sinter/_collection/_mux_sampler.py b/glue/sample/src/sinter/_collection/_mux_sampler.py new file mode 100755 index 000000000..d0db3caa2 --- /dev/null +++ b/glue/sample/src/sinter/_collection/_mux_sampler.py @@ -0,0 +1,56 @@ +import pathlib +from typing import Optional +from typing import Union + +from sinter._data import Task +from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_SAMPLERS +from sinter._decoding._decoding_decoder_class import Decoder +from sinter._decoding._sampler import CompiledSampler +from sinter._decoding._sampler import Sampler +from sinter._decoding._stim_then_decode_sampler import StimThenDecodeSampler + + +class MuxSampler(Sampler): + """Looks up the sampler to use for a task, by the task's decoder name.""" + + def __init__( + self, + *, + custom_decoders: Union[dict[str, Union[Decoder, Sampler]], None], + count_observable_error_combos: bool, + count_detection_events: bool, + tmp_dir: Optional[pathlib.Path], + ): + self.custom_decoders = custom_decoders + self.count_observable_error_combos = count_observable_error_combos + self.count_detection_events = count_detection_events + self.tmp_dir = tmp_dir + + def compiled_sampler_for_task(self, task: Task) -> CompiledSampler: + return self._resolve_sampler(task.decoder).compiled_sampler_for_task(task) + + def _resolve_sampler(self, name: str) -> Sampler: + sub_sampler: Union[Decoder, Sampler] + + if name in self.custom_decoders: + sub_sampler = self.custom_decoders[name] + elif name in BUILT_IN_SAMPLERS: + sub_sampler = BUILT_IN_SAMPLERS[name] + else: + raise NotImplementedError(f'Not a recognized decoder or sampler: {name=}. Did you forget to specify custom_decoders?') + + if isinstance(sub_sampler, Sampler): + if self.count_detection_events: + raise NotImplementedError("'count_detection_events' not supported when using a custom Sampler (instead of a custom Decoder).") + if self.count_observable_error_combos: + raise NotImplementedError("'count_observable_error_combos' not supported when using a custom Sampler (instead of a custom Decoder).") + return sub_sampler + elif isinstance(sub_sampler, Decoder) or hasattr(sub_sampler, 'compile_decoder_for_dem'): + return StimThenDecodeSampler( + decoder=sub_sampler, + count_detection_events=self.count_detection_events, + count_observable_error_combos=self.count_observable_error_combos, + tmp_dir=self.tmp_dir, + ) + else: + raise NotImplementedError(f"Don't know how to turn this into a Sampler: {sub_sampler!r}") diff --git a/glue/sample/src/sinter/_printer.py b/glue/sample/src/sinter/_collection/_printer.py similarity index 100% rename from glue/sample/src/sinter/_printer.py rename to glue/sample/src/sinter/_collection/_printer.py diff --git a/glue/sample/src/sinter/_collection/_sampler_ramp_throttled.py b/glue/sample/src/sinter/_collection/_sampler_ramp_throttled.py new file mode 100755 index 000000000..5f8ae056a --- /dev/null +++ b/glue/sample/src/sinter/_collection/_sampler_ramp_throttled.py @@ -0,0 +1,66 @@ +import time + +from sinter._decoding import Sampler, CompiledSampler +from sinter._data import Task, AnonTaskStats + + +class RampThrottledSampler(Sampler): + """Wraps a sampler to adjust requested shots to hit a target time. + + This sampler will initially only take 1 shot per call. If the time taken + significantly undershoots the target time, the maximum number of shots per + call is increased by a constant factor. If it exceeds the target time, the + maximum is reduced by a constant factor. The result is that the sampler + "ramps up" how many shots it does per call until it takes roughly the target + time, and then dynamically adapts to stay near it. + """ + + def __init__(self, sub_sampler: Sampler, target_batch_seconds: float, max_batch_shots: int): + self.sub_sampler = sub_sampler + self.target_batch_seconds = target_batch_seconds + self.max_batch_shots = max_batch_shots + + def __str__(self) -> str: + return f'CompiledRampThrottledSampler({self.sub_sampler})' + + def compiled_sampler_for_task(self, task: Task) -> CompiledSampler: + compiled_sub_sampler = self.sub_sampler.compiled_sampler_for_task(task) + if compiled_sub_sampler.handles_throttling(): + return compiled_sub_sampler + + return CompiledRampThrottledSampler( + sub_sampler=compiled_sub_sampler, + target_batch_seconds=self.target_batch_seconds, + max_batch_shots=self.max_batch_shots, + ) + + +class CompiledRampThrottledSampler(CompiledSampler): + def __init__(self, sub_sampler: CompiledSampler, target_batch_seconds: float, max_batch_shots: int): + self.sub_sampler = sub_sampler + self.target_batch_seconds = target_batch_seconds + self.batch_shots = 1 + self.max_batch_shots = max_batch_shots + + def __str__(self) -> str: + return f'CompiledRampThrottledSampler({self.sub_sampler})' + + def sample(self, max_shots: int) -> AnonTaskStats: + t0 = time.monotonic() + actual_shots = min(max_shots, self.batch_shots) + result = self.sub_sampler.sample(actual_shots) + dt = time.monotonic() - t0 + + # Rebalance number of shots. + if self.batch_shots > 1 and dt > self.target_batch_seconds * 1.3: + self.batch_shots //= 2 + if result.shots * 2 >= actual_shots: + for _ in range(4): + if self.batch_shots * 2 > self.max_batch_shots: + break + if dt > self.target_batch_seconds * 0.3: + break + self.batch_shots *= 2 + dt *= 2 + + return result diff --git a/glue/sample/src/sinter/_collection_tracker_for_single_task.py b/glue/sample/src/sinter/_collection_tracker_for_single_task.py deleted file mode 100644 index 0932b55e6..000000000 --- a/glue/sample/src/sinter/_collection_tracker_for_single_task.py +++ /dev/null @@ -1,230 +0,0 @@ -import math -from typing import Iterator -from typing import Optional - -import stim - -from sinter._anon_task_stats import AnonTaskStats -from sinter._existing_data import ExistingData -from sinter._task import Task -from sinter._worker import WorkIn -from sinter._worker import WorkOut - - -DEFAULT_MAX_BATCH_SECONDS = 120 - - -class CollectionTrackerForSingleTask: - def __init__( - self, - *, - task: Task, - count_observable_error_combos: bool, - count_detection_events: bool, - circuit_path: str, - dem_path: str, - existing_data: ExistingData, - custom_error_count_key: Optional[str], - ): - self.unfilled_task = task - self.count_observable_error_combos = count_observable_error_combos - self.count_detection_events = count_detection_events - self.task_strong_id = None - self.circuit_path = circuit_path - self.dem_path = dem_path - self.finished_stats = AnonTaskStats() - self.existing_data = existing_data - self.deployed_shots = 0 - self.waiting_for_worker_computing_dem_and_strong_id = False - self.deployed_processes = 0 - self.custom_error_count_key = custom_error_count_key - - task.circuit.to_file(circuit_path) - if task.detector_error_model is not None: - task.detector_error_model.to_file(dem_path) - self.task_strong_id = task.strong_id() - existing = self.existing_data.data.get(self.task_strong_id) - if existing is not None: - self.finished_stats += existing.to_anon_stats() - if self.copts.max_shots is None and self.copts.max_errors is None: - raise ValueError('Neither the task nor the collector specified max_shots or max_errors. Must specify one.') - - @property - def copts(self): - return self.unfilled_task.collection_options - - def expected_shots_remaining( - self, *, safety_factor_on_shots_per_error: float = 1) -> float: - """Doesn't include deployed shots.""" - result: float = float('inf') - - if self.copts.max_shots is not None: - result = self.copts.max_shots - self.finished_stats.shots - - errs = self._seen_errors() - if errs and self.copts.max_errors is not None: - shots_per_error = self.finished_stats.shots / errs - errors_left = self.copts.max_errors - errs - result = min(result, errors_left * shots_per_error * safety_factor_on_shots_per_error) - - return result - - def _seen_errors(self) -> int: - if self.custom_error_count_key is not None: - return self.finished_stats.custom_counts.get(self.custom_error_count_key, 0) - return self.finished_stats.errors - - def expected_time_per_shot(self) -> Optional[float]: - if self.finished_stats.shots == 0: - return None - return self.finished_stats.seconds / self.finished_stats.shots - - def expected_errors_per_shot(self) -> Optional[float]: - return (self._seen_errors() + 1) / (self.finished_stats.shots + 1) - - def expected_time_remaining(self) -> Optional[float]: - dt = self.expected_time_per_shot() - n = self.expected_shots_remaining() - if dt is None or n == float('inf'): - return None - return dt * n - - def work_completed(self, result: WorkOut) -> None: - if self.waiting_for_worker_computing_dem_and_strong_id: - self.task_strong_id = result.strong_id - existing = self.existing_data.data.get(self.task_strong_id) - if existing is not None: - self.finished_stats += existing.to_anon_stats() - self.waiting_for_worker_computing_dem_and_strong_id = False - else: - self.deployed_shots -= result.stats.shots - self.finished_stats += result.stats - self.deployed_processes -= 1 - - def is_done(self) -> bool: - if self.task_strong_id is None or self.waiting_for_worker_computing_dem_and_strong_id: - return False - enough_shots = False - if self.copts.max_shots is not None and self.finished_stats.shots >= self.copts.max_shots: - enough_shots = True - if self.copts.max_errors is not None and self._seen_errors() >= self.copts.max_errors: - enough_shots = True - return enough_shots and self.deployed_shots == 0 - - def iter_batch_size_limits(self, *, desperate: bool) -> Iterator[float]: - if self.finished_stats.shots == 0: - if self.deployed_shots > 0: - yield 0 - elif self.copts.start_batch_size is None: - yield 100 - else: - yield self.copts.start_batch_size - return - - # Do exponential ramp-up of batch sizes. - yield self.finished_stats.shots * 2 - - # Don't go super parallel before reaching other maximums. - if not desperate: - yield self.finished_stats.shots * 5 - self.deployed_shots - - # Don't take more shots than requested. - if self.copts.max_shots is not None: - yield self.copts.max_shots - self.finished_stats.shots - self.deployed_shots - - # Don't take more errors than requested. - if self.copts.max_errors is not None: - errors_left = self.copts.max_errors - self._seen_errors() - errors_left += 2 # oversample once count gets low - de = self.expected_errors_per_shot() - yield errors_left / de - self.deployed_shots - - # Don't exceed max batch size. - if self.copts.max_batch_size is not None: - yield self.copts.max_batch_size - - # If no maximum on batch size is specified, default to 30s maximum. - max_batch_seconds = self.copts.max_batch_seconds - if max_batch_seconds is None and self.copts.max_batch_size is None: - max_batch_seconds = DEFAULT_MAX_BATCH_SECONDS - - # Try not to exceed max batch duration. - if max_batch_seconds is not None: - dt = self.expected_time_per_shot() - if dt is not None and dt > 0: - yield max(1, math.floor(max_batch_seconds / dt)) - - def next_shot_count(self, *, desperate: bool) -> int: - return math.ceil(min(self.iter_batch_size_limits(desperate=desperate))) - - def provide_more_work(self, *, desperate: bool) -> Optional[WorkIn]: - if self.task_strong_id is None: - if self.waiting_for_worker_computing_dem_and_strong_id: - return None - self.waiting_for_worker_computing_dem_and_strong_id = True - self.deployed_processes += 1 - return WorkIn( - work_key=None, - circuit_path=self.circuit_path, - dem_path=self.dem_path, - decoder=self.unfilled_task.decoder, - postselection_mask=self.unfilled_task.postselection_mask, - postselected_observables_mask=self.unfilled_task.postselected_observables_mask, - json_metadata=self.unfilled_task.json_metadata, - strong_id=None, - num_shots=-1, - count_observable_error_combos=self.count_observable_error_combos, - count_detection_events=self.count_detection_events, - ) - - # Wait to have *some* data before starting to sample in parallel. - num_shots = self.next_shot_count(desperate=desperate) - if num_shots <= 0: - return None - - self.deployed_shots += num_shots - self.deployed_processes += 1 - return WorkIn( - work_key=None, - strong_id=self.task_strong_id, - circuit_path=self.circuit_path, - dem_path=self.dem_path, - decoder=self.unfilled_task.decoder, - postselection_mask=self.unfilled_task.postselection_mask, - postselected_observables_mask=self.unfilled_task.postselected_observables_mask, - json_metadata=self.unfilled_task.json_metadata, - num_shots=num_shots, - count_observable_error_combos=self.count_observable_error_combos, - count_detection_events=self.count_detection_events, - ) - - def status(self) -> str: - t = self.expected_time_remaining() - if t is not None: - t /= 60 - t = math.ceil(t) - t = f'{t}' - terms = [ - f'{self.unfilled_task.decoder} '.rjust(22), - f'processes={self.deployed_processes}'.ljust(13), - f'~core_mins_left={t}'.ljust(24), - ] - if self.task_strong_id is None: - terms.append(f'(initializing...) ') - else: - if self.copts.max_shots is not None: - terms.append(f'shots_left={max(0, self.copts.max_shots - self.finished_stats.shots)}'.ljust(20)) - if self.copts.max_errors is not None: - terms.append(f'errors_left={max(0, self.copts.max_errors - self._seen_errors())}'.ljust(20)) - if isinstance(self.unfilled_task.json_metadata, dict): - keys = self.unfilled_task.json_metadata.keys() - try: - keys = sorted(keys) - except: - keys = list(keys) - meta_desc = '{' + ','.join(f'{k}={self.unfilled_task.json_metadata[k]}' for k in keys) + '}' - else: - meta_desc = f'{self.unfilled_task.json_metadata}' - terms.append(meta_desc) - - return ''.join(terms) diff --git a/glue/sample/src/sinter/_collection_work_manager.py b/glue/sample/src/sinter/_collection_work_manager.py deleted file mode 100644 index d32db6985..000000000 --- a/glue/sample/src/sinter/_collection_work_manager.py +++ /dev/null @@ -1,275 +0,0 @@ -import os - -import contextlib -import multiprocessing -import pathlib -import tempfile -import stim -from typing import cast, Iterable, Optional, Iterator, Tuple, Dict, List - -from sinter._decoding_decoder_class import Decoder -from sinter._collection_options import CollectionOptions -from sinter._existing_data import ExistingData -from sinter._task_stats import TaskStats -from sinter._task import Task -from sinter._anon_task_stats import AnonTaskStats -from sinter._collection_tracker_for_single_task import CollectionTrackerForSingleTask -from sinter._worker import worker_loop, WorkIn, WorkOut - - -class CollectionWorkManager: - def __init__( - self, - *, - tasks_iter: Iterator[Task], - global_collection_options: CollectionOptions, - additional_existing_data: Optional[ExistingData], - count_observable_error_combos: bool, - count_detection_events: bool, - decoders: Optional[Iterable[str]], - custom_decoders: Dict[str, Decoder], - custom_error_count_key: Optional[str], - allowed_cpu_affinity_ids: Optional[List[int]], - ): - self.custom_decoders = custom_decoders - self.queue_from_workers: Optional[multiprocessing.Queue] = None - self.queue_to_workers: Optional[multiprocessing.Queue] = None - self.additional_existing_data = ExistingData() if additional_existing_data is None else additional_existing_data - self.tmp_dir: Optional[pathlib.Path] = None - self.exit_stack: Optional[contextlib.ExitStack] = None - self.custom_error_count_key = custom_error_count_key - self.allowed_cpu_affinity_ids = None if allowed_cpu_affinity_ids is None else sorted(set(allowed_cpu_affinity_ids)) - - self.global_collection_options = global_collection_options - self.decoders: Optional[Tuple[str, ...]] = None if decoders is None else tuple(decoders) - self.did_work = False - - self.workers: List[multiprocessing.Process] = [] - self.active_collectors: Dict[int, CollectionTrackerForSingleTask] = {} - self.next_collector_key: int = 0 - self.finished_count = 0 - self.deployed_jobs: Dict[int, WorkIn] = {} - self.next_job_id = 0 - self.count_observable_error_combos = count_observable_error_combos - self.count_detection_events = count_detection_events - - self.tasks_with_decoder_iter: Iterator[Task] = _iter_tasks_with_assigned_decoders( - tasks_iter=tasks_iter, - default_decoders=self.decoders, - global_collections_options=self.global_collection_options) - - def start_workers(self, num_workers: int) -> None: - assert self.tmp_dir is not None - current_method = multiprocessing.get_start_method() - try: - # To ensure the child processes do not accidentally share ANY state - # related to, we use 'spawn' instead of 'fork'. - multiprocessing.set_start_method('spawn', force=True) - # Create queues after setting start method to work around a deadlock - # bug that occurs otherwise. - self.queue_from_workers = multiprocessing.Queue() - self.queue_to_workers = multiprocessing.Queue() - self.queue_from_workers.cancel_join_thread() - self.queue_to_workers.cancel_join_thread() - - if self.allowed_cpu_affinity_ids is None: - cpus = range(os.cpu_count()) - else: - num_cpus = os.cpu_count() - cpus = [e for e in self.allowed_cpu_affinity_ids if e < num_cpus] - for index in range(num_workers): - cpu_pin = None if len(cpus) == 0 else cpus[index % len(cpus)] - w = multiprocessing.Process( - target=worker_loop, - args=(self.tmp_dir, self.queue_to_workers, self.queue_from_workers, self.custom_decoders, cpu_pin)) - w.start() - self.workers.append(w) - finally: - multiprocessing.set_start_method(current_method, force=True) - - def __enter__(self): - self.exit_stack = contextlib.ExitStack().__enter__() - self.tmp_dir = pathlib.Path(self.exit_stack.enter_context(tempfile.TemporaryDirectory())) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.shut_down_workers() - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - self.exit_stack = None - self.tmp_dir = None - - def shut_down_workers(self) -> None: - removed_workers = self.workers - self.workers = [] - - # SIGKILL everything. - for w in removed_workers: - # This is supposed to be safe because all state on disk was put - # in the specified tmp directory which we will handle deleting. - w.kill() - # Wait for them to be done. - for w in removed_workers: - w.join() - - def fill_work_queue(self) -> bool: - while len(self.deployed_jobs) < len(self.workers): - work = self.provide_more_work() - if work is None: - break - self.did_work = True - self.queue_to_workers.put(work.with_work_key((self.next_job_id, work.work_key))) - self.deployed_jobs[self.next_job_id] = work - self.next_job_id += 1 - return bool(self.deployed_jobs) - - def wait_for_next_sample(self, - *, - timeout: Optional[float] = None, - ) -> TaskStats: - result = self.queue_from_workers.get(timeout=timeout) - assert isinstance(result, WorkOut) - if result.msg_error is not None: - msg, error = result.msg_error - if isinstance(error, KeyboardInterrupt): - raise KeyboardInterrupt() - raise RuntimeError(f"Worker failed: {msg}") from error - - else: - job_id, sub_key = result.work_key - stats = result.stats - work_in = self.deployed_jobs[job_id] - - self.work_completed(WorkOut( - work_key=sub_key, - stats=stats, - strong_id=result.strong_id, - msg_error=result.msg_error, - )) - del self.deployed_jobs[job_id] - if stats is None: - stats = AnonTaskStats() - return TaskStats( - strong_id=result.strong_id, - decoder=work_in.decoder, - json_metadata=work_in.json_metadata, - shots=stats.shots, - errors=stats.errors, - custom_counts=stats.custom_counts, - discards=stats.discards, - seconds=stats.seconds, - ) - - def _iter_draw_collectors(self, *, prefer_started: bool) -> Iterator[Tuple[int, CollectionTrackerForSingleTask]]: - if prefer_started: - yield from self.active_collectors.items() - while True: - key = self.next_collector_key - try: - task = next(self.tasks_with_decoder_iter) - except StopIteration: - break - collector = CollectionTrackerForSingleTask( - task=task, - circuit_path=str((self.tmp_dir / f'circuit_{self.next_collector_key}.stim').absolute()), - dem_path=str((self.tmp_dir / f'dem_{self.next_collector_key}.dem').absolute()), - existing_data=self.additional_existing_data, - count_detection_events=self.count_detection_events, - count_observable_error_combos=self.count_observable_error_combos, - custom_error_count_key=self.custom_error_count_key, - ) - if collector.is_done(): - self.finished_count += 1 - continue - self.next_collector_key += 1 - self.active_collectors[key] = collector - yield key, collector - if not prefer_started: - yield from self.active_collectors.items() - - def is_done(self) -> bool: - return len(self.active_collectors) == 0 - - def work_completed(self, result: WorkOut): - assert isinstance(result.work_key, int) - collector_index = cast(int, result.work_key) - collector = self.active_collectors[collector_index] - collector.work_completed(result) - if collector.is_done(): - self.finished_count += 1 - del self.active_collectors[collector_index] - - def provide_more_work(self) -> Optional[WorkIn]: - iter_collectors = self._iter_draw_collectors( - prefer_started=len(self.active_collectors) >= 2) - for desperate in False, True: - for collector_index, collector in iter_collectors: - w = collector.provide_more_work(desperate=desperate) - if w is not None: - assert w.work_key is None - return w.with_work_key(collector_index) - return None - - def status(self, *, num_circuits: Optional[int]) -> str: - if self.is_done(): - if self.did_work: - main_status = 'Done collecting' - else: - main_status = 'There was nothing additional to collect' - elif num_circuits is not None: - main_status = f'{num_circuits - self.finished_count} cases left:' - else: - main_status = "Running..." - collector_statuses = [ - collector.status() - for collector in self.active_collectors.values() - ] - if len(collector_statuses) > 24: - collector_statuses = collector_statuses[:24] + ['\n...'] - - min_indent = 0 - while collector_statuses and all(min_indent < len(c) and c[min_indent] == ' ' for c in collector_statuses): - min_indent += 1 - if min_indent > 4: - collector_statuses = [c[min_indent - 4:] for c in collector_statuses] - collector_statuses = ['\n' + c for c in collector_statuses] - - return main_status + ''.join(collector_statuses) - - -def _iter_tasks_with_assigned_decoders( - *, - tasks_iter: Iterator[Task], - default_decoders: Optional[Iterable[str]], - global_collections_options: CollectionOptions, -) -> Iterator[Task]: - for task in tasks_iter: - if task.circuit is None: - task = Task( - circuit=stim.Circuit.from_file(task.circuit_path), - decoder=task.decoder, - detector_error_model=task.detector_error_model, - postselection_mask=task.postselection_mask, - postselected_observables_mask=task.postselected_observables_mask, - json_metadata=task.json_metadata, - collection_options=task.collection_options, - circuit_path=task.circuit_path, - ) - - if task.decoder is None and default_decoders is None: - raise ValueError("Decoders to use was not specified. decoders is None and task.decoder is None") - task_decoders = [] - if default_decoders is not None: - task_decoders.extend(default_decoders) - if task.decoder is not None and task.decoder not in task_decoders: - task_decoders.append(task.decoder) - for decoder in task_decoders: - yield Task( - circuit=task.circuit, - decoder=decoder, - detector_error_model=task.detector_error_model, - postselection_mask=task.postselection_mask, - postselected_observables_mask=task.postselected_observables_mask, - json_metadata=task.json_metadata, - collection_options=task.collection_options.combine(global_collections_options), - skip_validation=True, - ) diff --git a/glue/sample/src/sinter/_command/__init__.py b/glue/sample/src/sinter/_command/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/glue/sample/src/sinter/_main.py b/glue/sample/src/sinter/_command/_main.py similarity index 83% rename from glue/sample/src/sinter/_main.py rename to glue/sample/src/sinter/_command/_main.py index ef5f3c35d..30488ca33 100644 --- a/glue/sample/src/sinter/_main.py +++ b/glue/sample/src/sinter/_command/_main.py @@ -8,16 +8,16 @@ def main(*, command_line_args: Optional[List[str]] = None): mode = command_line_args[0] if command_line_args else None if mode == 'combine': - from sinter._main_combine import main_combine + from sinter._command._main_combine import main_combine return main_combine(command_line_args=command_line_args[1:]) if mode == 'collect': - from sinter._main_collect import main_collect + from sinter._command._main_collect import main_collect return main_collect(command_line_args=command_line_args[1:]) if mode == 'plot': - from sinter._main_plot import main_plot + from sinter._command._main_plot import main_plot return main_plot(command_line_args=command_line_args[1:]) if mode == 'predict': - from sinter._main_predict import main_predict + from sinter._command._main_predict import main_predict return main_predict(command_line_args=command_line_args[1:]) want_help = mode in ['help', 'h', '--help', '-help', '-h', '--h'] diff --git a/glue/sample/src/sinter/_main_collect.py b/glue/sample/src/sinter/_command/_main_collect.py similarity index 96% rename from glue/sample/src/sinter/_main_collect.py rename to glue/sample/src/sinter/_command/_main_collect.py index b53dc392d..5f008ae7a 100644 --- a/glue/sample/src/sinter/_main_collect.py +++ b/glue/sample/src/sinter/_command/_main_collect.py @@ -8,12 +8,11 @@ import numpy as np import stim -import sinter -from sinter._printer import ThrottledProgressPrinter -from sinter._task import Task +from sinter._collection import ThrottledProgressPrinter +from sinter._data import Task from sinter._collection import collect, Progress, post_selection_mask_from_predicate -from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS -from sinter._main_combine import ExistingData, CSV_HEADER +from sinter._command._main_combine import ExistingData, CSV_HEADER +from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_SAMPLERS def iter_file_paths_into_goals(circuit_paths: Iterator[str], @@ -82,7 +81,7 @@ def parse_args(args: List[str]) -> Any: default=None, help='Sampling of a circuit will stop if this many errors have been seen.') parser.add_argument('--processes', - required=True, + default='auto', type=str, help='Number of processes to use for simultaneous sampling and decoding. ' 'Must be either a number or "auto" which sets it to the number of ' @@ -203,6 +202,7 @@ def parse_args(args: List[str]) -> Any: ''' --metadata_func "auto"\n''' ''' --metadata_func "{'n': circuit.num_qubits, 'p': float(path.split('/')[-1].split('.')[0])}"\n''' ) + import sinter a = parser.parse_args(args=args) if a.metadata_func == 'auto': a.metadata_func = "sinter.comma_separated_key_values(path)" @@ -243,9 +243,9 @@ def parse_args(args: List[str]) -> Any: else: a.custom_decoders = None for decoder in a.decoders: - if decoder not in BUILT_IN_DECODERS and (a.custom_decoders is None or decoder not in a.custom_decoders): - message = f"Not a recognized decoder: {decoder=}.\n" - message += f"Available built-in decoders: {sorted(e for e in BUILT_IN_DECODERS.keys() if 'internal' not in e)}.\n" + if decoder not in BUILT_IN_SAMPLERS and (a.custom_decoders is None or decoder not in a.custom_decoders): + message = f"Not a recognized decoder or sampler: {decoder=}.\n" + message += f"Available built-in decoders and samplers: {sorted(e for e in BUILT_IN_SAMPLERS.keys() if 'internal' not in e)}.\n" if a.custom_decoders is None: message += f"No custom decoders are available. --custom_decoders_module_function wasn't specified." else: @@ -296,7 +296,7 @@ def main_collect(*, command_line_args: List[str]): printer = ThrottledProgressPrinter( outs=[], print_progress=not args.quiet, - min_progress_delay=0.03 if args.also_print_results_to_stdout else 0.2, + min_progress_delay=0.03 if args.also_print_results_to_stdout else 0.1, ) if print_to_stdout: printer.outs.append(sys.stdout) diff --git a/glue/sample/src/sinter/_main_collect_test.py b/glue/sample/src/sinter/_command/_main_collect_test.py similarity index 92% rename from glue/sample/src/sinter/_main_collect_test.py rename to glue/sample/src/sinter/_command/_main_collect_test.py index 09f45d890..408fde653 100644 --- a/glue/sample/src/sinter/_main_collect_test.py +++ b/glue/sample/src/sinter/_command/_main_collect_test.py @@ -6,8 +6,8 @@ import pytest import sinter -from sinter._main import main -from sinter._main_combine import ExistingData +from sinter._command._main import main +from sinter._command._main_combine import ExistingData from sinter._plotting import split_by @@ -144,7 +144,7 @@ def test_main_collect_with_custom_decoder(): "--decoders", "NOTEXIST", "--custom_decoders_module_function", - "sinter._main_collect_test:_make_custom_decoders", + "sinter._command._main_collect_test:_make_custom_decoders", "--processes", "2", "--quiet", @@ -162,7 +162,7 @@ def test_main_collect_with_custom_decoder(): "--decoders", "alternate", "--custom_decoders_module_function", - "sinter._main_collect_test:_make_custom_decoders", + "sinter._command._main_collect_test:_make_custom_decoders", "--processes", "2", "--quiet", @@ -450,3 +450,33 @@ def test_auto_processes(): ]) data = sinter.stats_from_csv_files(d / "out.csv") assert len(data) == 1 + + +def test_implicit_auto_processes(): + with tempfile.TemporaryDirectory() as d: + d = pathlib.Path(d) + stim.Circuit.generated( + 'repetition_code:memory', + rounds=5, + distance=3, + after_clifford_depolarization=0.1, + ).to_file(d / 'a=3.stim') + + # Collects requested stats. + main(command_line_args=[ + "collect", + "--circuits", + str(d / 'a=3.stim'), + "--max_shots", + "200", + "--quiet", + "--metadata_func", + "auto", + "--decoders", + "perfectionist", + "--save_resume_filepath", + str(d / "out.csv"), + ]) + data = sinter.stats_from_csv_files(d / "out.csv") + assert len(data) == 1 + assert data[0].discards > 0 diff --git a/glue/sample/src/sinter/_main_combine.py b/glue/sample/src/sinter/_command/_main_combine.py similarity index 97% rename from glue/sample/src/sinter/_main_combine.py rename to glue/sample/src/sinter/_command/_main_combine.py index 86c5e0cf8..216815e9b 100644 --- a/glue/sample/src/sinter/_main_combine.py +++ b/glue/sample/src/sinter/_command/_main_combine.py @@ -6,8 +6,7 @@ from typing import List, Any import sinter -from sinter._csv_out import CSV_HEADER -from sinter._existing_data import ExistingData +from sinter._data import CSV_HEADER, ExistingData from sinter._plotting import better_sorted_str_terms diff --git a/glue/sample/src/sinter/_main_combine_test.py b/glue/sample/src/sinter/_command/_main_combine_test.py similarity index 99% rename from glue/sample/src/sinter/_main_combine_test.py rename to glue/sample/src/sinter/_command/_main_combine_test.py index f89ae4634..6419c7d22 100644 --- a/glue/sample/src/sinter/_main_combine_test.py +++ b/glue/sample/src/sinter/_command/_main_combine_test.py @@ -3,7 +3,7 @@ import pathlib import tempfile -from sinter._main import main +from sinter._command._main import main def test_main_combine(): diff --git a/glue/sample/src/sinter/_main_plot.py b/glue/sample/src/sinter/_command/_main_plot.py similarity index 82% rename from glue/sample/src/sinter/_main_plot.py rename to glue/sample/src/sinter/_command/_main_plot.py index c6defeb90..4c1c38da8 100644 --- a/glue/sample/src/sinter/_main_plot.py +++ b/glue/sample/src/sinter/_command/_main_plot.py @@ -4,11 +4,12 @@ import argparse import matplotlib.pyplot as plt +import numpy as np -from sinter import shot_error_rate_to_piece_error_rate -from sinter._main_combine import ExistingData +from sinter._command._main_combine import ExistingData from sinter._plotting import plot_discard_rate, plot_custom from sinter._plotting import plot_error_rate +from sinter._probability_util import shot_error_rate_to_piece_error_rate, Fit if TYPE_CHECKING: import sinter @@ -32,6 +33,16 @@ def parse_args(args: List[str]) -> Any: 'Examples:\n' ''' --filter_func "decoder=='pymatching'"\n''' ''' --filter_func "0.001 < metadata['p'] < 0.005"\n''') + parser.add_argument('--preprocess_stats_func', + type=str, + default=None, + help='An expression that operates on a `stats` value, returning a new list of stats to plot.\n' + 'For example, this could double add a field to json_metadata or merge stats together.\n' + 'Examples:\n' + ''' --preprocess_stats_func "[stat for stat in stats if stat.errors > 0]\n''' + ''' --preprocess_stats_func "[stat.with_edits(errors=stat.custom_counts['severe_errors']) for stat in stats]\n''' + ''' --preprocess_stats_func "__import__('your_custom_module').your_custom_function(stats)"\n''' + ) parser.add_argument('--x_func', type=str, default="1", @@ -49,6 +60,21 @@ def parse_args(args: List[str]) -> Any: ''' --x_func m.p\n''' ''' --x_func "metadata['path'].split('/')[-1].split('.')[0]"\n''' ) + parser.add_argument('--point_label_func', + type=str, + default="None", + help='A python expression that determines text to put next to data points.\n' + 'Values available to the python expression:\n' + ' metadata: The parsed value from the json_metadata for the data point.\n' + ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' decoder: The decoder that decoded the data for the data point.\n' + ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' + ' stat: The sinter.TaskStats object for the data point.\n' + 'Expected expression type:\n' + ' Something Falsy (no label), or something that can be given to `str` to get a string.\n' + 'Examples:\n' + ''' --point_label_func "f'p={m.p}'"\n''' + ) parser.add_argument('--y_func', type=str, default=None, @@ -62,7 +88,8 @@ def parse_args(args: List[str]) -> Any: ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' 'Expected expression type:\n' - ' Something that can be given to `float` to get a float.\n' + ' A `sinter.Fit` specifying an uncertainty region,.\n' + ' or else something that can be given to `float` to get a float.\n' 'Examples:\n' ''' --x_func "metadata['p']"\n''' ''' --x_func "metadata['path'].split('/')[-1].split('.')[0]"\n''' @@ -76,6 +103,7 @@ def parse_args(args: List[str]) -> Any: type=str, default="'all data (use -group_func and -x_func to group into curves)'", help='A python expression that determines how points are grouped into curves.\n' + 'If this evaluates to a dict, different keys control different groupings (e.g. "color" and "marker")\n' 'Values available to the python expression:\n' ' metadata: The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' @@ -83,10 +111,16 @@ def parse_args(args: List[str]) -> Any: ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' 'Expected expression type:\n' - ' Something that can be given to `str` to get a useful string.\n' + ' A dict, or something that can be given to `str` to get a useful string.\n' + 'Recognized dict keys:\n' + ' "color": controls color grouping\n' + ' "marker": controls marker grouping\n' + ' "linestyle": controls linestyle grouping\n' + ' "order": controls ordering in the legend\n' + ' "label": the text shown in the legend\n' 'Examples:\n' - ''' --group_func "(decoder, metadata['d'])"\n''' - ''' --group_func m.d\n''' + ''' --group_func "(decoder, m.d)"\n''' + ''' --group_func "{'color': decoder, 'marker': m.d, 'label': (decoder, m.d)}"\n''' ''' --group_func "metadata['path'].split('/')[-2]"\n''' ) parser.add_argument('--failure_unit_name', @@ -151,7 +185,7 @@ def parse_args(args: List[str]) -> Any: ) parser.add_argument('--plot_args_func', type=str, - default='''{'marker': 'ov*sp^<>8P+xXhHDd|'[index % 18]}''', + default='''{}''', help='A python expression used to customize the look of curves.\n' 'Values available to the python expression:\n' ' index: A unique integer identifying the curve.\n' @@ -181,9 +215,8 @@ def parse_args(args: List[str]) -> Any: parser.add_argument('--out', type=str, default=None, - help='Output file to write the plot to.\n' - 'The file extension determines the type of image.\n' - 'Either this or --show must be specified.') + help='Write the plot to a file instead of showing it.\n' + '(Use --show to still show the plot.)') parser.add_argument('--xaxis', type=str, default='[log]', @@ -207,8 +240,7 @@ def parse_args(args: List[str]) -> Any: "stats.") parser.add_argument('--show', action='store_true', - help='Displays the plot in a window.\n' - 'Either this or --out must be specified.') + help='Displays the plot in a window even when --out is specified.') parser.add_argument('--xmin', default=None, type=float, @@ -246,8 +278,6 @@ def parse_args(args: List[str]) -> Any: help='Adds dashed line fits to every curve.') a = parser.parse_args(args=args) - if not a.show and a.out is None: - raise ValueError("Must specify '--out file' or '--show'.") if 'custom_y' in a.type and a.y_func is None: raise ValueError("--type custom_y requires --y_func.") if a.y_func is not None and a.type and 'custom_y' not in a.type: @@ -265,35 +295,51 @@ def parse_args(args: List[str]) -> Any: a.failure_values_func = "1" if a.failure_unit_name is None: a.failure_unit_name = 'shot' - a.x_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.x_func}', - filename='x_func:command_line_arg', + + def _compile_argument_into_func(arg_name: str, arg_val: Any = ()): + if arg_val == (): + arg_val = getattr(a, arg_name) + raw_func = eval(compile( + f'lambda *, stat, decoder, metadata, m, strong_id, sinter, math, np: {arg_val}', + filename=f'{arg_name}:command_line_arg', + mode='eval', + )) + import sinter + return lambda stat: raw_func( + sinter=sinter, + math=math, + np=np, + stat=stat, + decoder=stat.decoder, + metadata=stat.json_metadata, + m=_FieldToMetadataWrapper(stat.json_metadata), + strong_id=stat.strong_id) + + a.preprocess_stats_func = None if a.preprocess_stats_func is None else eval(compile( + f'lambda *, stats: {a.preprocess_stats_func}', + filename='preprocess_stats_func:command_line_arg', mode='eval')) + a.x_func = _compile_argument_into_func('x_func', a.x_func) if a.y_func is not None: - a.y_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.y_func}', - filename='x_func:command_line_arg', - mode='eval')) - a.group_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.group_func}', - filename='group_func:command_line_arg', - mode='eval')) - a.filter_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.filter_func}', - filename='filter_func:command_line_arg', - mode='eval')) - a.failure_units_per_shot_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.failure_units_per_shot_func}', - filename='failure_units_per_shot_func:command_line_arg', - mode='eval')) - a.failure_values_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.failure_values_func}', - filename='failure_values_func:command_line_arg', - mode='eval')) - a.plot_args_func = eval(compile( + a.y_func = _compile_argument_into_func('y_func') + a.point_label_func = _compile_argument_into_func('point_label_func') + a.group_func = _compile_argument_into_func('group_func') + a.filter_func = _compile_argument_into_func('filter_func') + a.failure_units_per_shot_func = _compile_argument_into_func('failure_units_per_shot_func') + a.failure_values_func = _compile_argument_into_func('failure_values_func') + raw_plot_args_func = eval(compile( f'lambda *, index, key, stats, stat, decoder, metadata, m, strong_id: {a.plot_args_func}', filename='plot_args_func:command_line_arg', mode='eval')) + a.plot_args_func = lambda index, group_key, stats: raw_plot_args_func( + index=index, + key=group_key, + stats=stats, + stat=stats[0], + decoder=stats[0].decoder, + metadata=stats[0].json_metadata, + m=_FieldToMetadataWrapper(stats[0].json_metadata), + strong_id=stats[0].strong_id) return a @@ -361,12 +407,21 @@ def _pick_min_max( want_strictly_positive: bool, ) -> Tuple[float, float]: assert default_max >= default_min - vs = [ - v - for stat in plotted_stats - if (v := v_func(stat)) is not None - if v > 0 or not want_positive - ] + vs = [] + for stat in plotted_stats: + v = v_func(stat) + if isinstance(v, (int, float)): + vs.append(v) + elif isinstance(v, Fit): + for e in [v.low, v.best, v.high]: + if e is not None: + vs.append(e) + elif v is None: + pass + else: + raise NotImplementedError(f'{v=}') + if want_positive: + vs = [v for v in vs if v > 0] min_v = min(vs, default=default_min) max_v = max(vs, default=default_max) @@ -429,16 +484,24 @@ def _set_axis_scale_label_ticks( elif scale_name == 'log': set_scale('log') min_v, max_v, major_ticks, minor_ticks = _log_ticks(min_v, max_v) + if forced_min_v is not None: + min_v = forced_min_v + if forced_max_v is not None: + max_v = forced_max_v set_ticks(major_ticks) set_ticks(minor_ticks, minor=True) set_lim(min_v, max_v) elif scale_name == 'sqrt': from matplotlib.scale import FuncScale min_v, max_v, major_ticks, minor_ticks = _sqrt_ticks(min_v, max_v) - set_lim(min_v, max_v) + if forced_min_v is not None: + min_v = forced_min_v + if forced_max_v is not None: + max_v = forced_max_v set_scale(FuncScale(ax, (lambda e: e**0.5, lambda e: e**2))) set_ticks(major_ticks) set_ticks(minor_ticks, minor=True) + set_lim(min_v, max_v) else: raise NotImplemented(f'{scale_name=}') return scale_name @@ -468,6 +531,7 @@ def _plot_helper( samples: Union[Iterable['sinter.TaskStats'], ExistingData], group_func: Callable[['sinter.TaskStats'], Any], filter_func: Callable[['sinter.TaskStats'], Any], + preprocess_stats_func: Optional[Callable], failure_units_per_shot_func: Callable[['sinter.TaskStats'], Any], failure_values_func: Callable[['sinter.TaskStats'], Any], x_func: Callable[['sinter.TaskStats'], Any], @@ -486,6 +550,7 @@ def _plot_helper( fig_size: Optional[Tuple[int, int]], plot_args_func: Callable[[int, Any, List['sinter.TaskStats']], Dict[str, Any]], line_fits: bool, + point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None, ) -> Tuple[plt.Figure, List[plt.Axes]]: if isinstance(samples, ExistingData): total = samples @@ -497,6 +562,12 @@ def _plot_helper( for k, v in total.data.items() if bool(filter_func(v))} + if preprocess_stats_func is not None: + processed_stats = preprocess_stats_func(stats=list(total.data.values())) + total.data = {} + for stat in processed_stats: + total.add_sample(stat) + if not plot_types: if y_func is not None: plot_types = ['custom_y'] @@ -529,7 +600,6 @@ def _plot_helper( plotted_stats: List['sinter.TaskStats'] = [ stat for stat in total.data.values() - if filter_func(stat) ] def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: @@ -581,6 +651,7 @@ def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: highlight_max_likelihood_factor=highlight_max_likelihood_factor, plot_args_func=plot_args_func, line_fits=None if not line_fits else (x_scale_name, y_scale_name), + point_label_func=point_label_func, ) ax_err.grid(which='major', color='#000000') ax_err.grid(which='minor', color='#DDDDDD') @@ -595,6 +666,7 @@ def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: x_func=x_func, highlight_max_likelihood_factor=highlight_max_likelihood_factor, plot_args_func=plot_args_func, + point_label_func=point_label_func, ) ax_dis.set_yticks([p / 10 for p in range(11)], labels=[f'{10*p}%' for p in range(11)]) ax_dis.set_ylim(0, 1) @@ -626,9 +698,9 @@ def stat_to_err_rate(stat: 'sinter.TaskStats') -> Optional[float]: x_func=x_func, y_func=y_func, group_func=group_func, - filter_func=filter_func, plot_args_func=plot_args_func, line_fits=None if not line_fits else (x_scale_name, y_scale_name), + point_label_func=point_label_func, ) ax_cus.grid(which='major', color='#000000') ax_cus.grid(which='minor', color='#DDDDDD') @@ -710,51 +782,14 @@ def main_plot(*, command_line_args: List[str]): fig, _ = _plot_helper( samples=total, - group_func=lambda stat: args.group_func( - stat=stat, - decoder=stat.decoder, - metadata=stat.json_metadata, - m=_FieldToMetadataWrapper(stat.json_metadata), - strong_id=stat.strong_id), - x_func=lambda stat: args.x_func( - stat=stat, - decoder=stat.decoder, - metadata=stat.json_metadata, - m=_FieldToMetadataWrapper(stat.json_metadata), - strong_id=stat.strong_id), - y_func=None if args.y_func is None else lambda stat: args.y_func( - stat=stat, - decoder=stat.decoder, - metadata=stat.json_metadata, - m=_FieldToMetadataWrapper(stat.json_metadata), - strong_id=stat.strong_id), - filter_func=lambda stat: args.filter_func( - stat=stat, - decoder=stat.decoder, - metadata=stat.json_metadata, - m=_FieldToMetadataWrapper(stat.json_metadata), - strong_id=stat.strong_id), - failure_units_per_shot_func=lambda stat: args.failure_units_per_shot_func( - stat=stat, - decoder=stat.decoder, - metadata=stat.json_metadata, - m=_FieldToMetadataWrapper(stat.json_metadata), - strong_id=stat.strong_id), - failure_values_func=lambda stat: args.failure_values_func( - stat=stat, - decoder=stat.decoder, - metadata=stat.json_metadata, - m=_FieldToMetadataWrapper(stat.json_metadata), - strong_id=stat.strong_id), - plot_args_func=lambda index, group_key, stats: args.plot_args_func( - index=index, - key=group_key, - stats=stats, - stat=stats[0], - decoder=stats[0].decoder, - metadata=stats[0].json_metadata, - m=_FieldToMetadataWrapper(stats[0].json_metadata), - strong_id=stats[0].strong_id), + group_func=args.group_func, + x_func=args.x_func, + point_label_func=args.point_label_func, + y_func=args.y_func, + filter_func=args.filter_func, + failure_units_per_shot_func=args.failure_units_per_shot_func, + failure_values_func=args.failure_values_func, + plot_args_func=args.plot_args_func, failure_unit=args.failure_unit_name, plot_types=args.type, xaxis=args.xaxis, @@ -768,8 +803,9 @@ def main_plot(*, command_line_args: List[str]): title=args.title, subtitle=args.subtitle, line_fits=args.line_fits, + preprocess_stats_func=args.preprocess_stats_func, ) if args.out is not None: fig.savefig(args.out) - if args.show: + if args.show or args.out is None: plt.show() diff --git a/glue/sample/src/sinter/_main_plot_test.py b/glue/sample/src/sinter/_command/_main_plot_test.py similarity index 99% rename from glue/sample/src/sinter/_main_plot_test.py rename to glue/sample/src/sinter/_command/_main_plot_test.py index 689199a58..3408a08c3 100644 --- a/glue/sample/src/sinter/_main_plot_test.py +++ b/glue/sample/src/sinter/_command/_main_plot_test.py @@ -4,8 +4,8 @@ import tempfile import pytest -from sinter._main import main -from sinter._main_plot import _log_ticks, _sqrt_ticks +from sinter._command._main import main +from sinter._command._main_plot import _log_ticks, _sqrt_ticks def test_main_plot(): diff --git a/glue/sample/src/sinter/_main_predict.py b/glue/sample/src/sinter/_command/_main_predict.py similarity index 100% rename from glue/sample/src/sinter/_main_predict.py rename to glue/sample/src/sinter/_command/_main_predict.py diff --git a/glue/sample/src/sinter/_main_predict_test.py b/glue/sample/src/sinter/_command/_main_predict_test.py similarity index 95% rename from glue/sample/src/sinter/_main_predict_test.py rename to glue/sample/src/sinter/_command/_main_predict_test.py index 2dec070c8..4e819668b 100644 --- a/glue/sample/src/sinter/_main_predict_test.py +++ b/glue/sample/src/sinter/_command/_main_predict_test.py @@ -1,7 +1,7 @@ import pathlib import tempfile -from sinter._main import main +from sinter._command._main import main def test_main_predict(): diff --git a/glue/sample/src/sinter/_data/__init__.py b/glue/sample/src/sinter/_data/__init__.py new file mode 100644 index 000000000..41a0194a4 --- /dev/null +++ b/glue/sample/src/sinter/_data/__init__.py @@ -0,0 +1,20 @@ +from sinter._data._anon_task_stats import ( + AnonTaskStats, +) +from sinter._data._collection_options import ( + CollectionOptions, +) +from sinter._data._csv_out import ( + CSV_HEADER, +) +from sinter._data._existing_data import ( + read_stats_from_csv_files, + stats_from_csv_files, + ExistingData, +) +from sinter._data._task import ( + Task, +) +from sinter._data._task_stats import ( + TaskStats, +) diff --git a/glue/sample/src/sinter/_anon_task_stats.py b/glue/sample/src/sinter/_data/_anon_task_stats.py similarity index 84% rename from glue/sample/src/sinter/_anon_task_stats.py rename to glue/sample/src/sinter/_data/_anon_task_stats.py index d5e62f95b..60fee918b 100644 --- a/glue/sample/src/sinter/_anon_task_stats.py +++ b/glue/sample/src/sinter/_data/_anon_task_stats.py @@ -1,6 +1,9 @@ import collections import dataclasses -from typing import Counter +from typing import Counter, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from sinter._data._task_stats import TaskStats @dataclasses.dataclass(frozen=True) @@ -74,13 +77,13 @@ def __add__(self, other: 'AnonTaskStats') -> 'AnonTaskStats': >>> a + b sinter.AnonTaskStats(shots=1100, errors=220) """ - if not isinstance(other, AnonTaskStats): - return NotImplemented + if isinstance(other, AnonTaskStats): + return AnonTaskStats( + shots=self.shots + other.shots, + errors=self.errors + other.errors, + discards=self.discards + other.discards, + seconds=self.seconds + other.seconds, + custom_counts=self.custom_counts + other.custom_counts, + ) - return AnonTaskStats( - shots=self.shots + other.shots, - errors=self.errors + other.errors, - discards=self.discards + other.discards, - seconds=self.seconds + other.seconds, - custom_counts=self.custom_counts + other.custom_counts, - ) + return NotImplemented diff --git a/glue/sample/src/sinter/_anon_task_stats_test.py b/glue/sample/src/sinter/_data/_anon_task_stats_test.py similarity index 100% rename from glue/sample/src/sinter/_anon_task_stats_test.py rename to glue/sample/src/sinter/_data/_anon_task_stats_test.py diff --git a/glue/sample/src/sinter/_collection_options.py b/glue/sample/src/sinter/_data/_collection_options.py similarity index 100% rename from glue/sample/src/sinter/_collection_options.py rename to glue/sample/src/sinter/_data/_collection_options.py diff --git a/glue/sample/src/sinter/_collection_options_test.py b/glue/sample/src/sinter/_data/_collection_options_test.py similarity index 100% rename from glue/sample/src/sinter/_collection_options_test.py rename to glue/sample/src/sinter/_data/_collection_options_test.py diff --git a/glue/sample/src/sinter/_csv_out.py b/glue/sample/src/sinter/_data/_csv_out.py similarity index 100% rename from glue/sample/src/sinter/_csv_out.py rename to glue/sample/src/sinter/_data/_csv_out.py diff --git a/glue/sample/src/sinter/_existing_data.py b/glue/sample/src/sinter/_data/_existing_data.py similarity index 96% rename from glue/sample/src/sinter/_existing_data.py rename to glue/sample/src/sinter/_data/_existing_data.py index 925b2110f..077c01c6e 100644 --- a/glue/sample/src/sinter/_existing_data.py +++ b/glue/sample/src/sinter/_data/_existing_data.py @@ -3,9 +3,9 @@ import pathlib from typing import Any, Dict, List, TYPE_CHECKING -from sinter._task_stats import TaskStats -from sinter._task import Task -from sinter._decoding import AnonTaskStats +from sinter._data._task_stats import TaskStats +from sinter._data._task import Task +from sinter._data._anon_task_stats import AnonTaskStats if TYPE_CHECKING: import sinter @@ -26,8 +26,9 @@ def stats_for(self, case: Task) -> AnonTaskStats: def add_sample(self, sample: TaskStats) -> None: k = sample.strong_id - if k in self.data: - self.data[k] += sample + current = self.data.get(k) + if current is not None: + self.data[k] = current + sample else: self.data[k] = sample diff --git a/glue/sample/src/sinter/_existing_data_test.py b/glue/sample/src/sinter/_data/_existing_data_test.py similarity index 100% rename from glue/sample/src/sinter/_existing_data_test.py rename to glue/sample/src/sinter/_data/_existing_data_test.py diff --git a/glue/sample/src/sinter/_task.py b/glue/sample/src/sinter/_data/_task.py similarity index 99% rename from glue/sample/src/sinter/_task.py rename to glue/sample/src/sinter/_data/_task.py index c358bde3b..4f19b034a 100644 --- a/glue/sample/src/sinter/_task.py +++ b/glue/sample/src/sinter/_data/_task.py @@ -8,7 +8,7 @@ import numpy as np -from sinter._collection_options import CollectionOptions +from sinter._data._collection_options import CollectionOptions if TYPE_CHECKING: import sinter diff --git a/glue/sample/src/sinter/_task_stats.py b/glue/sample/src/sinter/_data/_task_stats.py similarity index 62% rename from glue/sample/src/sinter/_task_stats.py rename to glue/sample/src/sinter/_data/_task_stats.py index 73a2dd162..4d846e89e 100644 --- a/glue/sample/src/sinter/_task_stats.py +++ b/glue/sample/src/sinter/_data/_task_stats.py @@ -1,9 +1,27 @@ import collections import dataclasses from typing import Counter, List, Any +from typing import Optional +from typing import Union +from typing import overload -from sinter._anon_task_stats import AnonTaskStats -from sinter._csv_out import csv_line +from sinter._data._anon_task_stats import AnonTaskStats +from sinter._data._csv_out import csv_line + + +def _is_equal_json_values(json1: Any, json2: Any): + if json1 == json2: + return True + + if type(json1) == type(json2): + if isinstance(json1, dict): + return json1.keys() == json2.keys() and all(_is_equal_json_values(json1[k], json2[k]) for k in json1.keys()) + elif isinstance(json1, (list, tuple)): + return len(json1) == len(json2) and all(_is_equal_json_values(a, b) for a, b in zip(json1, json2)) + elif isinstance(json1, (list, tuple)) and isinstance(json2, (list, tuple)): + return _is_equal_json_values(tuple(json1), tuple(json2)) + + return False @dataclasses.dataclass(frozen=True) @@ -67,22 +85,70 @@ def __post_init__(self): assert self.shots >= self.errors + self.discards assert all(isinstance(k, str) and isinstance(v, int) for k, v in self.custom_counts.items()) - def __add__(self, other: 'TaskStats') -> 'TaskStats': - if self.strong_id != other.strong_id: - raise ValueError(f'{self.strong_id=} != {other.strong_id=}') - total = self.to_anon_stats() + other.to_anon_stats() - + def with_edits( + self, + *, + strong_id: Optional[str] = None, + decoder: Optional[str] = None, + json_metadata: Optional[Any] = None, + shots: Optional[int] = None, + errors: Optional[int] = None, + discards: Optional[int] = None, + seconds: Optional[float] = None, + custom_counts: Optional[Counter[str]] = None, + ) -> 'TaskStats': return TaskStats( - decoder=self.decoder, - strong_id=self.strong_id, - json_metadata=self.json_metadata, - shots=total.shots, - errors=total.errors, - discards=total.discards, - seconds=total.seconds, - custom_counts=total.custom_counts, + strong_id=self.strong_id if strong_id is None else strong_id, + decoder=self.decoder if decoder is None else decoder, + json_metadata=self.json_metadata if json_metadata is None else json_metadata, + shots=self.shots if shots is None else shots, + errors=self.errors if errors is None else errors, + discards=self.discards if discards is None else discards, + seconds=self.seconds if seconds is None else seconds, + custom_counts=self.custom_counts if custom_counts is None else custom_counts, ) + @overload + def __add__(self, other: AnonTaskStats) -> AnonTaskStats: + pass + @overload + def __add__(self, other: 'TaskStats') -> 'TaskStats': + pass + def __add__(self, other: Union[AnonTaskStats, 'TaskStats']) -> Union[AnonTaskStats, 'TaskStats']: + if isinstance(other, AnonTaskStats): + return self.to_anon_stats() + other + + if isinstance(other, TaskStats): + if self.strong_id != other.strong_id: + raise ValueError(f'{self.strong_id=} != {other.strong_id=}') + if not _is_equal_json_values(self.json_metadata, other.json_metadata) or self.decoder != other.decoder: + raise ValueError( + "A stat had the same strong id as another, but their other identifying information (json_metadata, decoder) differed.\n" + "The strong id is supposed to be a cryptographic hash that uniquely identifies what was sampled, so this is an error.\n" + "\n" + "This failure can occur when post-processing data (e.g. combining X basis stats and Z basis stats into synthetic both-basis stats).\n" + "To fix it, ensure any post-processing sets the strong id of the synthetic data in some cryptographically secure way.\n" + "\n" + "In some cases this can be caused by attempting to add a value that has gone through JSON serialization+parsing to one\n" + "that hasn't, which causes things like tuples transforming into lists.\n" + "\n" + f"The two stats:\n1. {self!r}\n2. {other!r}") + + total = self.to_anon_stats() + other.to_anon_stats() + return TaskStats( + decoder=self.decoder, + strong_id=self.strong_id, + json_metadata=self.json_metadata, + shots=total.shots, + errors=total.errors, + discards=total.discards, + seconds=total.seconds, + custom_counts=total.custom_counts, + ) + + return NotImplemented + __radd__ = __add__ + def to_anon_stats(self) -> AnonTaskStats: """Returns a `sinter.AnonTaskStats` with the same statistics. diff --git a/glue/sample/src/sinter/_task_stats_test.py b/glue/sample/src/sinter/_data/_task_stats_test.py similarity index 53% rename from glue/sample/src/sinter/_task_stats_test.py rename to glue/sample/src/sinter/_data/_task_stats_test.py index 0847b003f..4060c2f1d 100644 --- a/glue/sample/src/sinter/_task_stats_test.py +++ b/glue/sample/src/sinter/_data/_task_stats_test.py @@ -3,6 +3,7 @@ import pytest import sinter +from sinter._data._task_stats import _is_equal_json_values def test_repr(): @@ -87,3 +88,53 @@ def test_add(): seconds=52, custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), ) + + +def test_with_edits(): + v = sinter.TaskStats( + decoder='pymatching', + json_metadata={'a': 2}, + strong_id='abcdefDIFFERENT', + shots=270, + errors=34, + discards=43, + seconds=52, + custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), + ) + assert v.with_edits(json_metadata={'b': 3}) == sinter.TaskStats( + decoder='pymatching', + json_metadata={'b': 3}, + strong_id='abcdefDIFFERENT', + shots=270, + errors=34, + discards=43, + seconds=52, + custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), + ) + assert v == sinter.TaskStats(strong_id='', json_metadata={}, decoder='').with_edits( + decoder='pymatching', + json_metadata={'a': 2}, + strong_id='abcdefDIFFERENT', + shots=270, + errors=34, + discards=43, + seconds=52, + custom_counts=collections.Counter({'a': 11, 'b': 20, 'c': 3}), + ) + + +def test_is_equal_json_values(): + assert _is_equal_json_values([1, 2], (1, 2)) + assert _is_equal_json_values([1, [3, (5, 6)]], (1, (3, [5, 6]))) + assert not _is_equal_json_values([1, [3, (5, 6)]], (1, (3, [5, 7]))) + assert not _is_equal_json_values([1, [3, (5, 6)]], (1, (3, [5]))) + assert not _is_equal_json_values([1, 2], (1, 3)) + assert not _is_equal_json_values([1, 2], {1, 2}) + assert _is_equal_json_values({'x': [1, 2]}, {'x': (1, 2)}) + assert _is_equal_json_values({'x': (1, 2)}, {'x': (1, 2)}) + assert not _is_equal_json_values({'x': (1, 2)}, {'y': (1, 2)}) + assert not _is_equal_json_values({'x': (1, 2)}, {'x': (1, 2), 'y': []}) + assert not _is_equal_json_values({'x': (1, 2), 'y': []}, {'x': (1, 2)}) + assert not _is_equal_json_values({'x': (1, 2)}, {'x': (1, 3)}) + assert not _is_equal_json_values(1, 2) + assert _is_equal_json_values(1, 1) diff --git a/glue/sample/src/sinter/_task_test.py b/glue/sample/src/sinter/_data/_task_test.py similarity index 100% rename from glue/sample/src/sinter/_task_test.py rename to glue/sample/src/sinter/_data/_task_test.py diff --git a/glue/sample/src/sinter/_decoding/__init__.py b/glue/sample/src/sinter/_decoding/__init__.py new file mode 100644 index 000000000..c0aaebfde --- /dev/null +++ b/glue/sample/src/sinter/_decoding/__init__.py @@ -0,0 +1,16 @@ +from sinter._decoding._decoding import ( + streaming_post_select, + sample_decode, +) +from sinter._decoding._decoding_decoder_class import ( + CompiledDecoder, + Decoder, +) +from sinter._decoding._decoding_all_built_in_decoders import ( + BUILT_IN_DECODERS, + BUILT_IN_SAMPLERS, +) +from sinter._decoding._sampler import ( + Sampler, + CompiledSampler, +) diff --git a/glue/sample/src/sinter/_decoding.py b/glue/sample/src/sinter/_decoding/_decoding.py similarity index 98% rename from glue/sample/src/sinter/_decoding.py rename to glue/sample/src/sinter/_decoding/_decoding.py index 7170d443e..1e54f87ef 100644 --- a/glue/sample/src/sinter/_decoding.py +++ b/glue/sample/src/sinter/_decoding/_decoding.py @@ -11,9 +11,9 @@ import numpy as np import stim -from sinter._anon_task_stats import AnonTaskStats -from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS -from sinter._decoding_decoder_class import CompiledDecoder, Decoder +from sinter._data import AnonTaskStats +from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_DECODERS +from sinter._decoding._decoding_decoder_class import CompiledDecoder, Decoder if TYPE_CHECKING: import sinter diff --git a/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py new file mode 100644 index 000000000..b495d23f7 --- /dev/null +++ b/glue/sample/src/sinter/_decoding/_decoding_all_built_in_decoders.py @@ -0,0 +1,20 @@ +from typing import Dict +from typing import Union + +from sinter._decoding._decoding_decoder_class import Decoder +from sinter._decoding._decoding_fusion_blossom import FusionBlossomDecoder +from sinter._decoding._decoding_pymatching import PyMatchingDecoder +from sinter._decoding._decoding_vacuous import VacuousDecoder +from sinter._decoding._perfectionist_sampler import PerfectionistSampler +from sinter._decoding._sampler import Sampler + +BUILT_IN_DECODERS: Dict[str, Decoder] = { + 'vacuous': VacuousDecoder(), + 'pymatching': PyMatchingDecoder(), + 'fusion_blossom': FusionBlossomDecoder(), +} + +BUILT_IN_SAMPLERS: Dict[str, Union[Decoder, Sampler]] = { + **BUILT_IN_DECODERS, + 'perfectionist': PerfectionistSampler(), +} diff --git a/glue/sample/src/sinter/_decoding_decoder_class.py b/glue/sample/src/sinter/_decoding/_decoding_decoder_class.py similarity index 100% rename from glue/sample/src/sinter/_decoding_decoder_class.py rename to glue/sample/src/sinter/_decoding/_decoding_decoder_class.py diff --git a/glue/sample/src/sinter/_decoding_fusion_blossom.py b/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py similarity index 99% rename from glue/sample/src/sinter/_decoding_fusion_blossom.py rename to glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py index 966b69a4f..ac3e5bfb3 100644 --- a/glue/sample/src/sinter/_decoding_fusion_blossom.py +++ b/glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py @@ -5,7 +5,7 @@ import numpy as np import stim -from sinter._decoding_decoder_class import Decoder, CompiledDecoder +from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder if TYPE_CHECKING: import fusion_blossom diff --git a/glue/sample/src/sinter/_decoding_pymatching.py b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py similarity index 97% rename from glue/sample/src/sinter/_decoding_pymatching.py rename to glue/sample/src/sinter/_decoding/_decoding_pymatching.py index c8fb7464c..b57bb32bc 100644 --- a/glue/sample/src/sinter/_decoding_pymatching.py +++ b/glue/sample/src/sinter/_decoding/_decoding_pymatching.py @@ -1,4 +1,4 @@ -from sinter._decoding_decoder_class import Decoder, CompiledDecoder +from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder class PyMatchingCompiledDecoder(CompiledDecoder): diff --git a/glue/sample/src/sinter/_decoding_test.py b/glue/sample/src/sinter/_decoding/_decoding_test.py similarity index 98% rename from glue/sample/src/sinter/_decoding_test.py rename to glue/sample/src/sinter/_decoding/_decoding_test.py index 2ca9fbbca..50b31b26b 100644 --- a/glue/sample/src/sinter/_decoding_test.py +++ b/glue/sample/src/sinter/_decoding/_decoding_test.py @@ -11,9 +11,9 @@ import stim from sinter._collection import post_selection_mask_from_4th_coord -from sinter._decoding import sample_decode -from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS -from sinter._decoding_vacuous import VacuousDecoder +from sinter._decoding._decoding_all_built_in_decoders import BUILT_IN_DECODERS +from sinter._decoding._decoding import sample_decode +from sinter._decoding._decoding_vacuous import VacuousDecoder def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]: diff --git a/glue/sample/src/sinter/_decoding_vacuous.py b/glue/sample/src/sinter/_decoding/_decoding_vacuous.py similarity index 94% rename from glue/sample/src/sinter/_decoding_vacuous.py rename to glue/sample/src/sinter/_decoding/_decoding_vacuous.py index d3d3113b0..8e24d5516 100644 --- a/glue/sample/src/sinter/_decoding_vacuous.py +++ b/glue/sample/src/sinter/_decoding/_decoding_vacuous.py @@ -1,6 +1,6 @@ import numpy as np -from sinter._decoding_decoder_class import Decoder, CompiledDecoder +from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder class VacuousDecoder(Decoder): diff --git a/glue/sample/src/sinter/_decoding/_perfectionist_sampler.py b/glue/sample/src/sinter/_decoding/_perfectionist_sampler.py new file mode 100755 index 000000000..7478164f8 --- /dev/null +++ b/glue/sample/src/sinter/_decoding/_perfectionist_sampler.py @@ -0,0 +1,38 @@ +import time + +import numpy as np + +from sinter._data import Task, AnonTaskStats +from sinter._decoding._sampler import Sampler, CompiledSampler + + +class PerfectionistSampler(Sampler): + """Predicts obs aren't flipped. Discards shots with any detection events.""" + def compiled_sampler_for_task(self, task: Task) -> CompiledSampler: + return CompiledPerfectionistSampler(task) + + +class CompiledPerfectionistSampler(CompiledSampler): + def __init__(self, task: Task): + self.stim_sampler = task.circuit.compile_detector_sampler() + + def sample(self, max_shots: int) -> AnonTaskStats: + t0 = time.monotonic() + dets, obs = self.stim_sampler.sample( + shots=max_shots, + bit_packed=True, + separate_observables=True, + ) + num_shots = dets.shape[0] + discards = np.any(dets, axis=1) + errors = np.any(obs, axis=1) + num_discards = np.count_nonzero(discards) + num_errors = np.count_nonzero(errors & ~discards) + t1 = time.monotonic() + + return AnonTaskStats( + shots=num_shots, + errors=num_errors, + discards=num_discards, + seconds=t1 - t0, + ) diff --git a/glue/sample/src/sinter/_decoding/_sampler.py b/glue/sample/src/sinter/_decoding/_sampler.py new file mode 100644 index 000000000..0fe40ca11 --- /dev/null +++ b/glue/sample/src/sinter/_decoding/_sampler.py @@ -0,0 +1,72 @@ +import abc +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import sinter + + +class CompiledSampler(metaclass=abc.ABCMeta): + """A sampler that has been configured for efficiently sampling some task.""" + + @abc.abstractmethod + def sample(self, suggested_shots: int) -> 'sinter.AnonTaskStats': + """Samples shots and returns statistics. + + Args: + suggested_shots: The number of shots being requested. The sampler + may perform more shots or fewer shots than this, so technically + this argument can just be ignored. If a sampler is optimized for + a specific batch size, it can simply return one batch per call + regardless of this parameter. + + However, this parameter is a useful hint about the amount of + work being done. The sampler can use this to optimize its + behavior. For example, it could adjust its batch size downward + if the suggested shots is very small. Whereas if the suggested + shots is very high, the sampler should focus entirely on + achieving the best possible throughput. + + Note that, in typical workloads, the sampler will be called + repeatedly with the same value of suggested_shots. Therefore it + is reasonable to allocate buffers sized to accomodate the + current suggested_shots, expecting them to be useful again for + the next shot. + + Returns: + A sinter.AnonTaskStats saying how many shots were actually taken, + how many errors were seen, etc. + + The returned stats must have at least one shot. + """ + pass + + def handles_throttling(self) -> bool: + """Return True to disable sinter wrapping samplers with throttling. + + By default, sinter will wrap samplers so that they initially only do + a small number of shots then slowly ramp up. Sometimes this behavior + is not desired (e.g. in unit tests). Override this method to return True + to disable it. + """ + return False + + +class Sampler(metaclass=abc.ABCMeta): + """A strategy for producing stats from tasks. + + Call `sampler.compiled_sampler_for_task(task)` to get a compiled sampler for + a task, then call `compiled_sampler.sample(shots)` to collect statistics. + + A sampler differs from a `sinter.Decoder` because the sampler is responsible + for the full sampling process (e.g. simulating the circuit), whereas a + decoder can do nothing except predict observable flips from detection event + data. This prevents the decoders from cheating, but makes them less flexible + overall. A sampler can do things like use simulators other than stim, or + really anything at all as long as it ends with returning statistics about + shot counts, error counts, and etc. + """ + + @abc.abstractmethod + def compiled_sampler_for_task(self, task: 'sinter.Task') -> 'sinter.CompiledSampler': + """Creates, configures, and returns an object for sampling the task.""" + pass diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py new file mode 100755 index 000000000..ee28c53ec --- /dev/null +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py @@ -0,0 +1,222 @@ +import collections +import pathlib +import random +import time +from typing import Optional +from typing import Union + +import numpy as np + +from sinter._data import Task, AnonTaskStats +from sinter._decoding._sampler import Sampler, CompiledSampler +from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder + + +class StimThenDecodeSampler(Sampler): + """Samples shots using stim, then decodes using the given decoder. + + This is the default sampler; the one used to wrap decoders with no + specified sampler. + + The decoder's goal is to predict the observable flips given the detection + event data. Errors are when the prediction is wrong. Discards are when the + decoder returns an extra byte of prediction data for each shot, and the + extra byte is not zero. + """ + def __init__( + self, + *, + decoder: Decoder, + count_observable_error_combos: bool, + count_detection_events: bool, + tmp_dir: Optional[pathlib.Path], + ): + self.decoder = decoder + self.count_observable_error_combos = count_observable_error_combos + self.count_detection_events = count_detection_events + self.tmp_dir = tmp_dir + + def compiled_sampler_for_task(self, task: Task) -> CompiledSampler: + return _CompiledStimThenDecodeSampler( + decoder=self.decoder, + task=task, + count_detection_events=self.count_detection_events, + count_observable_error_combos=self.count_observable_error_combos, + tmp_dir=self.tmp_dir, + ) + + +def classify_discards_and_errors( + *, + actual_obs: np.ndarray, + predictions: np.ndarray, + postselected_observables_mask: Union[np.ndarray, None], + out_count_observable_error_combos: Union[None, collections.Counter[str]], + num_obs: int, +) -> tuple[int, int]: + num_discards = 0 + + # Added bytes are used for signalling discards. + if predictions.shape[1] == actual_obs.shape[1] + 1: + discard_mask = predictions[:, -1] != 0 + predictions = predictions[:, :-1] + num_discards += np.count_nonzero(discard_mask) + discard_mask ^= True + actual_obs = actual_obs[discard_mask] + predictions = predictions[discard_mask] + + # Mispredicted observables can be used for signalling discards. + if postselected_observables_mask is not None: + discard_mask = np.any((actual_obs ^ predictions) & postselected_observables_mask, axis=1) + num_discards += np.count_nonzero(discard_mask) + discard_mask ^= True + actual_obs = actual_obs[discard_mask] + predictions = predictions[discard_mask] + + fail_mask = np.any(actual_obs != predictions, axis=1) + if out_count_observable_error_combos is not None: + for k in np.flatnonzero(fail_mask): + mistakes = np.unpackbits(actual_obs[k] ^ predictions[k], count=num_obs, bitorder='little') + err_key = "obs_mistake_mask=" + ''.join('_E'[b] for b in mistakes) + out_count_observable_error_combos[err_key] += 1 + + num_errors = np.count_nonzero(fail_mask) + return num_discards, num_errors + + +class DiskDecoder(CompiledDecoder): + def __init__(self, decoder: Decoder, task: Task, tmp_dir: pathlib.Path): + self.decoder = decoder + self.task = task + self.top_tmp_dir: pathlib.Path = tmp_dir + + while True: + k = random.randint(0, 2**64) + self.top_tmp_dir = tmp_dir / f'disk_decoder_{k}' + try: + self.top_tmp_dir.mkdir() + break + except FileExistsError: + pass + self.decoder_tmp_dir: pathlib.Path = self.top_tmp_dir / 'dec' + self.decoder_tmp_dir.mkdir() + self.num_obs = task.detector_error_model.num_observables + self.num_dets = task.detector_error_model.num_detectors + self.dem_path = self.top_tmp_dir / 'dem.dem' + self.dets_b8_in_path = self.top_tmp_dir / 'dets.b8' + self.obs_predictions_b8_out_path = self.top_tmp_dir / 'obs.b8' + self.task.detector_error_model.to_file(self.dem_path) + + def decode_shots_bit_packed( + self, + *, + bit_packed_detection_event_data: np.ndarray, + ) -> np.ndarray: + num_shots = bit_packed_detection_event_data.shape[0] + with open(self.dets_b8_in_path, 'wb') as f: + bit_packed_detection_event_data.tofile(f) + self.decoder.decode_via_files( + num_shots=num_shots, + num_obs=self.num_obs, + num_dets=self.num_dets, + dem_path=self.dem_path, + dets_b8_in_path=self.dets_b8_in_path, + obs_predictions_b8_out_path=self.obs_predictions_b8_out_path, + tmp_dir=self.decoder_tmp_dir, + ) + num_obs_bytes = (self.num_obs + 7) // 8 + with open(self.obs_predictions_b8_out_path, 'rb') as f: + prediction = np.fromfile(f, dtype=np.uint8, count=num_obs_bytes * num_shots) + assert prediction.shape == (num_obs_bytes * num_shots,) + self.obs_predictions_b8_out_path.unlink() + self.dets_b8_in_path.unlink() + return prediction.reshape((num_shots, num_obs_bytes)) + + +def _compile_decoder_with_disk_fallback( + decoder: Decoder, + task: Task, + tmp_dir: Optional[pathlib.Path], +) -> CompiledDecoder: + try: + return decoder.compile_decoder_for_dem(dem=task.detector_error_model) + except (NotImplementedError, ValueError): + pass + if tmp_dir is None: + raise ValueError(f"Decoder {task.decoder=} didn't implement `compile_decoder_for_dem`, but no temporary directory was provided for falling back to `decode_via_files`.") + return DiskDecoder(decoder, task, tmp_dir) + + +class _CompiledStimThenDecodeSampler(CompiledSampler): + def __init__( + self, + *, + decoder: Decoder, + task: Task, + count_observable_error_combos: bool, + count_detection_events: bool, + tmp_dir: Optional[pathlib.Path], + ): + self.task = task + self.compiled_decoder = _compile_decoder_with_disk_fallback(decoder, task, tmp_dir) + self.stim_sampler = task.circuit.compile_detector_sampler() + self.count_observable_error_combos = count_observable_error_combos + self.count_detection_events = count_detection_events + self.num_det = self.task.circuit.num_detectors + self.num_obs = self.task.circuit.num_observables + + def sample(self, max_shots: int) -> AnonTaskStats: + t0 = time.monotonic() + dets, actual_obs = self.stim_sampler.sample( + shots=max_shots, + bit_packed=True, + separate_observables=True, + ) + num_shots = dets.shape[0] + + custom_counts = collections.Counter() + if self.count_detection_events: + custom_counts['detectors_checked'] += self.num_det * num_shots + for b in range(8): + custom_counts['detection_events'] += np.count_nonzero(dets & (1 << b)) + + # Discard any shots that contain a postselected detection events. + if self.task.postselection_mask is not None: + discarded_flags = np.any(dets & self.task.postselection_mask, axis=1) + num_discards_1 = np.count_nonzero(discarded_flags) + if num_discards_1: + dets = dets[~discarded_flags, :] + actual_obs = actual_obs[~discarded_flags, :] + else: + num_discards_1 = 0 + + predictions = self.compiled_decoder.decode_shots_bit_packed(bit_packed_detection_event_data=dets) + if not isinstance(predictions, np.ndarray): + raise ValueError("not isinstance(predictions, np.ndarray)") + if predictions.dtype != np.uint8: + raise ValueError("predictions.dtype != np.uint8") + if len(predictions.shape) != 2: + raise ValueError("len(predictions.shape) != 2") + if predictions.shape[0] != num_shots: + raise ValueError("predictions.shape[0] != num_shots") + if predictions.shape[1] < actual_obs.shape[1]: + raise ValueError("predictions.shape[1] < actual_obs.shape[1]") + if predictions.shape[1] > actual_obs.shape[1] + 1: + raise ValueError("predictions.shape[1] > actual_obs.shape[1] + 1") + + num_discards_2, num_errors = classify_discards_and_errors( + actual_obs=actual_obs, + predictions=predictions, + postselected_observables_mask=self.task.postselected_observables_mask, + out_count_observable_error_combos=custom_counts if self.count_observable_error_combos else None, + num_obs=self.num_obs, + ) + t1 = time.monotonic() + + return AnonTaskStats( + shots=num_shots, + errors=num_errors, + discards=num_discards_1 + num_discards_2, + seconds=t1 - t0, + custom_counts=custom_counts, + ) diff --git a/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py new file mode 100755 index 000000000..413015f0d --- /dev/null +++ b/glue/sample/src/sinter/_decoding/_stim_then_decode_sampler_test.py @@ -0,0 +1,192 @@ +import collections + +import numpy as np + +from sinter._decoding._stim_then_decode_sampler import \ + classify_discards_and_errors + + +def test_classify_discards_and_errors(): + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + predictions=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + postselected_observables_mask=None, + out_count_observable_error_combos=None, + num_obs=16, + ) == (0, 0) + + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + predictions=np.array([ + [0, 0], + [2, 2], + [3, 2], + [4, 1], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + postselected_observables_mask=None, + out_count_observable_error_combos=None, + num_obs=16, + ) == (0, 2) + + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + predictions=np.array([ + [0, 0, 0], + [2, 2, 0], + [3, 2, 0], + [4, 1, 0], + [1, 3, 0], + [0, 3, 0], + [0, 3, 0], + ], dtype=np.uint8), + postselected_observables_mask=None, + out_count_observable_error_combos=None, + num_obs=16, + ) == (0, 2) + + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + predictions=np.array([ + [0, 0, 0], + [2, 2, 1], + [3, 2, 0], + [4, 1, 0], + [1, 3, 0], + [0, 3, 0], + [0, 3, 0], + ], dtype=np.uint8), + postselected_observables_mask=None, + out_count_observable_error_combos=None, + num_obs=16, + ) == (1, 2) + + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + predictions=np.array([ + [0, 0, 1], + [2, 2, 0], + [3, 2, 0], + [4, 1, 0], + [1, 3, 0], + [0, 3, 0], + [0, 3, 0], + ], dtype=np.uint8), + postselected_observables_mask=None, + out_count_observable_error_combos=None, + num_obs=16, + ) == (1, 1) + + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [0, 3], + [0, 3], + ], dtype=np.uint8), + predictions=np.array([ + [0, 0, 1], + [2, 2, 1], + [3, 2, 0], + [4, 1, 0], + [1, 3, 0], + [0, 3, 0], + [0, 3, 0], + ], dtype=np.uint8), + postselected_observables_mask=None, + out_count_observable_error_combos=None, + num_obs=16, + ) == (2, 1) + + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [2, 2], + [3, 2], + [4, 3], + [1, 3], + [2, 3], + [1, 3], + ], dtype=np.uint8), + predictions=np.array([ + [0, 0, 1], + [2, 2, 1], + [3, 2, 0], + [4, 1, 0], + [1, 3, 0], + [0, 3, 0], + [0, 3, 0], + ], dtype=np.uint8), + postselected_observables_mask=np.array([1, 0]), + out_count_observable_error_combos=None, + num_obs=16, + ) == (3, 2) + + counter = collections.Counter() + assert classify_discards_and_errors( + actual_obs=np.array([ + [1, 2], + [1, 2], + ], dtype=np.uint8), + predictions=np.array([ + [1, 0], + [1, 2], + ], dtype=np.uint8), + postselected_observables_mask=np.array([1, 0]), + out_count_observable_error_combos=counter, + num_obs=13, + ) == (0, 1) + assert counter == collections.Counter(["obs_mistake_mask=_________E___"]) diff --git a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py b/glue/sample/src/sinter/_decoding_all_built_in_decoders.py deleted file mode 100644 index a9fc5e76c..000000000 --- a/glue/sample/src/sinter/_decoding_all_built_in_decoders.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Dict - -from sinter._decoding_decoder_class import Decoder -from sinter._decoding_fusion_blossom import FusionBlossomDecoder -from sinter._decoding_pymatching import PyMatchingDecoder -from sinter._decoding_vacuous import VacuousDecoder - -BUILT_IN_DECODERS: Dict[str, Decoder] = { - 'vacuous': VacuousDecoder(), - 'pymatching': PyMatchingDecoder(), - 'fusion_blossom': FusionBlossomDecoder(), -} diff --git a/glue/sample/src/sinter/_plotting.py b/glue/sample/src/sinter/_plotting.py index 24acfa209..a8e36cbaa 100644 --- a/glue/sample/src/sinter/_plotting.py +++ b/glue/sample/src/sinter/_plotting.py @@ -1,6 +1,6 @@ import math -import sys from typing import Callable, TypeVar, List, Any, Iterable, Optional, TYPE_CHECKING, Dict, Union, Literal, Tuple +from typing import Sequence from typing import cast import numpy as np @@ -13,6 +13,25 @@ MARKERS: str = "ov*sp^<>8PhH+xXDd|" * 100 +LINESTYLES: tuple[str, ...] = ( + 'solid', + 'dotted', + 'dashed', + 'dashdot', + 'loosely dotted', + 'dotted', + 'densely dotted', + 'long dash with offset', + 'loosely dashed', + 'dashed', + 'densely dashed', + 'loosely dashdotted', + 'dashdotted', + 'densely dashdotted', + 'dashdotdotted', + 'loosely dashdotdotted', + 'densely dashdotdotted', +) T = TypeVar('T') TVal = TypeVar('TVal') TKey = TypeVar('TKey') @@ -37,21 +56,30 @@ def split_by(vs: Iterable[T], key_func: Callable[[T], Any]) -> List[List[T]]: class LooseCompare: def __init__(self, val: Any): - self.val = val + self.val: Any = None - def __lt__(self, other): - if isinstance(other, LooseCompare): - other_val = other.val - else: - other_val = other + self.val = val.val if isinstance(val, LooseCompare) else val + + def __lt__(self, other: Any) -> bool: + other_val = other.val if isinstance(other, LooseCompare) else other if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)): return self.val < other_val + if isinstance(self.val, (tuple, list)) and isinstance(other_val, (tuple, list)): + return tuple(LooseCompare(e) for e in self.val) < tuple(LooseCompare(e) for e in other_val) return str(self.val) < str(other_val) - def __str__(self): + def __gt__(self, other: Any) -> bool: + other_val = other.val if isinstance(other, LooseCompare) else other + if isinstance(self.val, (int, float)) and isinstance(other_val, (int, float)): + return self.val > other_val + if isinstance(self.val, (tuple, list)) and isinstance(other_val, (tuple, list)): + return tuple(LooseCompare(e) for e in self.val) > tuple(LooseCompare(e) for e in other_val) + return str(self.val) > str(other_val) + + def __str__(self) -> str: return str(self.val) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, LooseCompare): other_val = other.val else: @@ -94,11 +122,12 @@ def better_sorted_str_terms(val: Any) -> Any: distance=199999, rounds=3 distance=199999, rounds=199999 """ - + if val is None: + return 'None' if isinstance(val, tuple): return tuple(better_sorted_str_terms(e) for e in val) if not isinstance(val, str): - return val + return LooseCompare(val) terms = split_by(val, lambda c: c in '.0123456789') result = [] for term in terms: @@ -117,6 +146,8 @@ def better_sorted_str_terms(val: Any) -> Any: except ValueError: pass result.append(term) + if len(result) == 1 and isinstance(result[0], (int, float)): + return LooseCompare(result[0]) return tuple(LooseCompare(e) for e in result) @@ -155,6 +186,45 @@ def group_by(items: Iterable[TVal], TCurveId = TypeVar('TCurveId') +class _FrozenDict: + def __init__(self, v: dict): + self._v = dict(v) + self._eq = frozenset(v.items()) + self._hash = hash(self._eq) + + terms = [] + for k in sorted(self._v.keys(), key=lambda e: (e != 'sort', e)): + terms.append(k) + terms.append(better_sorted_str_terms(self._v[k]) + ) + self._order = tuple(terms) + + def __eq__(self, other): + if isinstance(other, _FrozenDict): + return self._eq == other._eq + return NotImplemented + + def __lt__(self, other): + if isinstance(other, _FrozenDict): + return self._order < other._order + return NotImplemented + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return self._hash + + def __getitem__(self, item): + return self._v[item] + + def get(self, item, alternate = None): + return self._v.get(item, alternate) + + def __str__(self): + return " ".join(str(v) for _, v in sorted(self._v.items())) + + def plot_discard_rate( *, ax: 'plt.Axes', @@ -165,6 +235,7 @@ def plot_discard_rate( filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True, plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(), highlight_max_likelihood_factor: Optional[float] = 1e3, + point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None, ) -> None: """Plots discard rates in curves with uncertainty highlights. @@ -181,11 +252,21 @@ def plot_discard_rate( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -198,6 +279,7 @@ def plot_discard_rate( highlight_max_likelihood_factor: Controls how wide the uncertainty highlight region around curves is. Must be 1 or larger. Hypothesis probabilities at most that many times as unlikely as the max likelihood hypothesis will be highlighted. + point_label_func: Optional. Specifies text to draw next to data points. """ if highlight_max_likelihood_factor is None: highlight_max_likelihood_factor = 1 @@ -228,6 +310,7 @@ def y_func(stat: 'sinter.TaskStats') -> Union[float, 'sinter.Fit']: group_func=group_func, filter_func=filter_func, plot_args_func=plot_args_func, + point_label_func=point_label_func, ) @@ -243,6 +326,7 @@ def plot_error_rate( plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(), highlight_max_likelihood_factor: Optional[float] = 1e3, line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None, + point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None, ) -> None: """Plots error rates in curves with uncertainty highlights. @@ -263,11 +347,21 @@ def plot_error_rate( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -283,6 +377,7 @@ def plot_error_rate( line_fits: Defaults to None. Set this to a tuple (x_scale, y_scale) to include a dashed line fit to every curve. The scales determine how to transform the coordinates before performing the fit, and can be set to 'linear', 'sqrt', or 'log'. + point_label_func: Optional. Specifies text to draw next to data points. """ if highlight_max_likelihood_factor is None: highlight_max_likelihood_factor = 1 @@ -320,16 +415,17 @@ def y_func(stat: 'sinter.TaskStats') -> Union[float, 'sinter.Fit']: filter_func=filter_func, plot_args_func=plot_args_func, line_fits=line_fits, + point_label_func=point_label_func, ) -def _rescale(v: np.ndarray, scale: str, invert: bool) -> np.ndarray: +def _rescale(v: Sequence[float], scale: str, invert: bool) -> np.ndarray: if scale == 'linear': - return v + return np.array(v) elif scale == 'log': return np.exp(v) if invert else np.log(v) elif scale == 'sqrt': - return v**2 if invert else np.sqrt(v) + return np.array(v)**2 if invert else np.sqrt(v) else: raise NotImplementedError(f'{scale=}') @@ -341,6 +437,7 @@ def plot_custom( x_func: Callable[['sinter.TaskStats'], Any], y_func: Callable[['sinter.TaskStats'], Union['sinter.Fit', float, int]], group_func: Callable[['sinter.TaskStats'], TCurveId] = lambda _: None, + point_label_func: Callable[['sinter.TaskStats'], Any] = lambda _: None, filter_func: Callable[['sinter.TaskStats'], Any] = lambda _: True, plot_args_func: Callable[[int, TCurveId, List['sinter.TaskStats']], Dict[str, Any]] = lambda index, group_key, group_stats: dict(), line_fits: Optional[Tuple[Literal['linear', 'log', 'sqrt'], Literal['linear', 'log', 'sqrt']]] = None, @@ -358,11 +455,22 @@ def plot_custom( group_func: Optional. When specified, multiple curves will be plotted instead of one curve. The statistics are grouped into curves based on whether or not they get the same result out of this function. For example, this could be `group_func=lambda stat: stat.decoder`. + If the result of the function is a dictionary, then optional keys in the dictionary will + also control the plotting of each curve. Available keys are: + 'label': the label added to the legend for the curve + 'color': the color used for plotting the curve + 'marker': the marker used for the curve + 'linestyle': the linestyle used for the curve + 'sort': the order in which the curves will be plotted and added to the legend + e.g. if two curves (with different resulting dictionaries from group_func) share the same + value for key 'marker', they will be plotted with the same marker. + Colors, markers and linestyles are assigned in order, sorted by the values for those keys. + point_label_func: Optional. Specifies text to draw next to data points. filter_func: Optional. When specified, some curves will not be plotted. The statistics are filtered and only plotted if filter_func(stat) returns True. For example, `filter_func=lambda s: s.json_metadata['basis'] == 'x'` would plot only stats where the saved metadata indicates the basis was 'x'. - plot_args_func: Optional. Specifies additional arguments to give the the underlying calls to + plot_args_func: Optional. Specifies additional arguments to give the underlying calls to `plot` and `fill_between` used to do the actual plotting. For example, this can be used to specify markers and colors. Takes the index of the curve in sorted order and also a curve_id (these will be 0 and None respectively if group_func is not specified). For example, @@ -380,6 +488,10 @@ def plot_custom( performing the fit, and can be set to 'linear', 'sqrt', or 'log'. """ + def group_dict_func(item: 'sinter.TaskStats') -> _FrozenDict: + e = group_func(item) + return _FrozenDict(e if isinstance(e, dict) else {'label': str(e)}) + # Backwards compatibility to when the group stats argument wasn't present. import inspect if len(inspect.signature(plot_args_func).parameters) == 2: @@ -392,44 +504,95 @@ def plot_custom( if filter_func(stat) ] - curve_groups = group_by(filtered_stats, key=group_func) - for k, curve_id in enumerate(sorted(curve_groups.keys(), key=better_sorted_str_terms)): - this_group_stats = sorted(curve_groups[curve_id], key=x_func) - - xs = [] - ys = [] - xs_range = [] - ys_low = [] - ys_high = [] - saw_fit = False - for stat in this_group_stats: - num_kept = stat.shots - stat.discards - if num_kept == 0: - continue - x = float(x_func(stat)) - y = y_func(stat) + curve_groups = group_by(filtered_stats, key=group_dict_func) + colors = { + k: f'C{i}' + for i, k in enumerate(sorted({g.get('color', g) for g in curve_groups.keys()}, key=better_sorted_str_terms)) + } + markers = { + k: MARKERS[i % len(MARKERS)] + for i, k in enumerate(sorted({g.get('marker', g) for g in curve_groups.keys()}, key=better_sorted_str_terms)) + } + linestyles = { + k: LINESTYLES[i % len(LINESTYLES)] + for i, k in enumerate(sorted({g.get('linestyle', None) for g in curve_groups.keys()}, key=better_sorted_str_terms)) + } + + def sort_key(a: Any) -> Any: + if isinstance(a, _FrozenDict): + return a.get('sort', better_sorted_str_terms(a)) + return better_sorted_str_terms(a) + + for k, group_key in enumerate(sorted(curve_groups.keys(), key=sort_key)): + group = curve_groups[group_key] + group = sorted(group, key=x_func) + color = colors[group_key.get('color', group_key)] + marker = markers[group_key.get('marker', group_key)] + linestyle = linestyles[group_key.get('linestyle', None)] + label = str(group_key.get('label', group_key)) + xs_label: list[float] = [] + ys_label: list[float] = [] + vs_label: list[float] = [] + xs_best: list[float] = [] + ys_best: list[float] = [] + xs_low_high: list[float] = [] + ys_low: list[float] = [] + ys_high: list[float] = [] + for item in group: + x = x_func(item) + y = y_func(item) + point_label = point_label_func(item) if isinstance(y, Fit): - xs_range.append(x) - ys_low.append(y.low) - ys_high.append(y.high) - saw_fit = True - y = y.best - if not math.isnan(y): - xs.append(x) - ys.append(y) - - kwargs: Dict[str, Any] = dict(plot_args_func(k, curve_id, this_group_stats)) - kwargs.setdefault('marker', MARKERS[k]) - if curve_id is not None: - kwargs.setdefault('label', str(curve_id)) - kwargs.setdefault('color', f'C{k}') - kwargs.setdefault('color', 'black') - ax.plot(xs, ys, **kwargs) - - if line_fits is not None and len(set(xs)) >= 2: + if y.low is not None and y.high is not None and not math.isnan(y.low) and not math.isnan(y.high): + xs_low_high.append(x) + ys_low.append(y.low) + ys_high.append(y.high) + if y.best is not None and not math.isnan(y.best): + ys_best.append(y.best) + xs_best.append(x) + + if point_label: + cy = None + for e in [y.best, y.high, y.low]: + if e is not None and not math.isnan(e): + cy = e + break + if cy is not None: + xs_label.append(x) + ys_label.append(cy) + vs_label.append(point_label) + elif not math.isnan(y): + xs_best.append(x) + ys_best.append(y) + if point_label: + xs_label.append(x) + ys_label.append(y) + vs_label.append(point_label) + args = dict(plot_args_func(k, group_func(group[0]), group)) + if 'linestyle' not in args: + args['linestyle'] = linestyle + if 'marker' not in args: + args['marker'] = marker + if 'color' not in args: + args['color'] = color + if 'label' not in args: + args['label'] = label + ax.plot(xs_best, ys_best, **args) + for x, y, lbl in zip(xs_label, ys_label, vs_label): + if lbl: + ax.annotate(lbl, (x, y)) + if len(xs_low_high) > 1: + ax.fill_between(xs_low_high, ys_low, ys_high, color=args['color'], alpha=0.2, zorder=-100) + elif len(xs_low_high) == 1: + l, = ys_low + h, = ys_high + m = (l + h) / 2 + ax.errorbar(xs_low_high, [m], yerr=([m - l], [h - m]), marker='', elinewidth=1, ecolor=color, capsize=5) + + if line_fits is not None and len(set(xs_best)) >= 2: x_scale, y_scale = line_fits - fit_xs = _rescale(xs, x_scale, False) - fit_ys = _rescale(ys, y_scale, False) + fit_xs = _rescale(xs_best, x_scale, False) + fit_ys = _rescale(ys_best, y_scale, False) from scipy.stats import linregress line_fit = linregress(fit_xs, fit_ys) @@ -447,21 +610,10 @@ def plot_custom( out_xs = _rescale(out_xs, x_scale, True) out_ys = _rescale(out_ys, y_scale, True) - line_kwargs = kwargs.copy() - line_kwargs.pop('marker', None) - line_kwargs.pop('label', None) - line_kwargs['linestyle'] = '--' - line_kwargs.setdefault('linewidth', 1) - line_kwargs['linewidth'] /= 2 - ax.plot(out_xs, out_ys, **line_kwargs) - - if saw_fit: - fit_kwargs = kwargs.copy() - fit_kwargs.setdefault('zorder', 0) - fit_kwargs.setdefault('alpha', 1) - fit_kwargs['zorder'] -= 100 - fit_kwargs['alpha'] *= 0.25 - fit_kwargs.pop('marker', None) - fit_kwargs.pop('linestyle', None) - fit_kwargs.pop('label', None) - ax.fill_between(xs_range, ys_low, ys_high, **fit_kwargs) + line_fit_kwargs = args.copy() + line_fit_kwargs.pop('marker', None) + line_fit_kwargs.pop('label', None) + line_fit_kwargs['linestyle'] = '--' + line_fit_kwargs.setdefault('linewidth', 1) + line_fit_kwargs['linewidth'] /= 2 + ax.plot(out_xs, out_ys, **line_fit_kwargs) diff --git a/glue/sample/src/sinter/_plotting_test.py b/glue/sample/src/sinter/_plotting_test.py index 088a165f8..acf69102a 100644 --- a/glue/sample/src/sinter/_plotting_test.py +++ b/glue/sample/src/sinter/_plotting_test.py @@ -13,6 +13,9 @@ def test_better_sorted_str_terms(): assert f('a1b2') == ('a', 1, 'b', 2) assert f('a1.5b2') == ('a', 1.5, 'b', 2) assert f('a1.5.3b2') == ('a', (1, 5, 3), 'b', 2) + assert f(1) < f(None) + assert f(1) < f('2') + assert f('2') > f(1) assert sorted([ "planar d=10 r=30", "planar d=16 r=36", diff --git a/glue/sample/src/sinter/_predict.py b/glue/sample/src/sinter/_predict.py index eeefba040..f3a6336c8 100644 --- a/glue/sample/src/sinter/_predict.py +++ b/glue/sample/src/sinter/_predict.py @@ -8,9 +8,7 @@ from typing import Optional, Union, Dict, TYPE_CHECKING from sinter._collection import post_selection_mask_from_4th_coord -from sinter._decoding_decoder_class import Decoder -from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS -from sinter._decoding import streaming_post_select +from sinter._decoding import Decoder, BUILT_IN_DECODERS, streaming_post_select if TYPE_CHECKING: import sinter diff --git a/glue/sample/src/sinter/_probability_util.py b/glue/sample/src/sinter/_probability_util.py index 62cf0ed71..76849dcdd 100644 --- a/glue/sample/src/sinter/_probability_util.py +++ b/glue/sample/src/sinter/_probability_util.py @@ -2,6 +2,7 @@ import math import pathlib from typing import Any, Dict, Union, Callable, Sequence, TYPE_CHECKING, overload +from typing import Optional import numpy as np @@ -208,9 +209,9 @@ class Fit: of the best fit's square error, or whose likelihood was within some maximum Bayes factor of the max likelihood hypothesis. """ - low: float - best: float - high: float + low: Optional[float] + best: Optional[float] + high: Optional[float] def __repr__(self) -> str: return f'sinter.Fit(low={self.low!r}, best={self.best!r}, high={self.high!r})' diff --git a/glue/sample/src/sinter/_worker.py b/glue/sample/src/sinter/_worker.py deleted file mode 100644 index c01cd5ff9..000000000 --- a/glue/sample/src/sinter/_worker.py +++ /dev/null @@ -1,212 +0,0 @@ -import os - -from typing import Any, Optional, Tuple, TYPE_CHECKING, Dict -import tempfile - -if TYPE_CHECKING: - import multiprocessing - import numpy as np - import pathlib - import sinter - import stim - - -class WorkIn: - def __init__( - self, - *, - work_key: Any, - circuit_path: str, - dem_path: str, - decoder: str, - strong_id: Optional[str], - postselection_mask: 'Optional[np.ndarray]', - postselected_observables_mask: 'Optional[np.ndarray]', - json_metadata: Any, - count_observable_error_combos: bool, - count_detection_events: bool, - num_shots: int): - self.work_key = work_key - self.circuit_path = circuit_path - self.dem_path = dem_path - self.decoder = decoder - self.strong_id = strong_id - self.postselection_mask = postselection_mask - self.postselected_observables_mask = postselected_observables_mask - self.json_metadata = json_metadata - self.count_observable_error_combos = count_observable_error_combos - self.count_detection_events = count_detection_events - self.num_shots = num_shots - - def with_work_key(self, work_key: Any) -> 'WorkIn': - return WorkIn( - work_key=work_key, - circuit_path=self.circuit_path, - dem_path=self.dem_path, - decoder=self.decoder, - postselection_mask=self.postselection_mask, - postselected_observables_mask=self.postselected_observables_mask, - json_metadata=self.json_metadata, - strong_id=self.strong_id, - num_shots=self.num_shots, - count_observable_error_combos=self.count_observable_error_combos, - count_detection_events=self.count_detection_events, - ) - - -def auto_dem(circuit: 'stim.Circuit') -> 'stim.DetectorErrorModel': - """Converts a circuit into a detector error model, with some fallbacks. - - First attempts to do it with folding and decomposition, then tries - giving up on the folding, then tries giving up on the decomposition. - """ - try: - return circuit.detector_error_model( - allow_gauge_detectors=False, - approximate_disjoint_errors=True, - block_decomposition_from_introducing_remnant_edges=False, - decompose_errors=True, - flatten_loops=False, - ignore_decomposition_failures=False, - ) - except ValueError: - pass - - # This might be https://github.com/quantumlib/Stim/issues/393 - # Try turning off loop flattening. - try: - return circuit.detector_error_model( - allow_gauge_detectors=False, - approximate_disjoint_errors=True, - block_decomposition_from_introducing_remnant_edges=False, - decompose_errors=True, - flatten_loops=False, - ignore_decomposition_failures=False, - ) - except ValueError: - pass - - # Maybe decomposition is impossible, but the decoder might not need it. - # Try turning off error decomposition. - try: - return circuit.detector_error_model( - allow_gauge_detectors=False, - approximate_disjoint_errors=True, - block_decomposition_from_introducing_remnant_edges=False, - decompose_errors=False, - flatten_loops=True, - ignore_decomposition_failures=False, - ) - except ValueError: - pass - - # Okay turn them both off... - return circuit.detector_error_model( - allow_gauge_detectors=False, - approximate_disjoint_errors=True, - block_decomposition_from_introducing_remnant_edges=False, - decompose_errors=False, - flatten_loops=True, - ignore_decomposition_failures=False, - ) - - -class WorkOut: - def __init__( - self, - *, - work_key: Any, - stats: Optional['sinter.AnonTaskStats'], - strong_id: str, - msg_error: Optional[Tuple[str, BaseException]]): - self.work_key = work_key - self.stats = stats - self.strong_id = strong_id - self.msg_error = msg_error - - -def worker_loop(tmp_dir: 'pathlib.Path', - inp: 'multiprocessing.Queue', - out: 'multiprocessing.Queue', - custom_decoders: Optional[Dict[str, 'sinter.Decoder']], - core_affinity: Optional[int]) -> None: - try: - if core_affinity is not None and hasattr(os, 'sched_setaffinity'): - os.sched_setaffinity(0, {core_affinity}) - except: - # If setting the core affinity fails, we keep going regardless. - pass - - try: - with tempfile.TemporaryDirectory(dir=tmp_dir) as child_dir: - while True: - work: Optional[WorkIn] = inp.get() - if work is None: - return - out.put(do_work_safely(work, child_dir, custom_decoders)) - except KeyboardInterrupt: - pass - - -def do_work_safely(work: WorkIn, child_dir: str, custom_decoders: Dict[str, 'sinter.Decoder']) -> WorkOut: - try: - return do_work(work, child_dir, custom_decoders) - except BaseException as ex: - import traceback - return WorkOut( - work_key=work.work_key, - stats=None, - strong_id=work.strong_id, - msg_error=(traceback.format_exc(), ex), - ) - - -def do_work(work: WorkIn, child_dir: str, custom_decoders: Dict[str, 'sinter.Decoder']) -> WorkOut: - import stim - from sinter._task import Task - from sinter._decoding import sample_decode - - if work.strong_id is None: - # The work is to compute the DEM, as opposed to taking shots. - - circuit = stim.Circuit.from_file(work.circuit_path) - dem = auto_dem(circuit) - dem.to_file(work.dem_path) - - task = Task( - circuit=circuit, - decoder=work.decoder, - detector_error_model=dem, - postselection_mask=work.postselection_mask, - postselected_observables_mask=work.postselected_observables_mask, - json_metadata=work.json_metadata, - ) - - return WorkOut( - work_key=work.work_key, - stats=None, - strong_id=task.strong_id(), - msg_error=None, - ) - - stats: 'sinter.AnonTaskStats' = sample_decode( - num_shots=work.num_shots, - circuit_path=work.circuit_path, - circuit_obj=None, - dem_path=work.dem_path, - dem_obj=None, - post_mask=work.postselection_mask, - postselected_observable_mask=work.postselected_observables_mask, - decoder=work.decoder, - count_observable_error_combos=work.count_observable_error_combos, - count_detection_events=work.count_detection_events, - tmp_dir=child_dir, - custom_decoders=custom_decoders, - ) - - return WorkOut( - stats=stats, - work_key=work.work_key, - strong_id=work.strong_id, - msg_error=None, - ) diff --git a/glue/sample/src/sinter/_worker_test.py b/glue/sample/src/sinter/_worker_test.py deleted file mode 100644 index d8049ee10..000000000 --- a/glue/sample/src/sinter/_worker_test.py +++ /dev/null @@ -1,134 +0,0 @@ -import multiprocessing -import pathlib -import tempfile - -import stim - -from sinter._worker import WorkIn, WorkOut, worker_loop -from sinter._worker import auto_dem - - -def test_worker_loop_infers_dem(): - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = pathlib.Path(tmp_dir) - circuit = stim.Circuit(""" - M(0.2) 0 1 - DETECTOR rec[-1] - OBSERVABLE_INCLUDE(0) rec[-1] rec[-2] - """) - circuit_path = str(tmp_dir / 'input_circuit.stim') - dem_path = str(tmp_dir / 'input_dem.dem') - circuit.to_file(circuit_path) - inp = multiprocessing.Queue() - out = multiprocessing.Queue() - inp.put(WorkIn( - work_key='test1', - circuit_path=circuit_path, - dem_path=dem_path, - decoder='pymatching', - json_metadata=5, - strong_id=None, - num_shots=-1, - postselected_observables_mask=None, - postselection_mask=None, - count_detection_events=False, - count_observable_error_combos=False, - )) - inp.put(None) - worker_loop(tmp_dir, inp, out, None, 0) - result: WorkOut = out.get(timeout=1) - assert out.empty() - - assert result.stats is None - assert result.work_key == 'test1' - assert result.msg_error is None - assert stim.DetectorErrorModel.from_file(dem_path) == circuit.detector_error_model() - assert result.strong_id is not None - - -def test_worker_loop_does_not_recompute_dem(): - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_dir = pathlib.Path(tmp_dir) - circuit_path = str(tmp_dir / 'input_circuit.stim') - dem_path = str(tmp_dir / 'input_dem.dem') - stim.Circuit(""" - M(0.2) 0 1 - DETECTOR rec[-1] - OBSERVABLE_INCLUDE(0) rec[-1] rec[-2] - """).to_file(circuit_path) - stim.DetectorErrorModel(""" - error(0.234567) D0 L0 - """).to_file(dem_path) - - inp = multiprocessing.Queue() - out = multiprocessing.Queue() - inp.put(WorkIn( - work_key='test1', - circuit_path=circuit_path, - dem_path=dem_path, - decoder='pymatching', - json_metadata=5, - strong_id="fake", - num_shots=1000, - postselected_observables_mask=None, - postselection_mask=None, - count_detection_events=False, - count_observable_error_combos=False, - )) - inp.put(None) - worker_loop(tmp_dir, inp, out, None, 0) - result: WorkOut = out.get(timeout=1) - assert out.empty() - - assert result.stats.shots == 1000 - assert result.stats.discards == 0 - assert 0 < result.stats.errors < 1000 - assert result.work_key == 'test1' - assert result.msg_error is None - assert result.strong_id == 'fake' - - -def test_auto_dem(): - assert auto_dem(stim.Circuit(""" - REPEAT 100 { - CORRELATED_ERROR(0.125) X0 X1 - CORRELATED_ERROR(0.125) X0 X1 X2 X3 - MR 0 1 2 3 - DETECTOR rec[-4] - DETECTOR rec[-3] - DETECTOR rec[-2] - DETECTOR rec[-1] - } - """)) == stim.DetectorErrorModel(""" - REPEAT 99 { - error(0.125) D0 D1 - error(0.125) D0 D1 ^ D2 D3 - shift_detectors 4 - } - error(0.125) D0 D1 - error(0.125) D0 D1 ^ D2 D3 - """) - - assert auto_dem(stim.Circuit(""" - CORRELATED_ERROR(0.125) X0 X1 - CORRELATED_ERROR(0.125) X0 X1 X2 X3 - M 0 1 2 3 - DETECTOR rec[-4] - DETECTOR rec[-3] - DETECTOR rec[-2] - DETECTOR rec[-1] - """)) == stim.DetectorErrorModel(""" - error(0.125) D0 D1 - error(0.125) D0 D1 ^ D2 D3 - """) - - assert auto_dem(stim.Circuit(""" - CORRELATED_ERROR(0.125) X0 X1 X2 X3 - M 0 1 2 3 - DETECTOR rec[-4] - DETECTOR rec[-3] - DETECTOR rec[-2] - DETECTOR rec[-1] - """)) == stim.DetectorErrorModel(""" - error(0.125) D0 D1 D2 D3 - """) diff --git a/src/stim/circuit/circuit.pybind.cc b/src/stim/circuit/circuit.pybind.cc index 8c801e5cf..7798587b3 100644 --- a/src/stim/circuit/circuit.pybind.cc +++ b/src/stim/circuit/circuit.pybind.cc @@ -3270,7 +3270,7 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_ 'stim._DiagramHelper': + @signature def diagram(self, type: str = 'timeline-text', *, tick: Union[None, int, range] = None, filter_coords: Iterable[Union[Iterable[float], stim.DemTarget]] = ((),), rows: int | None = None) -> 'stim._DiagramHelper': Returns a diagram of the circuit, from a variety of options. Args: diff --git a/src/stim/util_bot/probability_util.h b/src/stim/util_bot/probability_util.h index 0c28d2a31..b338ba50d 100644 --- a/src/stim/util_bot/probability_util.h +++ b/src/stim/util_bot/probability_util.h @@ -70,9 +70,19 @@ struct RareErrorIterator { std::vector sample_hit_indices(float probability, size_t attempts, std::mt19937_64 &rng); +/// Create a fresh random number generator seeded by entropy from the operating system. std::mt19937_64 externally_seeded_rng(); + +/// Create a random number generator either seeded by a --seed argument, or else by entropy from the operating system. std::mt19937_64 optionally_seeded_rng(int argc, const char **argv); +/// Overwrite the given span with random data where bits are set with the given probability. +/// +/// Args: +/// probability: The chance that each bit will be on. +/// start: Inclusive start of the memory span to overwrite. +/// end: Exclusive end of the memory span to overwrite. +/// rng: The random number generator to use to generate entropy. void biased_randomize_bits(float probability, uint64_t *start, uint64_t *end, std::mt19937_64 &rng); } // namespace stim