From 2e84e90be33fe979b5244fd4b987d22e48e1c677 Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Thu, 28 Mar 2024 18:15:56 +0800 Subject: [PATCH 1/8] FEAT: add custom sampler for sinter --- glue/sample/src/sinter/__init__.py | 8 + glue/sample/src/sinter/_collection.py | 25 ++ .../_collection_tracker_for_single_task.py | 4 +- .../src/sinter/_collection_work_manager.py | 44 ++- glue/sample/src/sinter/_csv_out.py | 4 + glue/sample/src/sinter/_existing_data.py | 3 +- glue/sample/src/sinter/_existing_data_test.py | 13 + glue/sample/src/sinter/_main_collect.py | 49 ++++ glue/sample/src/sinter/_main_collect_test.py | 3 +- glue/sample/src/sinter/_main_combine.py | 2 + glue/sample/src/sinter/_main_combine_test.py | 100 +++---- glue/sample/src/sinter/_predict.py | 2 +- .../sinter/_sampling_all_built_in_samplers.py | 10 + ..._decoding.py => _sampling_and_decoding.py} | 230 ++++++++++------ ...test.py => _sampling_and_decoding_test.py} | 260 +++++++++++++----- .../src/sinter/_sampling_sampler_class.py | 123 +++++++++ glue/sample/src/sinter/_sampling_stim.py | 44 +++ glue/sample/src/sinter/_sampling_vacuous.py | 47 ++++ glue/sample/src/sinter/_task.py | 15 + glue/sample/src/sinter/_task_stats.py | 7 + glue/sample/src/sinter/_task_stats_test.py | 2 +- glue/sample/src/sinter/_worker.py | 30 +- glue/sample/src/sinter/_worker_test.py | 6 +- 23 files changed, 806 insertions(+), 225 deletions(-) create mode 100644 glue/sample/src/sinter/_sampling_all_built_in_samplers.py rename glue/sample/src/sinter/{_decoding.py => _sampling_and_decoding.py} (67%) rename glue/sample/src/sinter/{_decoding_test.py => _sampling_and_decoding_test.py} (62%) create mode 100644 glue/sample/src/sinter/_sampling_sampler_class.py create mode 100644 glue/sample/src/sinter/_sampling_stim.py create mode 100644 glue/sample/src/sinter/_sampling_vacuous.py diff --git a/glue/sample/src/sinter/__init__.py b/glue/sample/src/sinter/__init__.py index 1237b79b4..a9af7a592 100644 --- a/glue/sample/src/sinter/__init__.py +++ b/glue/sample/src/sinter/__init__.py @@ -18,6 +18,9 @@ from sinter._decoding_all_built_in_decoders import ( BUILT_IN_DECODERS, ) +from sinter._sampling_all_built_in_samplers import ( + BUILT_IN_SAMPLERS, +) from sinter._existing_data import ( read_stats_from_csv_files, stats_from_csv_files, @@ -54,3 +57,8 @@ CompiledDecoder, Decoder, ) + +from sinter._sampling_sampler_class import ( + CompiledSampler, + Sampler, +) diff --git a/glue/sample/src/sinter/_collection.py b/glue/sample/src/sinter/_collection.py index 40bfceef6..d1d918d13 100644 --- a/glue/sample/src/sinter/_collection.py +++ b/glue/sample/src/sinter/_collection.py @@ -45,12 +45,14 @@ def iter_collect(*, additional_existing_data: Optional[ExistingData] = None, max_shots: Optional[int] = None, max_errors: Optional[int] = None, + samplers: Optional[Iterable[str]] = None, decoders: Optional[Iterable[str]] = None, max_batch_seconds: Optional[int] = None, max_batch_size: Optional[int] = None, start_batch_size: Optional[int] = None, count_observable_error_combos: bool = False, count_detection_events: bool = False, + custom_samplers: Optional[Dict[str, "sinter.Sampler"]] = None, custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None, custom_error_count_key: Optional[str] = None, allowed_cpu_affinity_ids: Optional[Iterable[int]] = None, @@ -74,6 +76,10 @@ def iter_collect(*, additional_existing_data: Defaults to None (no additional data). Statistical data that has already been collected, in addition to anything included in each task's `previous_stats` field. + samplers: Defaults to None (specified by each Task). The names of the + samplers to use on each Task. It must either be the case that each + Task specifies a sampler and this is set to None, or this is an + iterable and each Task has its sampler set to None. decoders: Defaults to None (specified by each Task). The names of the decoders to use on each Task. It must either be the case that each Task specifies a decoder and this is set to None, or this is an @@ -107,6 +113,9 @@ def iter_collect(*, taken per shot. This information is then used to predict the biggest batch size that can finish in under the given number of seconds. Limits each batch to be no larger than that. + custom_samplers: Custom samplers that can be used if requested by name. + If not specified, only samplers built into sinter, such as 'stim', + can be used. custom_decoders: Custom decoders that can be used if requested by name. If not specified, only decoders built into sinter, such as 'pymatching' and 'fusion_blossom', can be used. @@ -156,6 +165,8 @@ def iter_collect(*, >>> print(total_shots) 200 """ + if isinstance(samplers, str): + samplers = [samplers] if isinstance(decoders, str): decoders = [decoders] @@ -175,10 +186,12 @@ def iter_collect(*, start_batch_size=start_batch_size, max_batch_size=max_batch_size, ), + samplers=samplers, decoders=decoders, count_observable_error_combos=count_observable_error_combos, count_detection_events=count_detection_events, additional_existing_data=additional_existing_data, + custom_samplers=custom_samplers, custom_decoders=custom_decoders, custom_error_count_key=custom_error_count_key, allowed_cpu_affinity_ids=allowed_cpu_affinity_ids, @@ -228,12 +241,14 @@ def collect(*, max_errors: Optional[int] = None, count_observable_error_combos: bool = False, count_detection_events: bool = False, + samplers: Optional[Iterable[str]] = None, decoders: Optional[Iterable[str]] = None, max_batch_seconds: Optional[int] = None, max_batch_size: Optional[int] = None, start_batch_size: Optional[int] = None, print_progress: bool = False, hint_num_tasks: Optional[int] = None, + custom_samplers: Optional[Dict[str, "sinter.Sampler"]] = None, custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None, custom_error_count_key: Optional[str] = None, allowed_cpu_affinity_ids: Optional[Iterable[int]] = None, @@ -260,6 +275,10 @@ def collect(*, hint_num_tasks: If `tasks` is an iterator or a generator, its length can be given here so that progress printouts can say how many cases are left. + samplers: Defaults to None (specified by each Task). The names of the + samplers to use on each Task. It must either be the case that each + Task specifies a sampler and this is set to None, or this is an + iterable and each Task has its sampler set to None. decoders: Defaults to None (specified by each Task). The names of the decoders to use on each Task. It must either be the case that each Task specifies a decoder and this is set to None, or this is an @@ -295,6 +314,10 @@ def collect(*, taken per shot. This information is then used to predict the biggest batch size that can finish in under the given number of seconds. Limits each batch to be no larger than that. + custom_samplers: Named child classes of `sinter.Sampler`, that can be + used if requested by name by a task or by the samplers list. + If not specified, only samplers with support built into sinter, such + as 'stim', can be used. custom_decoders: Named child classes of `sinter.decoder`, that can be used if requested by name by a task or by the decoders list. If not specified, only decoders with support built into sinter, such @@ -386,10 +409,12 @@ def collect(*, max_batch_size=max_batch_size, count_observable_error_combos=count_observable_error_combos, count_detection_events=count_detection_events, + samplers=samplers, decoders=decoders, tasks=tasks, hint_num_tasks=hint_num_tasks, additional_existing_data=additional_existing_data, + custom_samplers=custom_samplers, custom_decoders=custom_decoders, custom_error_count_key=custom_error_count_key, allowed_cpu_affinity_ids=allowed_cpu_affinity_ids, diff --git a/glue/sample/src/sinter/_collection_tracker_for_single_task.py b/glue/sample/src/sinter/_collection_tracker_for_single_task.py index 0932b55e6..e34197b00 100644 --- a/glue/sample/src/sinter/_collection_tracker_for_single_task.py +++ b/glue/sample/src/sinter/_collection_tracker_for_single_task.py @@ -2,8 +2,6 @@ from typing import Iterator from typing import Optional -import stim - from sinter._anon_task_stats import AnonTaskStats from sinter._existing_data import ExistingData from sinter._task import Task @@ -167,6 +165,7 @@ def provide_more_work(self, *, desperate: bool) -> Optional[WorkIn]: work_key=None, circuit_path=self.circuit_path, dem_path=self.dem_path, + sampler=self.unfilled_task.sampler, decoder=self.unfilled_task.decoder, postselection_mask=self.unfilled_task.postselection_mask, postselected_observables_mask=self.unfilled_task.postselected_observables_mask, @@ -189,6 +188,7 @@ def provide_more_work(self, *, desperate: bool) -> Optional[WorkIn]: strong_id=self.task_strong_id, circuit_path=self.circuit_path, dem_path=self.dem_path, + sampler=self.unfilled_task.sampler, decoder=self.unfilled_task.decoder, postselection_mask=self.unfilled_task.postselection_mask, postselected_observables_mask=self.unfilled_task.postselected_observables_mask, diff --git a/glue/sample/src/sinter/_collection_work_manager.py b/glue/sample/src/sinter/_collection_work_manager.py index d32db6985..6d795cd4a 100644 --- a/glue/sample/src/sinter/_collection_work_manager.py +++ b/glue/sample/src/sinter/_collection_work_manager.py @@ -5,8 +5,10 @@ import pathlib import tempfile import stim +import itertools from typing import cast, Iterable, Optional, Iterator, Tuple, Dict, List +from sinter._sampling_sampler_class import Sampler from sinter._decoding_decoder_class import Decoder from sinter._collection_options import CollectionOptions from sinter._existing_data import ExistingData @@ -26,11 +28,14 @@ def __init__( additional_existing_data: Optional[ExistingData], count_observable_error_combos: bool, count_detection_events: bool, + samplers: Optional[Iterable[str]], + custom_samplers: Dict[str, Sampler], decoders: Optional[Iterable[str]], custom_decoders: Dict[str, Decoder], custom_error_count_key: Optional[str], allowed_cpu_affinity_ids: Optional[List[int]], ): + self.custom_samplers = custom_samplers self.custom_decoders = custom_decoders self.queue_from_workers: Optional[multiprocessing.Queue] = None self.queue_to_workers: Optional[multiprocessing.Queue] = None @@ -41,6 +46,7 @@ def __init__( self.allowed_cpu_affinity_ids = None if allowed_cpu_affinity_ids is None else sorted(set(allowed_cpu_affinity_ids)) self.global_collection_options = global_collection_options + self.samplers: Optional[Tuple[str, ...]] = None if samplers is None else tuple(samplers) self.decoders: Optional[Tuple[str, ...]] = None if decoders is None else tuple(decoders) self.did_work = False @@ -53,10 +59,12 @@ def __init__( self.count_observable_error_combos = count_observable_error_combos self.count_detection_events = count_detection_events - self.tasks_with_decoder_iter: Iterator[Task] = _iter_tasks_with_assigned_decoders( + self.tasks_with_sampler_decoder_iter: Iterator[Task] = _iter_tasks_with_assigned_samplers_decoders( tasks_iter=tasks_iter, + default_samplers=self.samplers, default_decoders=self.decoders, - global_collections_options=self.global_collection_options) + global_collections_options=self.global_collection_options, + ) def start_workers(self, num_workers: int) -> None: assert self.tmp_dir is not None @@ -81,7 +89,15 @@ def start_workers(self, num_workers: int) -> None: cpu_pin = None if len(cpus) == 0 else cpus[index % len(cpus)] w = multiprocessing.Process( target=worker_loop, - args=(self.tmp_dir, self.queue_to_workers, self.queue_from_workers, self.custom_decoders, cpu_pin)) + args=( + self.tmp_dir, + self.queue_to_workers, + self.queue_from_workers, + self.custom_samplers, + self.custom_decoders, + cpu_pin, + ), + ) w.start() self.workers.append(w) finally: @@ -89,7 +105,9 @@ def start_workers(self, num_workers: int) -> None: def __enter__(self): self.exit_stack = contextlib.ExitStack().__enter__() - self.tmp_dir = pathlib.Path(self.exit_stack.enter_context(tempfile.TemporaryDirectory())) + self.tmp_dir = pathlib.Path( + self.exit_stack.enter_context(tempfile.TemporaryDirectory()) + ) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -150,6 +168,7 @@ def wait_for_next_sample(self, stats = AnonTaskStats() return TaskStats( strong_id=result.strong_id, + sampler=work_in.sampler, decoder=work_in.decoder, json_metadata=work_in.json_metadata, shots=stats.shots, @@ -165,7 +184,7 @@ def _iter_draw_collectors(self, *, prefer_started: bool) -> Iterator[Tuple[int, while True: key = self.next_collector_key try: - task = next(self.tasks_with_decoder_iter) + task = next(self.tasks_with_sampler_decoder_iter) except StopIteration: break collector = CollectionTrackerForSingleTask( @@ -236,9 +255,10 @@ def status(self, *, num_circuits: Optional[int]) -> str: return main_status + ''.join(collector_statuses) -def _iter_tasks_with_assigned_decoders( +def _iter_tasks_with_assigned_samplers_decoders( *, tasks_iter: Iterator[Task], + default_samplers: Optional[Iterable[str]], default_decoders: Optional[Iterable[str]], global_collections_options: CollectionOptions, ) -> Iterator[Task]: @@ -246,6 +266,7 @@ def _iter_tasks_with_assigned_decoders( if task.circuit is None: task = Task( circuit=stim.Circuit.from_file(task.circuit_path), + sampler=task.sampler, decoder=task.decoder, detector_error_model=task.detector_error_model, postselection_mask=task.postselection_mask, @@ -255,16 +276,25 @@ def _iter_tasks_with_assigned_decoders( circuit_path=task.circuit_path, ) + if task.sampler is None and default_samplers is None: + raise ValueError("Samplers to use was not specified. samplers is None and task.sampler is None") if task.decoder is None and default_decoders is None: raise ValueError("Decoders to use was not specified. decoders is None and task.decoder is None") + task_samplers = [] + if default_samplers is not None: + task_samplers.extend(default_samplers) + if task.sampler is not None and task.sampler not in task_samplers: + task_samplers.append(task.sampler) + task_decoders = [] if default_decoders is not None: task_decoders.extend(default_decoders) if task.decoder is not None and task.decoder not in task_decoders: task_decoders.append(task.decoder) - for decoder in task_decoders: + for sampler, decoder in itertools.product(task_samplers, task_decoders): yield Task( circuit=task.circuit, + sampler=sampler, decoder=decoder, detector_error_model=task.detector_error_model, postselection_mask=task.postselection_mask, diff --git a/glue/sample/src/sinter/_csv_out.py b/glue/sample/src/sinter/_csv_out.py index 2feb0fa2b..054bc43c2 100644 --- a/glue/sample/src/sinter/_csv_out.py +++ b/glue/sample/src/sinter/_csv_out.py @@ -19,6 +19,7 @@ def csv_line(*, errors: Any, discards: Any, seconds: Any, + sampler: Any, decoder: Any, strong_id: Any, json_metadata: Any, @@ -50,6 +51,7 @@ def csv_line(*, errors = escape_csv(errors, 10) discards = escape_csv(discards, 10) seconds = escape_csv(seconds, 8) + sampler = escape_csv(sampler, None) decoder = escape_csv(decoder, None) strong_id = escape_csv(strong_id, None) json_metadata = escape_csv(json_metadata, None) @@ -57,6 +59,7 @@ def csv_line(*, f'{errors},' f'{discards},' f'{seconds},' + f'{sampler},' f'{decoder},' f'{strong_id},' f'{json_metadata},' @@ -68,6 +71,7 @@ def csv_line(*, discards='discards', seconds='seconds', strong_id='strong_id', + sampler='sampler', decoder='decoder', json_metadata='json_metadata', custom_counts='custom_counts', diff --git a/glue/sample/src/sinter/_existing_data.py b/glue/sample/src/sinter/_existing_data.py index 925b2110f..2deb1d608 100644 --- a/glue/sample/src/sinter/_existing_data.py +++ b/glue/sample/src/sinter/_existing_data.py @@ -5,7 +5,7 @@ from sinter._task_stats import TaskStats from sinter._task import Task -from sinter._decoding import AnonTaskStats +from sinter._sampling_and_decoding import AnonTaskStats if TYPE_CHECKING: import sinter @@ -81,6 +81,7 @@ def from_file(path_or_file: Any) -> 'ExistingData': custom_counts=custom_counts, seconds=float(row['seconds']), strong_id=row['strong_id'], + sampler=row.get('sampler', 'stim'), decoder=row['decoder'], json_metadata=json.loads(row['json_metadata']), )) diff --git a/glue/sample/src/sinter/_existing_data_test.py b/glue/sample/src/sinter/_existing_data_test.py index 250c76541..db742f3dd 100644 --- a/glue/sample/src/sinter/_existing_data_test.py +++ b/glue/sample/src/sinter/_existing_data_test.py @@ -39,3 +39,16 @@ def test_read_stats_from_csv_files(): sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=2600, errors=8, discards=120, seconds=8.0, custom_counts=collections.Counter({'dets': 1234})), sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=4000, errors=0, discards=20, seconds=4.0), ] + + with open(d / 'tmp3.csv', 'w') as f: + print(""" +shots,errors,discards,seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 300, 1, 20, 1.0,stim,pymatching,abc123,"{""d"":3}","{""dets"":1234}" + 1000, 3, 40, 3.0,stim,pymatching,abc123,"{""d"":3}", + 2000, 0, 10, 2.0,mock,pymatching,def456,"{""d"":5}" + """.strip(), file=f) + + assert sinter.read_stats_from_csv_files(d / 'tmp3.csv') == [ + sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})), + sinter.TaskStats(strong_id='def456', sampler='mock', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), + ] diff --git a/glue/sample/src/sinter/_main_collect.py b/glue/sample/src/sinter/_main_collect.py index b53dc392d..ed0e2a1df 100644 --- a/glue/sample/src/sinter/_main_collect.py +++ b/glue/sample/src/sinter/_main_collect.py @@ -13,6 +13,7 @@ from sinter._task import Task from sinter._collection import collect, Progress, post_selection_mask_from_predicate from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS +from sinter._sampling_all_built_in_samplers import BUILT_IN_SAMPLERS from sinter._main_combine import ExistingData, CSV_HEADER @@ -61,6 +62,18 @@ def parse_args(args: List[str]) -> Any: required=True, help='Circuit files to sample from and decode.\n' 'This parameter can be given multiple arguments.') + parser.add_argument("--samplers", + default=["stim"], + type=str, + nargs="+", + help="The samplers to use to sample detectors from circuits, defaults to `sinter.StimDetectorSampler`.") + parser.add_argument("--custom_samplers_module_function", + default=None, + nargs="+", + help='Use the syntax "module:function" to "import function from module" ' + 'and use the result of "function()" as the custom_samplers ' + "dictionary. The dictionary must map strings to sinter.Sampler " + "instances.") parser.add_argument('--decoders', type=str, nargs='+', @@ -226,6 +239,40 @@ def parse_args(args: List[str]) -> Any: 'lambda index, metadata, coords: ' + cast(str, a.postselected_detectors_predicate), filename='postselected_detectors_predicate:command_line_arg', mode='eval')) + + if a.custom_samplers_module_function is not None: + all_custom_samplers = {} + for entry in a.custom_samplers_module_function: + terms = entry.split(":") + if len(terms) != 2: + raise ValueError( + "--custom_samplers_module_function didn't have exactly one colon " + "separating a module name from a function name. Expected an argument " + "of the form --custom_samplers_module_function 'module:function'" + ) + module, function = terms + vals = {"__name__": "[]"} + exec(f"from {module} import {function} as _custom_samplers", vals) + custom_samplers = vals["_custom_samplers"]() + all_custom_samplers = {**all_custom_samplers, **custom_samplers} + a.custom_samplers = all_custom_samplers + else: + a.custom_samplers = None + + for sampler in a.samplers: + if sampler not in BUILT_IN_SAMPLERS and ( + a.custom_samplers is None or sampler not in a.custom_samplers + ): + message = f"Not a recognized sampler: {sampler=}.\n" + message += f"Available built-in samplers: {sorted(e for e in BUILT_IN_SAMPLERS.keys() if 'internal' not in e)}.\n" + if a.custom_samplers is None: + message += "No custom samplers are available. --custom_samplers_module_function wasn't specified." + else: + message += ( + f"Available custom samplers: {sorted(a.custom_samplers.keys())}." + ) + raise ValueError(message) + if a.custom_decoders_module_function is not None: all_custom_decoders = {} for entry in a.custom_decoders_module_function: @@ -338,10 +385,12 @@ def on_progress(sample: Progress) -> None: max_shots=args.max_shots, count_detection_events=args.count_detection_events, count_observable_error_combos=args.count_observable_error_combos, + samplers=args.samplers, decoders=args.decoders, max_batch_seconds=args.max_batch_seconds, max_batch_size=args.max_batch_size, start_batch_size=args.start_batch_size, + custom_samplers=args.custom_samplers, custom_decoders=args.custom_decoders, custom_error_count_key=args.custom_error_count_key, allowed_cpu_affinity_ids=args.allowed_cpu_affinity_ids, diff --git a/glue/sample/src/sinter/_main_collect_test.py b/glue/sample/src/sinter/_main_collect_test.py index 09f45d890..72b02dfe4 100644 --- a/glue/sample/src/sinter/_main_collect_test.py +++ b/glue/sample/src/sinter/_main_collect_test.py @@ -1,4 +1,3 @@ -import collections import pathlib import tempfile @@ -127,7 +126,7 @@ def _make_custom_decoders(): def test_main_collect_with_custom_decoder(): with tempfile.TemporaryDirectory() as d: d = pathlib.Path(d) - with open(d / f'tmp.stim', 'w') as f: + with open(d / 'tmp.stim', 'w') as f: print(""" M(0.1) 0 DETECTOR rec[-1] diff --git a/glue/sample/src/sinter/_main_combine.py b/glue/sample/src/sinter/_main_combine.py index 86c5e0cf8..baed98da9 100644 --- a/glue/sample/src/sinter/_main_combine.py +++ b/glue/sample/src/sinter/_main_combine.py @@ -42,6 +42,7 @@ def main_combine(*, command_line_args: List[str]): total = [ sinter.TaskStats( strong_id=task.strong_id, + sampler=task.sampler, decoder=task.decoder, json_metadata=task.json_metadata, shots=task.shots, @@ -55,6 +56,7 @@ def main_combine(*, command_line_args: List[str]): total = [ sinter.TaskStats( strong_id=task.strong_id, + sampler=task.sampler, decoder=task.decoder, json_metadata=task.json_metadata, shots=task.shots, diff --git a/glue/sample/src/sinter/_main_combine_test.py b/glue/sample/src/sinter/_main_combine_test.py index f89ae4634..757d5fd17 100644 --- a/glue/sample/src/sinter/_main_combine_test.py +++ b/glue/sample/src/sinter/_main_combine_test.py @@ -9,12 +9,12 @@ def test_main_combine(): with tempfile.TemporaryDirectory() as d: d = pathlib.Path(d) - with open(d / f'input.csv', 'w') as f: + with open(d / 'input.csv', 'w') as f: print(""" -shots,errors,discards,seconds,decoder,strong_id,json_metadata -300,1,20,1.0,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}" -300,100,200,2.0,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}" -9,5,4,6.0,pymatching,5fe5a6cd4226b1a910d57e5479d1ba6572e0b3115983c9516360916d1670000f,"{""path"":""b.stim""}" +shots,errors,discards,seconds,sampler,decoder,strong_id,json_metadata +300,1,20,1.0,stim,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}" +300,100,200,2.0,stim,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}" +9,5,4,6.0,stim,pymatching,5fe5a6cd4226b1a910d57e5479d1ba6572e0b3115983c9516360916d1670000f,"{""path"":""b.stim""}" """.strip(), file=f) out = io.StringIO() @@ -23,9 +23,9 @@ def test_main_combine(): "combine", str(d / "input.csv"), ]) - assert out.getvalue() == """ shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts - 600, 101, 220, 3.00,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}", - 9, 5, 4, 6.00,pymatching,5fe5a6cd4226b1a910d57e5479d1ba6572e0b3115983c9516360916d1670000f,"{""path"":""b.stim""}", + assert out.getvalue() == """ shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 600, 101, 220, 3.00,stim,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}", + 9, 5, 4, 6.00,stim,pymatching,5fe5a6cd4226b1a910d57e5479d1ba6572e0b3115983c9516360916d1670000f,"{""path"":""b.stim""}", """ out = io.StringIO() @@ -35,25 +35,25 @@ def test_main_combine(): str(d / "input.csv"), str(d / "input.csv"), ]) - assert out.getvalue() == """ shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts - 1200, 202, 440, 6.00,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}", - 18, 10, 8, 12.0,pymatching,5fe5a6cd4226b1a910d57e5479d1ba6572e0b3115983c9516360916d1670000f,"{""path"":""b.stim""}", + assert out.getvalue() == """ shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 1200, 202, 440, 6.00,stim,pymatching,f256bab362f516ebe4d59a08ae67330ff7771ff738757cd738f4b30605ddccf6,"{""path"":""a.stim""}", + 18, 10, 8, 12.0,stim,pymatching,5fe5a6cd4226b1a910d57e5479d1ba6572e0b3115983c9516360916d1670000f,"{""path"":""b.stim""}", """ def test_main_combine_legacy_custom_counts(): with tempfile.TemporaryDirectory() as d: d = pathlib.Path(d) - with open(d / f'old.csv', 'w') as f: + with open(d / 'old.csv', 'w') as f: print(""" -shots,errors,discards,seconds,decoder,strong_id,json_metadata -100,1,20,1.0,pymatching,abc123,"{""path"":""a.stim""}" +shots,errors,discards,seconds,sampler,decoder,strong_id,json_metadata +100,1,20,1.0,stim,pymatching,abc123,"{""path"":""a.stim""}" """.strip(), file=f) - with open(d / f'new.csv', 'w') as f: + with open(d / 'new.csv', 'w') as f: print(""" -shots,errors,discards,seconds,decoder,strong_id,json_metadata,custom_counts -300,1,20,1.0,pymatching,abc123,"{""path"":""a.stim""}","{""x"":2}" -300,1,20,1.0,pymatching,abc123,"{""path"":""a.stim""}","{""y"":3}" +shots,errors,discards,seconds,sampler,decoder,strong_id,json_metadata,custom_counts +300,1,20,1.0,stim,pymatching,abc123,"{""path"":""a.stim""}","{""x"":2}" +300,1,20,1.0,stim,pymatching,abc123,"{""path"":""a.stim""}","{""y"":3}" """.strip(), file=f) out = io.StringIO() @@ -63,22 +63,22 @@ def test_main_combine_legacy_custom_counts(): str(d / "old.csv"), str(d / "new.csv"), ]) - assert out.getvalue() == """ shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts - 700, 3, 60, 3.00,pymatching,abc123,"{""path"":""a.stim""}","{""x"":2,""y"":3}" + assert out.getvalue() == """ shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 700, 3, 60, 3.00,stim,pymatching,abc123,"{""path"":""a.stim""}","{""x"":2,""y"":3}" """ def test_order_flag(): with tempfile.TemporaryDirectory() as d: d = pathlib.Path(d) - with open(d / f'input.csv', 'w') as f: + with open(d / 'input.csv', 'w') as f: print(""" -shots,errors,discards,seconds,decoder, strong_id,json_metadata -1000, 100, 4, 2.0, pymatching,deadbeef0,"{""d"":19}" -2000, 300, 3, 3.0, pymatching,deadbeef1,"{""d"":9}" -3000, 200, 2000, 5.0, pymatching,deadbeef2,"{""d"":200}" -4000, 100, 1, 7.0, pymatching,deadbeef3,"{""d"":3}" -5000, 100, 0, 11, pymatching,deadbeef4,"{""d"":5}" +shots,errors,discards,seconds,sampler,decoder, strong_id,json_metadata +1000, 100, 4, 2.0, stim,pymatching,deadbeef0,"{""d"":19}" +2000, 300, 3, 3.0, stim,pymatching,deadbeef1,"{""d"":9}" +3000, 200, 2000, 5.0, stim,pymatching,deadbeef2,"{""d"":200}" +4000, 100, 1, 7.0, stim,pymatching,deadbeef3,"{""d"":3}" +5000, 100, 0, 11, stim,pymatching,deadbeef4,"{""d"":5}" """.strip(), file=f) out = io.StringIO() @@ -90,12 +90,12 @@ def test_order_flag(): str(d / "input.csv"), str(d / "input.csv"), ]) - assert out.getvalue() == """ shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts - 2000, 200, 8, 4.00,pymatching,deadbeef0,"{""d"":19}", - 4000, 600, 6, 6.00,pymatching,deadbeef1,"{""d"":9}", - 6000, 400, 4000, 10.0,pymatching,deadbeef2,"{""d"":200}", - 8000, 200, 2, 14.0,pymatching,deadbeef3,"{""d"":3}", - 10000, 200, 0, 22.0,pymatching,deadbeef4,"{""d"":5}", + assert out.getvalue() == """ shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 2000, 200, 8, 4.00,stim,pymatching,deadbeef0,"{""d"":19}", + 4000, 600, 6, 6.00,stim,pymatching,deadbeef1,"{""d"":9}", + 6000, 400, 4000, 10.0,stim,pymatching,deadbeef2,"{""d"":200}", + 8000, 200, 2, 14.0,stim,pymatching,deadbeef3,"{""d"":3}", + 10000, 200, 0, 22.0,stim,pymatching,deadbeef4,"{""d"":5}", """ out = io.StringIO() @@ -107,12 +107,12 @@ def test_order_flag(): str(d / "input.csv"), str(d / "input.csv"), ]) - assert out.getvalue() == """ shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts - 8000, 200, 2, 14.0,pymatching,deadbeef3,"{""d"":3}", - 10000, 200, 0, 22.0,pymatching,deadbeef4,"{""d"":5}", - 4000, 600, 6, 6.00,pymatching,deadbeef1,"{""d"":9}", - 2000, 200, 8, 4.00,pymatching,deadbeef0,"{""d"":19}", - 6000, 400, 4000, 10.0,pymatching,deadbeef2,"{""d"":200}", + assert out.getvalue() == """ shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 8000, 200, 2, 14.0,stim,pymatching,deadbeef3,"{""d"":3}", + 10000, 200, 0, 22.0,stim,pymatching,deadbeef4,"{""d"":5}", + 4000, 600, 6, 6.00,stim,pymatching,deadbeef1,"{""d"":9}", + 2000, 200, 8, 4.00,stim,pymatching,deadbeef0,"{""d"":19}", + 6000, 400, 4000, 10.0,stim,pymatching,deadbeef2,"{""d"":200}", """ out = io.StringIO() @@ -124,22 +124,22 @@ def test_order_flag(): str(d / "input.csv"), str(d / "input.csv"), ]) - assert out.getvalue() == """ shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts - 10000, 200, 0, 22.0,pymatching,deadbeef4,"{""d"":5}", - 8000, 200, 2, 14.0,pymatching,deadbeef3,"{""d"":3}", - 2000, 200, 8, 4.00,pymatching,deadbeef0,"{""d"":19}", - 4000, 600, 6, 6.00,pymatching,deadbeef1,"{""d"":9}", - 6000, 400, 4000, 10.0,pymatching,deadbeef2,"{""d"":200}", + assert out.getvalue() == """ shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 10000, 200, 0, 22.0,stim,pymatching,deadbeef4,"{""d"":5}", + 8000, 200, 2, 14.0,stim,pymatching,deadbeef3,"{""d"":3}", + 2000, 200, 8, 4.00,stim,pymatching,deadbeef0,"{""d"":19}", + 4000, 600, 6, 6.00,stim,pymatching,deadbeef1,"{""d"":9}", + 6000, 400, 4000, 10.0,stim,pymatching,deadbeef2,"{""d"":200}", """ def test_order_custom_counts(): with tempfile.TemporaryDirectory() as d: d = pathlib.Path(d) - with open(d / f'input.csv', 'w') as f: + with open(d / 'input.csv', 'w') as f: print(""" -shots,errors,discards,seconds,decoder, strong_id,json_metadata,custom_counts -1000, 100, 4, 2.0, pymatching,deadbeef0,[],"{""d4"":3,""d2"":30}" +shots,errors,discards,seconds,sampler,decoder, strong_id,json_metadata,custom_counts +1000, 100, 4, 2.0, stim,pymatching,deadbeef0,[],"{""d4"":3,""d2"":30}" """.strip(), file=f) out = io.StringIO() @@ -148,6 +148,6 @@ def test_order_custom_counts(): "combine", str(d / "input.csv"), ]) - assert out.getvalue() == """ shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts - 1000, 100, 4, 2.00,pymatching,deadbeef0,[],"{""d2"":30,""d4"":3}" + assert out.getvalue() == """ shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts + 1000, 100, 4, 2.00,stim,pymatching,deadbeef0,[],"{""d2"":30,""d4"":3}" """ diff --git a/glue/sample/src/sinter/_predict.py b/glue/sample/src/sinter/_predict.py index eeefba040..a1ef9db00 100644 --- a/glue/sample/src/sinter/_predict.py +++ b/glue/sample/src/sinter/_predict.py @@ -10,7 +10,7 @@ from sinter._collection import post_selection_mask_from_4th_coord from sinter._decoding_decoder_class import Decoder from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS -from sinter._decoding import streaming_post_select +from sinter._sampling_and_decoding import streaming_post_select if TYPE_CHECKING: import sinter diff --git a/glue/sample/src/sinter/_sampling_all_built_in_samplers.py b/glue/sample/src/sinter/_sampling_all_built_in_samplers.py new file mode 100644 index 000000000..de2f94879 --- /dev/null +++ b/glue/sample/src/sinter/_sampling_all_built_in_samplers.py @@ -0,0 +1,10 @@ +from typing import Dict + +from sinter._sampling_sampler_class import Sampler +from sinter._sampling_stim import StimDetectorSampler +from sinter._sampling_vacuous import VacuousSampler + +BUILT_IN_SAMPLERS: Dict[str, Sampler] = { + "stim": StimDetectorSampler(), + "vacuous": VacuousSampler(), +} diff --git a/glue/sample/src/sinter/_decoding.py b/glue/sample/src/sinter/_sampling_and_decoding.py similarity index 67% rename from glue/sample/src/sinter/_decoding.py rename to glue/sample/src/sinter/_sampling_and_decoding.py index 7170d443e..3059d739a 100644 --- a/glue/sample/src/sinter/_decoding.py +++ b/glue/sample/src/sinter/_sampling_and_decoding.py @@ -1,5 +1,4 @@ import collections -from typing import Iterable from typing import Optional, Dict, Tuple, TYPE_CHECKING, Union import contextlib @@ -13,24 +12,30 @@ from sinter._anon_task_stats import AnonTaskStats from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS +from sinter._sampling_all_built_in_samplers import BUILT_IN_SAMPLERS from sinter._decoding_decoder_class import CompiledDecoder, Decoder +from sinter._sampling_sampler_class import Sampler, CompiledSampler if TYPE_CHECKING: import sinter -def streaming_post_select(*, - num_dets: int, - num_obs: int, - dets_in_b8: pathlib.Path, - obs_in_b8: Optional[pathlib.Path], - dets_out_b8: pathlib.Path, - obs_out_b8: Optional[pathlib.Path], - discards_out_b8: Optional[pathlib.Path], - num_shots: int, - post_mask: np.ndarray) -> int: +def streaming_post_select( + *, + num_dets: int, + num_obs: int, + dets_in_b8: pathlib.Path, + obs_in_b8: Optional[pathlib.Path], + dets_out_b8: pathlib.Path, + obs_out_b8: Optional[pathlib.Path], + discards_out_b8: Optional[pathlib.Path], + num_shots: int, + post_mask: np.ndarray, +) -> int: if post_mask.shape != ((num_dets + 7) // 8,): - raise ValueError(f"post_mask.shape={post_mask.shape} != (math.ceil(num_detectors / 8),)") + raise ValueError( + f"post_mask.shape={post_mask.shape} != (math.ceil(num_detectors / 8),)" + ) if post_mask.dtype != np.uint8: raise ValueError(f"post_mask.dtype={post_mask.dtype} != np.uint8") assert (obs_in_b8 is None) == (obs_out_b8 is None) @@ -41,30 +46,34 @@ def streaming_post_select(*, num_discards = 0 with contextlib.ExitStack() as ctx: - dets_in_f = ctx.enter_context(open(dets_in_b8, 'rb')) - dets_out_f = ctx.enter_context(open(dets_out_b8, 'wb')) + dets_in_f = ctx.enter_context(open(dets_in_b8, "rb")) + dets_out_f = ctx.enter_context(open(dets_out_b8, "wb")) if obs_in_b8 is not None and obs_out_b8 is not None: - obs_in_f = ctx.enter_context(open(obs_in_b8, 'rb')) - obs_out_f = ctx.enter_context(open(obs_out_b8, 'wb')) + obs_in_f = ctx.enter_context(open(obs_in_b8, "rb")) + obs_out_f = ctx.enter_context(open(obs_out_b8, "wb")) else: obs_in_f = None obs_out_f = None if discards_out_b8 is not None: - discards_out_f = ctx.enter_context(open(discards_out_b8, 'wb')) + discards_out_f = ctx.enter_context(open(discards_out_b8, "wb")) else: discards_out_f = None while num_shots_left: - batch_size = min(num_shots_left, math.ceil(10 ** 6 / max(1, num_dets))) + batch_size = min(num_shots_left, math.ceil(10**6 / max(1, num_dets))) - det_batch = np.fromfile(dets_in_f, dtype=np.uint8, count=num_det_bytes * batch_size) + det_batch = np.fromfile( + dets_in_f, dtype=np.uint8, count=num_det_bytes * batch_size + ) det_batch.shape = (batch_size, num_det_bytes) discarded = np.any(det_batch & post_mask, axis=1) det_left = det_batch[~discarded, :] det_left.tofile(dets_out_f) if obs_in_f is not None and obs_out_f is not None: - obs_batch = np.fromfile(obs_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size) + obs_batch = np.fromfile( + obs_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size + ) obs_batch.shape = (batch_size, num_obs_bytes) obs_left = obs_batch[~discarded, :] obs_left.tofile(obs_out_f) @@ -78,55 +87,69 @@ def streaming_post_select(*, def _streaming_count_mistakes( - *, - num_shots: int, - num_obs: int, - num_det: int, - postselected_observable_mask: Optional[np.ndarray] = None, - dets_in: pathlib.Path, - obs_in: pathlib.Path, - predictions_in: pathlib.Path, - count_detection_events: bool, - count_observable_error_combos: bool, + *, + num_shots: int, + num_obs: int, + num_det: int, + postselected_observable_mask: Optional[np.ndarray] = None, + dets_in: pathlib.Path, + obs_in: pathlib.Path, + predictions_in: pathlib.Path, + count_detection_events: bool, + count_observable_error_combos: bool, ) -> Tuple[int, int, collections.Counter]: - num_det_bytes = math.ceil(num_det / 8) num_obs_bytes = math.ceil(num_obs / 8) num_errors = 0 num_discards = 0 custom_counts = collections.Counter() if count_detection_events: - with open(dets_in, 'rb') as dets_in_f: + with open(dets_in, "rb") as dets_in_f: num_shots_left = num_shots while num_shots_left: batch_size = min(num_shots_left, math.ceil(10**6 / max(num_obs, 1))) - det_data = np.fromfile(dets_in_f, dtype=np.uint8, count=num_det_bytes * batch_size) + det_data = np.fromfile( + dets_in_f, dtype=np.uint8, count=num_det_bytes * batch_size + ) for b in range(8): - custom_counts['detection_events'] += np.count_nonzero(det_data & (1 << b)) + custom_counts["detection_events"] += np.count_nonzero( + det_data & (1 << b) + ) num_shots_left -= batch_size - custom_counts['detectors_checked'] += num_shots * num_det + custom_counts["detectors_checked"] += num_shots * num_det - with open(obs_in, 'rb') as obs_in_f: - with open(predictions_in, 'rb') as predictions_in_f: + with open(obs_in, "rb") as obs_in_f: + with open(predictions_in, "rb") as predictions_in_f: num_shots_left = num_shots while num_shots_left: batch_size = min(num_shots_left, math.ceil(10**6 / max(num_obs, 1))) - obs_batch = np.fromfile(obs_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size) - pred_batch = np.fromfile(predictions_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size) + obs_batch = np.fromfile( + obs_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size + ) + pred_batch = np.fromfile( + predictions_in_f, dtype=np.uint8, count=num_obs_bytes * batch_size + ) obs_batch.shape = (batch_size, num_obs_bytes) pred_batch.shape = (batch_size, num_obs_bytes) cmp_table = pred_batch ^ obs_batch err_mask = np.any(cmp_table, axis=1) if postselected_observable_mask is not None: - discard_mask = np.any(cmp_table & postselected_observable_mask, axis=1) + discard_mask = np.any( + cmp_table & postselected_observable_mask, axis=1 + ) err_mask &= ~discard_mask num_discards += np.count_nonzero(discard_mask) if count_observable_error_combos: for misprediction_arr in cmp_table[err_mask]: - err_key = "obs_mistake_mask=" + ''.join('_E'[b] for b in np.unpackbits(misprediction_arr, count=num_obs, bitorder='little')) + err_key = "obs_mistake_mask=" + "".join( + "_E"[b] + for b in np.unpackbits( + misprediction_arr, count=num_obs, bitorder="little" + ) + ) custom_counts[err_key] += 1 num_errors += np.count_nonzero(err_mask) @@ -134,21 +157,24 @@ def _streaming_count_mistakes( return num_discards, num_errors, custom_counts -def sample_decode(*, - circuit_obj: Optional[stim.Circuit], - circuit_path: Union[None, str, pathlib.Path], - dem_obj: Optional[stim.DetectorErrorModel], - dem_path: Union[None, str, pathlib.Path], - post_mask: Optional[np.ndarray] = None, - postselected_observable_mask: Optional[np.ndarray] = None, - count_observable_error_combos: bool = False, - count_detection_events: bool = False, - num_shots: int, - decoder: str, - tmp_dir: Union[str, pathlib.Path, None] = None, - custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None, - __private__unstable__force_decode_on_disk: Optional[bool] = None, - ) -> AnonTaskStats: +def sample_decode( + *, + circuit_obj: Optional[stim.Circuit], + circuit_path: Union[None, str, pathlib.Path], + dem_obj: Optional[stim.DetectorErrorModel], + dem_path: Union[None, str, pathlib.Path], + post_mask: Optional[np.ndarray] = None, + postselected_observable_mask: Optional[np.ndarray] = None, + count_observable_error_combos: bool = False, + count_detection_events: bool = False, + num_shots: int, + sampler: str, + custom_samplers: Optional[Dict[str, "sinter.Sampler"]] = None, + decoder: str, + tmp_dir: Union[str, pathlib.Path, None] = None, + custom_decoders: Optional[Dict[str, "sinter.Decoder"]] = None, + __private__unstable__force_decode_on_disk: Optional[bool] = None, +) -> AnonTaskStats: """Samples how many times a decoder correctly predicts the logical frame. Args: @@ -177,6 +203,12 @@ def sample_decode(*, were executed. The detection fraction is the ratio of these two numbers. num_shots: The number of sample shots to take from the circuit. + sampler: The name of the sampler to use. Allowed values are: + "stim": + Use stim's built-in detector sampler. + custom_samplers: Custom samplers that can be used if requested by name. + If not specified, only samplers built into sinter, such as 'stim', + can be used. decoder: The name of the decoder to use. Allowed values are: "pymatching": Use pymatching min-weight-perfect-match decoder. @@ -192,12 +224,20 @@ def sample_decode(*, 'pymatching' and 'fusion_blossom', can be used. """ if (circuit_obj is None) == (circuit_path is None): - raise ValueError('(circuit_obj is None) == (circuit_path is None)') + raise ValueError("(circuit_obj is None) == (circuit_path is None)") if (dem_obj is None) == (dem_path is None): - raise ValueError('(dem_obj is None) == (dem_path is None)') + raise ValueError("(dem_obj is None) == (dem_path is None)") if num_shots == 0: return AnonTaskStats() + sampler_obj: Optional[Sampler] = None + if custom_samplers is not None: + sampler_obj = custom_samplers.get(sampler) + if sampler_obj is None: + sampler_obj = BUILT_IN_SAMPLERS.get(sampler) + if sampler_obj is None: + raise NotImplementedError(f"Unrecognized sampler: {sampler!r}") + decoder_obj: Optional[Decoder] = None if custom_decoders is not None: decoder_obj = custom_decoders.get(decoder) @@ -222,11 +262,12 @@ def sample_decode(*, try: if __private__unstable__force_decode_on_disk: raise NotImplementedError() + compiled_sampler = sampler_obj.compile_sampler_for_circuit(circuit=circuit) compiled_decoder = decoder_obj.compile_decoder_for_dem(dem=dem) return _sample_decode_helper_using_memory( - circuit=circuit, post_mask=post_mask, postselected_observable_mask=postselected_observable_mask, + compiled_sampler=compiled_sampler, compiled_decoder=compiled_decoder, total_num_shots=num_shots, num_det=circuit.num_detectors, @@ -237,15 +278,20 @@ def sample_decode(*, count_detection_events=count_detection_events, ) except NotImplementedError: - assert __private__unstable__force_decode_on_disk or __private__unstable__force_decode_on_disk is None + assert ( + __private__unstable__force_decode_on_disk + or __private__unstable__force_decode_on_disk is None + ) pass return _sample_decode_helper_using_disk( circuit=circuit, + circuit_path=circuit_path, dem=dem, dem_path=dem_path, post_mask=post_mask, postselected_observable_mask=postselected_observable_mask, num_shots=num_shots, + sampler_obj=sampler_obj, decoder_obj=decoder_obj, tmp_dir=tmp_dir, start_time_monotonic=start_time, @@ -256,27 +302,27 @@ def sample_decode(*, def _sample_decode_helper_using_memory( *, - circuit: stim.Circuit, post_mask: Optional[np.ndarray], postselected_observable_mask: Optional[np.ndarray], num_obs: int, num_det: int, total_num_shots: int, mini_batch_size: int, + compiled_sampler: CompiledSampler, compiled_decoder: CompiledDecoder, start_time_monotonic: float, count_observable_error_combos: bool, count_detection_events: bool, ) -> AnonTaskStats: - sampler: stim.CompiledDetectorSampler = circuit.compile_detector_sampler() - out_num_discards = 0 out_num_errors = 0 shots_left = total_num_shots custom_counts = collections.Counter() while shots_left > 0: cur_num_shots = min(shots_left, mini_batch_size) - dets_data, obs_data = sampler.sample(shots=cur_num_shots, separate_observables=True, bit_packed=True) + dets_data, obs_data = compiled_sampler.sample_detectors_bit_packed( + shots=cur_num_shots + ) # Discard any shots that contain a postselected detection events. if post_mask is not None: @@ -288,11 +334,15 @@ def _sample_decode_helper_using_memory( obs_data = obs_data[~discarded_flags, :] # Have the decoder predict which observables are flipped. - predict_data = compiled_decoder.decode_shots_bit_packed(bit_packed_detection_event_data=dets_data) + predict_data = compiled_decoder.decode_shots_bit_packed( + bit_packed_detection_event_data=dets_data + ) # Discard any shots where the decoder predicts a flipped postselected observable. if postselected_observable_mask is not None: - discarded_flags = np.any(postselected_observable_mask & (predict_data ^ obs_data), axis=1) + discarded_flags = np.any( + postselected_observable_mask & (predict_data ^ obs_data), axis=1 + ) cur_num_discarded_shots = np.count_nonzero(discarded_flags) if cur_num_discarded_shots: out_num_discards += cur_num_discarded_shots @@ -304,16 +354,23 @@ def _sample_decode_helper_using_memory( err_mask = np.any(mispredictions, axis=1) if count_detection_events: for b in range(8): - custom_counts['detection_events'] += np.count_nonzero(dets_data & (1 << b)) + custom_counts["detection_events"] += np.count_nonzero( + dets_data & (1 << b) + ) if count_observable_error_combos: for misprediction_arr in mispredictions[err_mask]: - err_key = "obs_mistake_mask=" + ''.join('_E'[b] for b in np.unpackbits(misprediction_arr, count=num_obs, bitorder='little')) + err_key = "obs_mistake_mask=" + "".join( + "_E"[b] + for b in np.unpackbits( + misprediction_arr, count=num_obs, bitorder="little" + ) + ) custom_counts[err_key] += 1 out_num_errors += np.count_nonzero(err_mask) shots_left -= cur_num_shots if count_detection_events: - custom_counts['detectors_checked'] += num_det * total_num_shots + custom_counts["detectors_checked"] += num_det * total_num_shots return AnonTaskStats( shots=total_num_shots, errors=out_num_errors, @@ -326,11 +383,13 @@ def _sample_decode_helper_using_memory( def _sample_decode_helper_using_disk( *, circuit: stim.Circuit, + circuit_path: Union[str, pathlib.Path, None], dem: stim.DetectorErrorModel, - dem_path: Union[str, pathlib.Path], + dem_path: Union[str, pathlib.Path, None], post_mask: Optional[np.ndarray], postselected_observable_mask: Optional[np.ndarray], num_shots: int, + sampler_obj: Sampler, decoder_obj: Decoder, tmp_dir: Union[str, pathlib.Path, None], start_time_monotonic: float, @@ -341,28 +400,31 @@ def _sample_decode_helper_using_disk( if tmp_dir is None: tmp_dir = exit_stack.enter_context(tempfile.TemporaryDirectory()) tmp_dir = pathlib.Path(tmp_dir) + if circuit_path is None: + circuit_path = tmp_dir / "tmp.stim" + circuit.to_file(circuit_path) + circuit_path = pathlib.Path(circuit_path) if dem_path is None: - dem_path = tmp_dir / 'tmp.dem' + dem_path = tmp_dir / "tmp.dem" dem.to_file(dem_path) dem_path = pathlib.Path(dem_path) - dets_all_path = tmp_dir / 'sinter_dets.all.b8' - obs_all_path = tmp_dir / 'sinter_obs.all.b8' - dets_kept_path = tmp_dir / 'sinter_dets.kept.b8' - obs_kept_path = tmp_dir / 'sinter_obs.kept.b8' - predictions_path = tmp_dir / 'sinter_predictions.b8' + dets_all_path = tmp_dir / "sinter_dets.all.b8" + obs_all_path = tmp_dir / "sinter_obs.all.b8" + dets_kept_path = tmp_dir / "sinter_dets.kept.b8" + obs_kept_path = tmp_dir / "sinter_obs.kept.b8" + predictions_path = tmp_dir / "sinter_predictions.b8" num_dets = circuit.num_detectors num_obs = circuit.num_observables - # Sample data using Stim. - sampler: stim.CompiledDetectorSampler = circuit.compile_detector_sampler() - sampler.sample_write( - num_shots, - filepath=str(dets_all_path), - obs_out_filepath=str(obs_all_path), - format='b8', - obs_out_format='b8', + # Sample data + sampler_obj.sample_detectors_via_files( + shots=num_shots, + circuit_path=circuit_path, + dets_b8_out_path=dets_all_path, + obs_flips_b8_out_path=obs_all_path, + tmp_dir=tmp_dir, ) # Postselect, then split into detection event data and observable data. diff --git a/glue/sample/src/sinter/_decoding_test.py b/glue/sample/src/sinter/_sampling_and_decoding_test.py similarity index 62% rename from glue/sample/src/sinter/_decoding_test.py rename to glue/sample/src/sinter/_sampling_and_decoding_test.py index 2ca9fbbca..c0048ffc6 100644 --- a/glue/sample/src/sinter/_decoding_test.py +++ b/glue/sample/src/sinter/_sampling_and_decoding_test.py @@ -11,9 +11,37 @@ import stim from sinter._collection import post_selection_mask_from_4th_coord -from sinter._decoding import sample_decode +from sinter._sampling_and_decoding import sample_decode from sinter._decoding_all_built_in_decoders import BUILT_IN_DECODERS +from sinter._sampling_all_built_in_samplers import BUILT_IN_SAMPLERS from sinter._decoding_vacuous import VacuousDecoder +from sinter._sampling_vacuous import VacuousSampler + + +def get_test_samplers() -> Tuple[List[str], Dict[str, sinter.Sampler]]: + available_samplers = list(BUILT_IN_SAMPLERS.keys()) + custom_samplers = {} + + e = os.environ.get("SINTER_PYTEST_CUSTOM_SAMPLERS") + if e is not None: + for term in e.split(";"): + module, method = term.split(":") + for name, obj in getattr(__import__(module), method)().items(): + custom_samplers[name] = obj + available_samplers.append(name) + + available_samplers.append("also_vacuous") + custom_samplers["also_vacuous"] = VacuousSampler() + return available_samplers, custom_samplers + + +TEST_SAMPLER_NAMES, TEST_CUSTOM_SAMPLERS = get_test_samplers() + +SAMPLER_CASES = [ + (sampler, force_streaming) + for sampler in TEST_SAMPLER_NAMES + for force_streaming in [None, True] +] def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]: @@ -22,16 +50,16 @@ def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]: try: import pymatching except ImportError: - available_decoders.remove('pymatching') + available_decoders.remove("pymatching") try: import fusion_blossom except ImportError: - available_decoders.remove('fusion_blossom') + available_decoders.remove("fusion_blossom") - e = os.environ.get('SINTER_PYTEST_CUSTOM_DECODERS') + e = os.environ.get("SINTER_PYTEST_CUSTOM_DECODERS") if e is not None: - for term in e.split(';'): - module, method = term.split(':') + for term in e.split(";"): + module, method = term.split(":") for name, obj in getattr(__import__(module), method)().items(): custom_decoders[name] = obj available_decoders.append(name) @@ -40,6 +68,7 @@ def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]: custom_decoders["also_vacuous"] = VacuousDecoder() return available_decoders, custom_decoders + TEST_DECODER_NAMES, TEST_CUSTOM_DECODERS = get_test_decoders() DECODER_CASES = [ @@ -49,29 +78,59 @@ def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]: ] -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("sampler,force_streaming", SAMPLER_CASES) +def test_sample_repetition_code(sampler: str, force_streaming: Optional[bool]): + circuit = stim.Circuit.generated( + "repetition_code:memory", + rounds=3, + distance=3, + after_clifford_depolarization=0.05, + ) + result = sample_decode( + circuit_obj=circuit, + circuit_path=None, + dem_obj=circuit.detector_error_model(decompose_errors=True), + dem_path=None, + num_shots=1000, + sampler=sampler, + custom_samplers=TEST_CUSTOM_SAMPLERS, + decoder="vacuous", + __private__unstable__force_decode_on_disk=force_streaming, + ) + assert result.discards == 0 + assert result.shots == 1000 + if "vacuous" in sampler: + assert result.errors == 0 + else: + assert 1 <= result.errors <= 100 + + +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_decode_repetition_code(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit.generated('repetition_code:memory', - rounds=3, - distance=3, - after_clifford_depolarization=0.05) + circuit = stim.Circuit.generated( + "repetition_code:memory", + rounds=3, + distance=3, + after_clifford_depolarization=0.05, + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, dem_obj=circuit.detector_error_model(decompose_errors=True), dem_path=None, num_shots=1000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, ) assert result.discards == 0 - if 'vacuous' not in decoder: + if "vacuous" not in decoder: assert 1 <= result.errors <= 100 assert result.shots == 1000 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_decode_surface_code(decoder: str, force_streaming: Optional[bool]): circuit = stim.Circuit.generated( "surface_code:rotated_memory_x", @@ -85,15 +144,16 @@ def test_decode_surface_code(decoder: str, force_streaming: Optional[bool]): circuit_path=None, dem_obj=circuit.detector_error_model(decompose_errors=True), dem_path=None, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, ) - if 'vacuous' not in decoder: + if "vacuous" not in decoder: assert 0 <= stats.errors <= 50 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_empty(decoder: str, force_streaming: Optional[bool]): circuit = stim.Circuit() result = sample_decode( @@ -102,6 +162,7 @@ def test_empty(decoder: str, force_streaming: Optional[bool]): dem_obj=circuit.detector_error_model(decompose_errors=True), dem_path=None, num_shots=1000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, @@ -111,19 +172,22 @@ def test_empty(decoder: str, force_streaming: Optional[bool]): assert result.errors == 0 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_no_observables(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 M 0 DETECTOR rec[-1] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, dem_obj=circuit.detector_error_model(decompose_errors=True), dem_path=None, num_shots=1000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, @@ -133,20 +197,23 @@ def test_no_observables(decoder: str, force_streaming: Optional[bool]): assert result.errors == 0 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_invincible_observables(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 M 0 1 DETECTOR rec[-2] OBSERVABLE_INCLUDE(1) rec[-1] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, dem_obj=circuit.detector_error_model(decompose_errors=True), dem_path=None, num_shots=1000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, @@ -156,23 +223,31 @@ def test_invincible_observables(decoder: str, force_streaming: Optional[bool]): assert result.errors == 0 -@pytest.mark.parametrize('decoder,force_streaming,offset', [(a, b, c) for a, b in DECODER_CASES for c in range(8)]) +@pytest.mark.parametrize( + "decoder,force_streaming,offset", + [(a, b, c) for a, b in DECODER_CASES for c in range(8)], +) def test_observable_offsets_mod8(decoder: str, force_streaming: bool, offset: int): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 MR 0 DETECTOR rec[-1] - """) * (8 + offset) + stim.Circuit(""" + """ + ) * (8 + offset) + stim.Circuit( + """ X_ERROR(0.1) 0 MR 0 OBSERVABLE_INCLUDE(0) rec[-1] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, dem_obj=circuit.detector_error_model(decompose_errors=True), dem_path=None, num_shots=1000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, @@ -182,19 +257,22 @@ def test_observable_offsets_mod8(decoder: str, force_streaming: bool, offset: in assert 50 <= result.errors <= 150 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_no_detectors(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 M 0 OBSERVABLE_INCLUDE(0) rec[-1] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, dem_obj=circuit.detector_error_model(decompose_errors=True), dem_path=None, num_shots=1000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, @@ -203,13 +281,15 @@ def test_no_detectors(decoder: str, force_streaming: Optional[bool]): assert 50 <= result.errors <= 150 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 M 0 OBSERVABLE_INCLUDE(0) rec[-1] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, @@ -217,6 +297,7 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo dem_path=None, post_mask=np.array([], dtype=np.uint8), num_shots=1000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, @@ -225,9 +306,10 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo assert 50 <= result.errors <= 150 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_post_selection(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.6) 0 M 0 DETECTOR(2, 0, 0, 1) rec[-1] @@ -241,7 +323,8 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]): X_ERROR(0.1) 2 M 2 OBSERVABLE_INCLUDE(0) rec[-1] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, @@ -249,24 +332,27 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]): dem_path=None, post_mask=post_selection_mask_from_4th_coord(circuit), num_shots=2000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, ) assert 1050 <= result.discards <= 1350 - if 'vacuous' not in decoder: + if "vacuous" not in decoder: assert 40 <= result.errors <= 160 -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_observable_post_selection(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 X_ERROR(0.2) 1 M 0 1 OBSERVABLE_INCLUDE(0) rec[-1] OBSERVABLE_INCLUDE(1) rec[-1] rec[-2] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, @@ -275,24 +361,29 @@ def test_observable_post_selection(decoder: str, force_streaming: Optional[bool] post_mask=None, postselected_observable_mask=np.array([1], dtype=np.uint8), num_shots=10000, + sampler='stim', decoder=decoder, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, ) np.testing.assert_allclose(result.discards / result.shots, 0.2, atol=0.1) - if 'vacuous' not in decoder: - np.testing.assert_allclose(result.errors / (result.shots - result.discards), 0.1, atol=0.05) + if "vacuous" not in decoder: + np.testing.assert_allclose( + result.errors / (result.shots - result.discards), 0.1, atol=0.05 + ) -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_error_splitting(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 X_ERROR(0.2) 1 M 0 1 OBSERVABLE_INCLUDE(0) rec[-1] OBSERVABLE_INCLUDE(1) rec[-1] rec[-2] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, @@ -300,23 +391,43 @@ def test_error_splitting(decoder: str, force_streaming: Optional[bool]): dem_path=None, post_mask=None, num_shots=10000, + sampler='stim', decoder=decoder, count_observable_error_combos=True, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, ) assert result.discards == 0 - assert set(result.custom_counts.keys()) == {'obs_mistake_mask=E_', 'obs_mistake_mask=_E', 'obs_mistake_mask=EE'} - if 'vacuous' not in decoder: - np.testing.assert_allclose(result.errors / result.shots, 1 - 0.8 * 0.9, atol=0.05) - np.testing.assert_allclose(result.custom_counts['obs_mistake_mask=E_'] / result.shots, 0.1 * 0.2, atol=0.05) - np.testing.assert_allclose(result.custom_counts['obs_mistake_mask=_E'] / result.shots, 0.1 * 0.8, atol=0.05) - np.testing.assert_allclose(result.custom_counts['obs_mistake_mask=EE'] / result.shots, 0.9 * 0.2, atol=0.05) - - -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) + assert set(result.custom_counts.keys()) == { + "obs_mistake_mask=E_", + "obs_mistake_mask=_E", + "obs_mistake_mask=EE", + } + if "vacuous" not in decoder: + np.testing.assert_allclose( + result.errors / result.shots, 1 - 0.8 * 0.9, atol=0.05 + ) + np.testing.assert_allclose( + result.custom_counts["obs_mistake_mask=E_"] / result.shots, + 0.1 * 0.2, + atol=0.05, + ) + np.testing.assert_allclose( + result.custom_counts["obs_mistake_mask=_E"] / result.shots, + 0.1 * 0.8, + atol=0.05, + ) + np.testing.assert_allclose( + result.custom_counts["obs_mistake_mask=EE"] / result.shots, + 0.9 * 0.2, + atol=0.05, + ) + + +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_detector_counting(decoder: str, force_streaming: Optional[bool]): - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ X_ERROR(0.1) 0 X_ERROR(0.2) 1 M 0 1 @@ -324,7 +435,8 @@ def test_detector_counting(decoder: str, force_streaming: Optional[bool]): DETECTOR rec[-2] OBSERVABLE_INCLUDE(0) rec[-1] OBSERVABLE_INCLUDE(1) rec[-1] rec[-2] - """) + """ + ) result = sample_decode( circuit_obj=circuit, circuit_path=None, @@ -332,53 +444,61 @@ def test_detector_counting(decoder: str, force_streaming: Optional[bool]): dem_path=None, post_mask=None, num_shots=10000, + sampler='stim', decoder=decoder, count_detection_events=True, __private__unstable__force_decode_on_disk=force_streaming, custom_decoders=TEST_CUSTOM_DECODERS, ) assert result.discards == 0 - assert result.custom_counts['detectors_checked'] == 20000 - assert 0.3 * 10000 * 0.5 <= result.custom_counts['detection_events'] <= 0.3 * 10000 * 2.0 - assert set(result.custom_counts.keys()) == {'detectors_checked', 'detection_events'} + assert result.custom_counts["detectors_checked"] == 20000 + assert ( + 0.3 * 10000 * 0.5 + <= result.custom_counts["detection_events"] + <= 0.3 * 10000 * 2.0 + ) + assert set(result.custom_counts.keys()) == {"detectors_checked", "detection_events"} -@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES) +@pytest.mark.parametrize("decoder,force_streaming", DECODER_CASES) def test_decode_fails_correctly(decoder: str, force_streaming: Optional[bool]): decoder_obj = BUILT_IN_DECODERS.get(decoder) with tempfile.TemporaryDirectory() as d: d = pathlib.Path(d) - circuit = stim.Circuit(""" + circuit = stim.Circuit( + """ REPEAT 9 { MR(0.001) 0 DETECTOR rec[-1] OBSERVABLE_INCLUDE(0) rec[-1] } - """) + """ + ) dem = circuit.detector_error_model() - circuit.to_file(d / 'circuit.stim') - dem.to_file(d / 'dem.dem') - with open(d / 'bad_dets.b8', 'wb') as f: - f.write(b'!') + circuit.to_file(d / "circuit.stim") + dem.to_file(d / "dem.dem") + with open(d / "bad_dets.b8", "wb") as f: + f.write(b"!") - if 'vacuous' not in decoder: + if "vacuous" not in decoder: with pytest.raises(Exception): decoder_obj.decode_via_files( num_shots=1, num_dets=dem.num_detectors, num_obs=dem.num_observables, - dem_path=d / 'dem.dem', - dets_b8_in_path=d / 'bad_dets.b8', - obs_predictions_b8_out_path=d / 'predict.b8', + dem_path=d / "dem.dem", + dets_b8_in_path=d / "bad_dets.b8", + obs_predictions_b8_out_path=d / "predict.b8", tmp_dir=d, ) -@pytest.mark.parametrize('decoder', TEST_DECODER_NAMES) +@pytest.mark.parametrize("decoder", TEST_DECODER_NAMES) def test_full_scale(decoder: str): - result, = sinter.collect( + (result,) = sinter.collect( num_workers=2, tasks=[sinter.Task(circuit=stim.Circuit())], + samplers=['stim'], decoders=[decoder], max_shots=1000, custom_decoders=TEST_CUSTOM_DECODERS, diff --git a/glue/sample/src/sinter/_sampling_sampler_class.py b/glue/sample/src/sinter/_sampling_sampler_class.py new file mode 100644 index 000000000..b43a7431c --- /dev/null +++ b/glue/sample/src/sinter/_sampling_sampler_class.py @@ -0,0 +1,123 @@ +from typing import Tuple +import abc +import pathlib + +import numpy as np +import stim + + +class CompiledSampler(metaclass=abc.ABCMeta): + """Abstract class for samplers preconfigured to a specific sampling task. + + This is the type returned by `sinter.Sampler.compile_sampler_for_circuit`. The + idea is that, when many shots of the same sampling task are going to be + performed, it is valuable to pay the cost of configuring the sampler only + once instead of once per batch of shots. Custom samplers can optionally + implement that method, and return this type, to increase sampling + efficiency. + """ + @abc.abstractmethod + def sample_detectors_bit_packed( + self, + *, + shots: int, + ) -> Tuple[np.ndarray, np.ndarray]: + """Samples detectors and observables. + + All data returned must be bit packed with bitorder='little'. + + Args: + shots: The number of shots to sample. + + Returns: + Bit packed detector data and bit packed observable flip data stored as + a tuple of two bit packed numpy arrays. The numpy array must have the + following dtype/shape: + + dtype: uint8 + shape: (shots, ceil(num_bits_per_shot / 8)) + + where `num_bits_per_shot` is `circuit.num_detectors` for the detector + data and `circuit.num_observables` for the observable flip data for the + circuit this instance was compiled to sample. + """ + pass + + +class Sampler(metaclass=abc.ABCMeta): + """Abstract base class for custom samplers. + + Custom samplers can be explained to sinter by inheriting from this class and + implementing its methods. + + Sampler classes MUST be serializable (e.g. via pickling), so that they can + be given to worker processes when using python multiprocessing. + """ + def compile_sampler_for_circuit( + self, + *, + circuit: stim.Circuit, + ) -> CompiledSampler: + """Compiles a sampler for the given circuit. + + This method is optional to implement. By default, it will raise a + NotImplementedError. When sampling, sinter will attempt to use this + method first and otherwise fallback to using `sample_detectors_via_files`. + + The idea is that the preconfigured sampler amortizes the cost of + configuration over more calls. This makes smaller batch sizes efficient, + reducing the amount of memory used for storing each batch, improving + overall efficiency. + + Args: + circuit: A circuit for the sampler to be configured and sample from. + + Returns: + An instance of `sinter.CompiledSampler` that can be used to invoke + the preconfigured sampler. + + Raises: + NotImplementedError: This `sinter.Sampler` doesn't support compiling + for a circuit. + """ + raise NotImplementedError("compile_sampler_for_circuit") + + @abc.abstractmethod + def sample_detectors_via_files( + self, + *, + shots: int, + circuit_path: pathlib.Path, + dets_b8_out_path: pathlib.Path, + obs_flips_b8_out_path: pathlib.Path, + tmp_dir: pathlib.Path, + ) -> None: + """Performs sampling by reading/writing circuit and data from/to disk. + + Args: + shots: The number of shots to sample. + circuit_path: The file path where the circuit should be read from, + e.g. using `stim.Circuit.from_file`. The circuit should be used + to configure the sampler. + dets_b8_out_path: The file path that detection event data should be + write to. Note that the file may be a named pipe instead of a + fixed size object. The detection events will be in b8 format + (see + https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md + ). The number of detection events per shot is available via the + circuit at `circuit_path`. + obs_flips_b8_out_path: The file path that observable flip data should + be write to. Note that the file may be a named pipe instead of a + fixed size object. The observables will be in b8 format + (see + https://github.com/quantumlib/Stim/blob/main/doc/result_formats.md + ). The number of observables per shot is available via the + circuit at `circuit_path`. + tmp_dir: Any temporary files generated by the sampler during its + operation MUST be put into this directory. The reason for this + requirement is because sinter is allowed to kill the sampling + process without warning, without giving it time to clean up any + temporary objects. All cleanup should be done via sinter + deleting this directory after killing the sampler. + """ + pass diff --git a/glue/sample/src/sinter/_sampling_stim.py b/glue/sample/src/sinter/_sampling_stim.py new file mode 100644 index 000000000..9ba61e9e2 --- /dev/null +++ b/glue/sample/src/sinter/_sampling_stim.py @@ -0,0 +1,44 @@ +from typing import Tuple +import pathlib + +import stim +import numpy as np + +from sinter._sampling_sampler_class import Sampler, CompiledSampler + + +class StimCompiledDetectorSampler(CompiledSampler): + def __init__(self, circuit: stim.Circuit): + self.sampler = circuit.compile_detector_sampler() + + def sample_detectors_bit_packed( + self, + *, + shots: int, + ) -> Tuple[np.ndarray, np.ndarray]: + return self.sampler.sample(shots, separate_observables=True, bit_packed=True) + + +class StimDetectorSampler(Sampler): + """Use `stim.CompiledDetectorSampler` to sample detectors from a circuit.""" + def compile_sampler_for_circuit(self, *, circuit: stim.Circuit) -> CompiledSampler: + return StimCompiledDetectorSampler(circuit) + + def sample_detectors_via_files( + self, + *, + shots: int, + circuit_path: pathlib.Path, + dets_b8_out_path: pathlib.Path, + obs_flips_b8_out_path: pathlib.Path, + tmp_dir: pathlib.Path, + ) -> None: + circuit = stim.Circuit.from_file(circuit_path) + sampler = circuit.compile_detector_sampler() + sampler.sample_write( + shots, + filepath=str(dets_b8_out_path), + format="b8", + obs_out_filepath=str(obs_flips_b8_out_path), + obs_out_format="b8", + ) diff --git a/glue/sample/src/sinter/_sampling_vacuous.py b/glue/sample/src/sinter/_sampling_vacuous.py new file mode 100644 index 000000000..8a9e32c3f --- /dev/null +++ b/glue/sample/src/sinter/_sampling_vacuous.py @@ -0,0 +1,47 @@ +from typing import Tuple +import pathlib + +import stim +import numpy as np + +from sinter._sampling_sampler_class import Sampler, CompiledSampler + + +class VacuousCompiledSampler(CompiledSampler): + def __init__(self, detectors_shape: int, obs_shape: int): + self.detectors_shape = detectors_shape + self.obs_shape = obs_shape + + def sample_detectors_bit_packed( + self, + *, + shots: int, + ) -> Tuple[np.ndarray, np.ndarray]: + return np.zeros(shape=(shots, self.detectors_shape), dtype=np.uint8), np.zeros( + shape=(shots, self.obs_shape), dtype=np.uint8 + ) + + +class VacuousSampler(Sampler): + """An example sampler that always sample zero-valued detectors and zero-valued observables.""" + def compile_sampler_for_circuit(self, *, circuit: stim.Circuit) -> CompiledSampler: + return VacuousCompiledSampler( + (circuit.num_detectors + 7) // 8, (circuit.num_observables + 7) // 8 + ) + + def sample_detectors_via_files( + self, + *, + shots: int, + circuit_path: pathlib.Path, + dets_b8_out_path: pathlib.Path, + obs_flips_b8_out_path: pathlib.Path, + tmp_dir: pathlib.Path, + ) -> None: + circuit = stim.Circuit.from_file(circuit_path) + num_detectors = circuit.num_detectors + num_obs = circuit.num_observables + with open(dets_b8_out_path, "wb") as f: + f.write(b"\0" * (num_detectors * shots)) + with open(obs_flips_b8_out_path, "wb") as f: + f.write(b"\0" * (num_obs * shots)) diff --git a/glue/sample/src/sinter/_task.py b/glue/sample/src/sinter/_task.py index c358bde3b..ad822dbb6 100644 --- a/glue/sample/src/sinter/_task.py +++ b/glue/sample/src/sinter/_task.py @@ -21,6 +21,9 @@ class Task: Attributes: circuit: The annotated noisy circuit to sample detection event data and logical observable data form. + sampler: The sampler to use to sample detectors from the circuit. + This can be set to None if it will be specified later (e.g. by + the call to `collect`). Defaults to 'stim'. decoder: The decoder to use to predict the logical observable data from the detection event data. This can be set to None if it will be specified later (e.g. by the call to `collect`). @@ -67,6 +70,7 @@ def __init__( self, *, circuit: Optional['stim.Circuit'] = None, + sampler: Optional[str] = 'stim', decoder: Optional[str] = None, detector_error_model: Optional['stim.DetectorErrorModel'] = None, postselection_mask: Optional[np.ndarray] = None, @@ -81,6 +85,7 @@ def __init__( Args: circuit: The annotated noisy circuit to sample detection event data and logical observable data form. + sampler: The sampler to use to sample detectors from the circuit. decoder: The decoder to use to predict the logical observable data from the detection event data. This can be set to None if it will be specified later (e.g. by the call to `collect`). @@ -156,6 +161,7 @@ def __init__( raise ValueError(f"postselected_observables_mask.dtype={postselected_observables_mask.dtype!r} != np.uint8") self.circuit_path = None if circuit_path is None else pathlib.Path(circuit_path) self.circuit = circuit + self.sampler = sampler self.decoder = decoder self.detector_error_model = detector_error_model self.postselection_mask = postselection_mask @@ -185,6 +191,8 @@ def strong_id_value(self) -> Dict[str, Any]: """ if self.circuit is None: raise ValueError("Can't compute strong_id until `circuit` is set.") + if self.sampler is None: + raise ValueError("Can't compute strong_id until `sampler` is set.") if self.decoder is None: raise ValueError("Can't compute strong_id until `decoder` is set.") if self.detector_error_model is None: @@ -201,6 +209,10 @@ def strong_id_value(self) -> Dict[str, Any]: } if self.postselected_observables_mask is not None: result['postselected_observables_mask'] = [int(e) for e in self.postselected_observables_mask] + # Do not include the sampler if it is the default value "stim". + # This is for backwards compatibility. + if self.sampler != "stim": + result["sampler"] = self.sampler return result def strong_id_text(self) -> str: @@ -274,6 +286,8 @@ def __repr__(self) -> str: terms = [] if self.circuit is not None: terms.append(f'circuit={self.circuit!r}') + if self.sampler is not None: + terms.append(f"sampler={self.sampler!r}") if self.decoder is not None: terms.append(f'decoder={self.decoder!r}') if self.detector_error_model is not None: @@ -302,6 +316,7 @@ def __eq__(self, other: Any) -> bool: return ( self.circuit_path == other.circuit_path and self.circuit == other.circuit and + self.sampler == other.sampler and self.decoder == other.decoder and self.detector_error_model == other.detector_error_model and np.array_equal(self.postselection_mask, other.postselection_mask) and diff --git a/glue/sample/src/sinter/_task_stats.py b/glue/sample/src/sinter/_task_stats.py index 73a2dd162..7a8736009 100644 --- a/glue/sample/src/sinter/_task_stats.py +++ b/glue/sample/src/sinter/_task_stats.py @@ -17,6 +17,7 @@ class TaskStats: Attributes: strong_id: The cryptographically unique identifier of the task, from `sinter.Task.strong_id()`. + sampler: The name of the sampler that was used to sample the task. decoder: The name of the decoder that was used to decode the task. Errors are counted when this decoder made a wrong prediction. json_metadata: A JSON-encodable value (such as a dictionary from strings @@ -44,6 +45,7 @@ class TaskStats: strong_id: str decoder: str json_metadata: Any + sampler: str = "stim" # Information describing the results of sampling. shots: int = 0 @@ -58,6 +60,7 @@ def __post_init__(self): assert isinstance(self.discards, int) assert isinstance(self.seconds, (int, float)) assert isinstance(self.custom_counts, collections.Counter) + assert isinstance(self.sampler, str) assert isinstance(self.decoder, str) assert isinstance(self.strong_id, str) assert self.json_metadata is None or isinstance(self.json_metadata, (int, float, str, dict, list, tuple)) @@ -73,6 +76,7 @@ def __add__(self, other: 'TaskStats') -> 'TaskStats': total = self.to_anon_stats() + other.to_anon_stats() return TaskStats( + sampler=self.sampler, decoder=self.decoder, strong_id=self.strong_id, json_metadata=self.json_metadata, @@ -132,6 +136,7 @@ def to_csv_line(self) -> str: seconds=self.seconds, discards=self.discards, strong_id=self.strong_id, + sampler=self.sampler, decoder=self.decoder, json_metadata=self.json_metadata, custom_counts=self.custom_counts, @@ -147,6 +152,7 @@ def _split_custom_counts(self, custom_keys: List[str]) -> List['TaskStats']: m.setdefault('original_error_count', self.errors) result.append(TaskStats( strong_id=f'{self.strong_id}:{k}', + sampler=self.sampler, decoder=self.decoder, json_metadata=m, shots=self.shots, @@ -163,6 +169,7 @@ def __str__(self) -> str: def __repr__(self) -> str: terms = [] terms.append(f'strong_id={self.strong_id!r}') + terms.append(f"sampler={self.sampler!r}") terms.append(f'decoder={self.decoder!r}') terms.append(f'json_metadata={self.json_metadata!r}') if self.shots: diff --git a/glue/sample/src/sinter/_task_stats_test.py b/glue/sample/src/sinter/_task_stats_test.py index 0847b003f..cc6b001fa 100644 --- a/glue/sample/src/sinter/_task_stats_test.py +++ b/glue/sample/src/sinter/_task_stats_test.py @@ -28,7 +28,7 @@ def test_to_csv_line(): discards=4, seconds=5, ) - assert v.to_csv_line() == str(v) == ' 22, 3, 4, 5,pymatching,test,"{""a"":[1,2,3]}",' + assert v.to_csv_line() == str(v) == ' 22, 3, 4, 5,stim,pymatching,test,"{""a"":[1,2,3]}",' def test_to_anon_stats(): diff --git a/glue/sample/src/sinter/_worker.py b/glue/sample/src/sinter/_worker.py index c01cd5ff9..59faf18a2 100644 --- a/glue/sample/src/sinter/_worker.py +++ b/glue/sample/src/sinter/_worker.py @@ -18,6 +18,7 @@ def __init__( work_key: Any, circuit_path: str, dem_path: str, + sampler: str, decoder: str, strong_id: Optional[str], postselection_mask: 'Optional[np.ndarray]', @@ -29,6 +30,7 @@ def __init__( self.work_key = work_key self.circuit_path = circuit_path self.dem_path = dem_path + self.sampler = sampler self.decoder = decoder self.strong_id = strong_id self.postselection_mask = postselection_mask @@ -43,6 +45,7 @@ def with_work_key(self, work_key: Any) -> 'WorkIn': work_key=work_key, circuit_path=self.circuit_path, dem_path=self.dem_path, + sampler=self.sampler, decoder=self.decoder, postselection_mask=self.postselection_mask, postselected_observables_mask=self.postselected_observables_mask, @@ -128,6 +131,7 @@ def __init__( def worker_loop(tmp_dir: 'pathlib.Path', inp: 'multiprocessing.Queue', out: 'multiprocessing.Queue', + custom_samplers: Optional[Dict[str, 'sinter.Sampler']], custom_decoders: Optional[Dict[str, 'sinter.Decoder']], core_affinity: Optional[int]) -> None: try: @@ -143,16 +147,24 @@ def worker_loop(tmp_dir: 'pathlib.Path', work: Optional[WorkIn] = inp.get() if work is None: return - out.put(do_work_safely(work, child_dir, custom_decoders)) + out.put( + do_work_safely(work, child_dir, custom_samplers, custom_decoders) + ) except KeyboardInterrupt: pass -def do_work_safely(work: WorkIn, child_dir: str, custom_decoders: Dict[str, 'sinter.Decoder']) -> WorkOut: +def do_work_safely( + work: WorkIn, + child_dir: str, + custom_samplers: Dict[str, 'sinter.Sampler'], + custom_decoders: Dict[str, 'sinter.Decoder'], +) -> WorkOut: try: - return do_work(work, child_dir, custom_decoders) + return do_work(work, child_dir, custom_samplers, custom_decoders) except BaseException as ex: import traceback + return WorkOut( work_key=work.work_key, stats=None, @@ -161,10 +173,15 @@ def do_work_safely(work: WorkIn, child_dir: str, custom_decoders: Dict[str, 'sin ) -def do_work(work: WorkIn, child_dir: str, custom_decoders: Dict[str, 'sinter.Decoder']) -> WorkOut: +def do_work( + work: WorkIn, + child_dir: str, + custom_samplers: Dict[str, 'sinter.Sampler'], + custom_decoders: Dict[str, 'sinter.Decoder'], +) -> WorkOut: import stim from sinter._task import Task - from sinter._decoding import sample_decode + from sinter._sampling_and_decoding import sample_decode if work.strong_id is None: # The work is to compute the DEM, as opposed to taking shots. @@ -175,6 +192,7 @@ def do_work(work: WorkIn, child_dir: str, custom_decoders: Dict[str, 'sinter.Dec task = Task( circuit=circuit, + sampler=work.sampler, decoder=work.decoder, detector_error_model=dem, postselection_mask=work.postselection_mask, @@ -197,6 +215,8 @@ def do_work(work: WorkIn, child_dir: str, custom_decoders: Dict[str, 'sinter.Dec dem_obj=None, post_mask=work.postselection_mask, postselected_observable_mask=work.postselected_observables_mask, + sampler=work.sampler, + custom_samplers=custom_samplers, decoder=work.decoder, count_observable_error_combos=work.count_observable_error_combos, count_detection_events=work.count_detection_events, diff --git a/glue/sample/src/sinter/_worker_test.py b/glue/sample/src/sinter/_worker_test.py index d8049ee10..3f5e65aa4 100644 --- a/glue/sample/src/sinter/_worker_test.py +++ b/glue/sample/src/sinter/_worker_test.py @@ -25,6 +25,7 @@ def test_worker_loop_infers_dem(): work_key='test1', circuit_path=circuit_path, dem_path=dem_path, + sampler='stim', decoder='pymatching', json_metadata=5, strong_id=None, @@ -35,7 +36,7 @@ def test_worker_loop_infers_dem(): count_observable_error_combos=False, )) inp.put(None) - worker_loop(tmp_dir, inp, out, None, 0) + worker_loop(tmp_dir, inp, out, None, None, 0) result: WorkOut = out.get(timeout=1) assert out.empty() @@ -66,6 +67,7 @@ def test_worker_loop_does_not_recompute_dem(): work_key='test1', circuit_path=circuit_path, dem_path=dem_path, + sampler='stim', decoder='pymatching', json_metadata=5, strong_id="fake", @@ -76,7 +78,7 @@ def test_worker_loop_does_not_recompute_dem(): count_observable_error_combos=False, )) inp.put(None) - worker_loop(tmp_dir, inp, out, None, 0) + worker_loop(tmp_dir, inp, out, None, None, 0) result: WorkOut = out.get(timeout=1) assert out.empty() From ac27545bb131ed64b24415101e0b3a7cc8f17ce2 Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Thu, 28 Mar 2024 19:28:42 +0800 Subject: [PATCH 2/8] Remove default sampler value in `Task` and `TaskStats` --- glue/sample/src/sinter/_collection_test.py | 6 ++++++ glue/sample/src/sinter/_existing_data.py | 2 ++ glue/sample/src/sinter/_existing_data_test.py | 14 +++++++------- glue/sample/src/sinter/_task.py | 4 ++-- glue/sample/src/sinter/_task_stats.py | 3 ++- glue/sample/src/sinter/_task_stats_test.py | 7 +++++++ 6 files changed, 26 insertions(+), 10 deletions(-) diff --git a/glue/sample/src/sinter/_collection_test.py b/glue/sample/src/sinter/_collection_test.py index 3ca72a5e9..94526d2b8 100644 --- a/glue/sample/src/sinter/_collection_test.py +++ b/glue/sample/src/sinter/_collection_test.py @@ -19,6 +19,7 @@ def test_iter_collect(): rounds=3, distance=3, after_clifford_depolarization=p), + sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -53,6 +54,7 @@ def test_collect(): rounds=3, distance=3, after_clifford_depolarization=p), + sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -92,6 +94,7 @@ def test_collect_from_paths(): ).to_file(path) tasks.append(sinter.Task( circuit_path=path, + sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -150,6 +153,7 @@ def test_collect_custom_decoder(): ) ], max_shots=10000, + samplers=['stim'], decoders=['alternate'], custom_decoders={'alternate': AlternatingPredictionsDecoder()}, ) @@ -169,6 +173,7 @@ def test_iter_collect_list(): rounds=3, distance=3, after_clifford_depolarization=p), + sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -196,6 +201,7 @@ def test_iter_collect_list(): def test_iter_collect_worker_fails(): with pytest.raises(RuntimeError, match="Worker failed"): _ = list(sinter.iter_collect( + samplers=['stim'], decoders=['NOT A VALID DECODER'], num_workers=1, tasks=iter([ diff --git a/glue/sample/src/sinter/_existing_data.py b/glue/sample/src/sinter/_existing_data.py index 2deb1d608..d539bc685 100644 --- a/glue/sample/src/sinter/_existing_data.py +++ b/glue/sample/src/sinter/_existing_data.py @@ -38,6 +38,8 @@ def __iadd__(self, other: 'ExistingData') -> 'ExistingData': @staticmethod def from_file(path_or_file: Any) -> 'ExistingData': + # Do not expect 'sampler' field in CSV files. + # This is for backwards compatibility. expected_fields = { "shots", "discards", diff --git a/glue/sample/src/sinter/_existing_data_test.py b/glue/sample/src/sinter/_existing_data_test.py index db742f3dd..9c7f50d23 100644 --- a/glue/sample/src/sinter/_existing_data_test.py +++ b/glue/sample/src/sinter/_existing_data_test.py @@ -18,8 +18,8 @@ def test_read_stats_from_csv_files(): """.strip(), file=f) assert sinter.read_stats_from_csv_files(d / 'tmp.csv') == [ - sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0), - sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), + sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0), + sinter.TaskStats(strong_id='def456', sampler='stim', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), ] with open(d / 'tmp2.csv', 'w') as f: @@ -31,13 +31,13 @@ def test_read_stats_from_csv_files(): """.strip(), file=f) assert sinter.read_stats_from_csv_files(d / 'tmp2.csv') == [ - sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})), - sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), + sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})), + sinter.TaskStats(strong_id='def456', sampler='stim', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), ] assert sinter.read_stats_from_csv_files(d / 'tmp.csv', d / 'tmp2.csv') == [ - sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=2600, errors=8, discards=120, seconds=8.0, custom_counts=collections.Counter({'dets': 1234})), - sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=4000, errors=0, discards=20, seconds=4.0), + sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=2600, errors=8, discards=120, seconds=8.0, custom_counts=collections.Counter({'dets': 1234})), + sinter.TaskStats(strong_id='def456', sampler='stim', decoder='pymatching', json_metadata={'d': 5}, shots=4000, errors=0, discards=20, seconds=4.0), ] with open(d / 'tmp3.csv', 'w') as f: @@ -49,6 +49,6 @@ def test_read_stats_from_csv_files(): """.strip(), file=f) assert sinter.read_stats_from_csv_files(d / 'tmp3.csv') == [ - sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})), + sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})), sinter.TaskStats(strong_id='def456', sampler='mock', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), ] diff --git a/glue/sample/src/sinter/_task.py b/glue/sample/src/sinter/_task.py index ad822dbb6..6fd246051 100644 --- a/glue/sample/src/sinter/_task.py +++ b/glue/sample/src/sinter/_task.py @@ -23,7 +23,7 @@ class Task: and logical observable data form. sampler: The sampler to use to sample detectors from the circuit. This can be set to None if it will be specified later (e.g. by - the call to `collect`). Defaults to 'stim'. + the call to `collect`). decoder: The decoder to use to predict the logical observable data from the detection event data. This can be set to None if it will be specified later (e.g. by the call to `collect`). @@ -70,7 +70,7 @@ def __init__( self, *, circuit: Optional['stim.Circuit'] = None, - sampler: Optional[str] = 'stim', + sampler: Optional[str] = None, decoder: Optional[str] = None, detector_error_model: Optional['stim.DetectorErrorModel'] = None, postselection_mask: Optional[np.ndarray] = None, diff --git a/glue/sample/src/sinter/_task_stats.py b/glue/sample/src/sinter/_task_stats.py index 7a8736009..294385643 100644 --- a/glue/sample/src/sinter/_task_stats.py +++ b/glue/sample/src/sinter/_task_stats.py @@ -43,9 +43,9 @@ class TaskStats: # Information describing the problem that was sampled. strong_id: str + sampler: str decoder: str json_metadata: Any - sampler: str = "stim" # Information describing the results of sampling. shots: int = 0 @@ -95,6 +95,7 @@ def to_anon_stats(self) -> AnonTaskStats: >>> stat = sinter.TaskStats( ... strong_id='test', ... json_metadata={'a': [1, 2, 3]}, + ... sampler='stim', ... decoder='pymatching', ... shots=22, ... errors=3, diff --git a/glue/sample/src/sinter/_task_stats_test.py b/glue/sample/src/sinter/_task_stats_test.py index cc6b001fa..bbfd35c8f 100644 --- a/glue/sample/src/sinter/_task_stats_test.py +++ b/glue/sample/src/sinter/_task_stats_test.py @@ -9,6 +9,7 @@ def test_repr(): v = sinter.TaskStats( strong_id='test', json_metadata={'a': [1, 2, 3]}, + sampler='stim', decoder='pymatching', shots=22, errors=3, @@ -22,6 +23,7 @@ def test_to_csv_line(): v = sinter.TaskStats( strong_id='test', json_metadata={'a': [1, 2, 3]}, + sampler='stim', decoder='pymatching', shots=22, errors=3, @@ -35,6 +37,7 @@ def test_to_anon_stats(): v = sinter.TaskStats( strong_id='test', json_metadata={'a': [1, 2, 3]}, + sampler='stim', decoder='pymatching', shots=22, errors=3, @@ -46,6 +49,7 @@ def test_to_anon_stats(): def test_add(): a = sinter.TaskStats( + sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdef', @@ -56,6 +60,7 @@ def test_add(): custom_counts=collections.Counter({'a': 10, 'b': 20}), ) b = sinter.TaskStats( + sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdef', @@ -66,6 +71,7 @@ def test_add(): custom_counts=collections.Counter({'a': 1, 'c': 3}), ) c = sinter.TaskStats( + sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdef', @@ -78,6 +84,7 @@ def test_add(): assert a + b == c with pytest.raises(ValueError): a + sinter.TaskStats( + sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdefDIFFERENT', From 965b58bb12a2dc4556a69ecba3a4a6defb14c895 Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Thu, 28 Mar 2024 19:57:34 +0800 Subject: [PATCH 3/8] Revert the default 'stim' sampler for `TaskStats` and `collect` for backwards compatibility --- glue/sample/src/sinter/_collection.py | 24 ++++++++++--------- .../src/sinter/_collection_work_manager.py | 14 +++++------ glue/sample/src/sinter/_existing_data_test.py | 12 +++++----- glue/sample/src/sinter/_task_stats.py | 2 +- glue/sample/src/sinter/_task_stats_test.py | 7 ------ 5 files changed, 27 insertions(+), 32 deletions(-) diff --git a/glue/sample/src/sinter/_collection.py b/glue/sample/src/sinter/_collection.py index d1d918d13..58e897702 100644 --- a/glue/sample/src/sinter/_collection.py +++ b/glue/sample/src/sinter/_collection.py @@ -45,7 +45,7 @@ def iter_collect(*, additional_existing_data: Optional[ExistingData] = None, max_shots: Optional[int] = None, max_errors: Optional[int] = None, - samplers: Optional[Iterable[str]] = None, + samplers: Optional[Iterable[str]] = ('stim',), decoders: Optional[Iterable[str]] = None, max_batch_seconds: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -76,10 +76,11 @@ def iter_collect(*, additional_existing_data: Defaults to None (no additional data). Statistical data that has already been collected, in addition to anything included in each task's `previous_stats` field. - samplers: Defaults to None (specified by each Task). The names of the - samplers to use on each Task. It must either be the case that each - Task specifies a sampler and this is set to None, or this is an - iterable and each Task has its sampler set to None. + samplers: Defaults to ('stim',). The names of the samplers to use on each + Task. It must either be the case that each Task specifies a sampler + and this is set to None, or this is an iterable and each Task has + its sampler set to None. If both are set, samplers specified here + will be used instead. decoders: Defaults to None (specified by each Task). The names of the decoders to use on each Task. It must either be the case that each Task specifies a decoder and this is set to None, or this is an @@ -191,7 +192,7 @@ def iter_collect(*, count_observable_error_combos=count_observable_error_combos, count_detection_events=count_detection_events, additional_existing_data=additional_existing_data, - custom_samplers=custom_samplers, + custom_samplers=custom_samplers or dict(), custom_decoders=custom_decoders, custom_error_count_key=custom_error_count_key, allowed_cpu_affinity_ids=allowed_cpu_affinity_ids, @@ -241,7 +242,7 @@ def collect(*, max_errors: Optional[int] = None, count_observable_error_combos: bool = False, count_detection_events: bool = False, - samplers: Optional[Iterable[str]] = None, + samplers: Optional[Iterable[str]] = ('stim',), decoders: Optional[Iterable[str]] = None, max_batch_seconds: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -275,10 +276,11 @@ def collect(*, hint_num_tasks: If `tasks` is an iterator or a generator, its length can be given here so that progress printouts can say how many cases are left. - samplers: Defaults to None (specified by each Task). The names of the - samplers to use on each Task. It must either be the case that each - Task specifies a sampler and this is set to None, or this is an - iterable and each Task has its sampler set to None. + samplers: Defaults to ('stim',). The names of the samplers to use on each + Task. It must either be the case that each Task specifies a sampler + and this is set to None, or this is an iterable and each Task has + its sampler set to None. If both are set, samplers specified here + will be used instead. decoders: Defaults to None (specified by each Task). The names of the decoders to use on each Task. It must either be the case that each Task specifies a decoder and this is set to None, or this is an diff --git a/glue/sample/src/sinter/_collection_work_manager.py b/glue/sample/src/sinter/_collection_work_manager.py index 6d795cd4a..f508ff2f8 100644 --- a/glue/sample/src/sinter/_collection_work_manager.py +++ b/glue/sample/src/sinter/_collection_work_manager.py @@ -28,7 +28,7 @@ def __init__( additional_existing_data: Optional[ExistingData], count_observable_error_combos: bool, count_detection_events: bool, - samplers: Optional[Iterable[str]], + samplers: Optional[Iterable[str]] = ('stim',), custom_samplers: Dict[str, Sampler], decoders: Optional[Iterable[str]], custom_decoders: Dict[str, Decoder], @@ -276,15 +276,15 @@ def _iter_tasks_with_assigned_samplers_decoders( circuit_path=task.circuit_path, ) - if task.sampler is None and default_samplers is None: + if default_samplers is not None: + task_samplers = list(default_samplers) + elif task.sampler is not None: + task_samplers = [task.sampler] + else: raise ValueError("Samplers to use was not specified. samplers is None and task.sampler is None") + if task.decoder is None and default_decoders is None: raise ValueError("Decoders to use was not specified. decoders is None and task.decoder is None") - task_samplers = [] - if default_samplers is not None: - task_samplers.extend(default_samplers) - if task.sampler is not None and task.sampler not in task_samplers: - task_samplers.append(task.sampler) task_decoders = [] if default_decoders is not None: diff --git a/glue/sample/src/sinter/_existing_data_test.py b/glue/sample/src/sinter/_existing_data_test.py index 9c7f50d23..381d04fa3 100644 --- a/glue/sample/src/sinter/_existing_data_test.py +++ b/glue/sample/src/sinter/_existing_data_test.py @@ -18,8 +18,8 @@ def test_read_stats_from_csv_files(): """.strip(), file=f) assert sinter.read_stats_from_csv_files(d / 'tmp.csv') == [ - sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0), - sinter.TaskStats(strong_id='def456', sampler='stim', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), + sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0), + sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), ] with open(d / 'tmp2.csv', 'w') as f: @@ -31,13 +31,13 @@ def test_read_stats_from_csv_files(): """.strip(), file=f) assert sinter.read_stats_from_csv_files(d / 'tmp2.csv') == [ - sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})), - sinter.TaskStats(strong_id='def456', sampler='stim', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), + sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})), + sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0), ] assert sinter.read_stats_from_csv_files(d / 'tmp.csv', d / 'tmp2.csv') == [ - sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=2600, errors=8, discards=120, seconds=8.0, custom_counts=collections.Counter({'dets': 1234})), - sinter.TaskStats(strong_id='def456', sampler='stim', decoder='pymatching', json_metadata={'d': 5}, shots=4000, errors=0, discards=20, seconds=4.0), + sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=2600, errors=8, discards=120, seconds=8.0, custom_counts=collections.Counter({'dets': 1234})), + sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=4000, errors=0, discards=20, seconds=4.0), ] with open(d / 'tmp3.csv', 'w') as f: diff --git a/glue/sample/src/sinter/_task_stats.py b/glue/sample/src/sinter/_task_stats.py index 294385643..ce79bd367 100644 --- a/glue/sample/src/sinter/_task_stats.py +++ b/glue/sample/src/sinter/_task_stats.py @@ -43,9 +43,9 @@ class TaskStats: # Information describing the problem that was sampled. strong_id: str - sampler: str decoder: str json_metadata: Any + sampler: str = 'stim' # Information describing the results of sampling. shots: int = 0 diff --git a/glue/sample/src/sinter/_task_stats_test.py b/glue/sample/src/sinter/_task_stats_test.py index bbfd35c8f..cc6b001fa 100644 --- a/glue/sample/src/sinter/_task_stats_test.py +++ b/glue/sample/src/sinter/_task_stats_test.py @@ -9,7 +9,6 @@ def test_repr(): v = sinter.TaskStats( strong_id='test', json_metadata={'a': [1, 2, 3]}, - sampler='stim', decoder='pymatching', shots=22, errors=3, @@ -23,7 +22,6 @@ def test_to_csv_line(): v = sinter.TaskStats( strong_id='test', json_metadata={'a': [1, 2, 3]}, - sampler='stim', decoder='pymatching', shots=22, errors=3, @@ -37,7 +35,6 @@ def test_to_anon_stats(): v = sinter.TaskStats( strong_id='test', json_metadata={'a': [1, 2, 3]}, - sampler='stim', decoder='pymatching', shots=22, errors=3, @@ -49,7 +46,6 @@ def test_to_anon_stats(): def test_add(): a = sinter.TaskStats( - sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdef', @@ -60,7 +56,6 @@ def test_add(): custom_counts=collections.Counter({'a': 10, 'b': 20}), ) b = sinter.TaskStats( - sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdef', @@ -71,7 +66,6 @@ def test_add(): custom_counts=collections.Counter({'a': 1, 'c': 3}), ) c = sinter.TaskStats( - sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdef', @@ -84,7 +78,6 @@ def test_add(): assert a + b == c with pytest.raises(ValueError): a + sinter.TaskStats( - sampler='stim', decoder='pymatching', json_metadata={'a': 2}, strong_id='abcdefDIFFERENT', From c4ee7636e68e1c7572ca28933e1c7e9babd5b1a3 Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Thu, 28 Mar 2024 20:04:48 +0800 Subject: [PATCH 4/8] Fix doctests failure --- glue/sample/src/sinter/_existing_data.py | 8 ++++---- glue/sample/src/sinter/_task.py | 2 +- glue/sample/src/sinter/_task_stats.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/glue/sample/src/sinter/_existing_data.py b/glue/sample/src/sinter/_existing_data.py index d539bc685..b14b4f728 100644 --- a/glue/sample/src/sinter/_existing_data.py +++ b/glue/sample/src/sinter/_existing_data.py @@ -125,8 +125,8 @@ def stats_from_csv_files(*paths_or_files: Any) -> List['sinter.TaskStats']: >>> stats = sinter.stats_from_csv_files(in_memory_file) >>> for stat in stats: ... print(repr(stat)) - sinter.TaskStats(strong_id='9c31908e2b', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25) - sinter.TaskStats(strong_id='deadbeef08', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125) + sinter.TaskStats(strong_id='9c31908e2b', sampler='stim', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25) + sinter.TaskStats(strong_id='deadbeef08', sampler='stim', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125) """ result = ExistingData() for p in paths_or_files: @@ -166,8 +166,8 @@ def read_stats_from_csv_files(*paths_or_files: Any) -> List['sinter.TaskStats']: >>> stats = sinter.read_stats_from_csv_files(in_memory_file) >>> for stat in stats: ... print(repr(stat)) - sinter.TaskStats(strong_id='9c31908e2b', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25) - sinter.TaskStats(strong_id='deadbeef08', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125) + sinter.TaskStats(strong_id='9c31908e2b', sampler='stim', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25) + sinter.TaskStats(strong_id='deadbeef08', sampler='stim', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125) """ result = ExistingData() for p in paths_or_files: diff --git a/glue/sample/src/sinter/_task.py b/glue/sample/src/sinter/_task.py index 6fd246051..70772c659 100644 --- a/glue/sample/src/sinter/_task.py +++ b/glue/sample/src/sinter/_task.py @@ -70,7 +70,7 @@ def __init__( self, *, circuit: Optional['stim.Circuit'] = None, - sampler: Optional[str] = None, + sampler: str = 'stim', decoder: Optional[str] = None, detector_error_model: Optional['stim.DetectorErrorModel'] = None, postselection_mask: Optional[np.ndarray] = None, diff --git a/glue/sample/src/sinter/_task_stats.py b/glue/sample/src/sinter/_task_stats.py index ce79bd367..c6348ffc4 100644 --- a/glue/sample/src/sinter/_task_stats.py +++ b/glue/sample/src/sinter/_task_stats.py @@ -127,9 +127,9 @@ def to_csv_line(self) -> str: ... seconds=5, ... ) >>> print(sinter.CSV_HEADER) - shots, errors, discards, seconds,decoder,strong_id,json_metadata,custom_counts + shots, errors, discards, seconds,sampler,decoder,strong_id,json_metadata,custom_counts >>> print(stat.to_csv_line()) - 22, 3, 0, 5,pymatching,test,"{""a"":[1,2,3]}", + 22, 3, 0, 5,stim,pymatching,test,"{""a"":[1,2,3]}", """ return csv_line( shots=self.shots, From caf41af9a0369e50519767bf7e3fda40bc412484 Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Thu, 28 Mar 2024 20:20:49 +0800 Subject: [PATCH 5/8] Revert unnecessary changes in tests --- glue/sample/src/sinter/_collection_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/glue/sample/src/sinter/_collection_test.py b/glue/sample/src/sinter/_collection_test.py index 94526d2b8..3ca72a5e9 100644 --- a/glue/sample/src/sinter/_collection_test.py +++ b/glue/sample/src/sinter/_collection_test.py @@ -19,7 +19,6 @@ def test_iter_collect(): rounds=3, distance=3, after_clifford_depolarization=p), - sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -54,7 +53,6 @@ def test_collect(): rounds=3, distance=3, after_clifford_depolarization=p), - sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -94,7 +92,6 @@ def test_collect_from_paths(): ).to_file(path) tasks.append(sinter.Task( circuit_path=path, - sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -153,7 +150,6 @@ def test_collect_custom_decoder(): ) ], max_shots=10000, - samplers=['stim'], decoders=['alternate'], custom_decoders={'alternate': AlternatingPredictionsDecoder()}, ) @@ -173,7 +169,6 @@ def test_iter_collect_list(): rounds=3, distance=3, after_clifford_depolarization=p), - sampler='stim', decoder='pymatching', json_metadata={'p': p}, collection_options=sinter.CollectionOptions( @@ -201,7 +196,6 @@ def test_iter_collect_list(): def test_iter_collect_worker_fails(): with pytest.raises(RuntimeError, match="Worker failed"): _ = list(sinter.iter_collect( - samplers=['stim'], decoders=['NOT A VALID DECODER'], num_workers=1, tasks=iter([ From 0bf14f84eaa429403b7bd750c956ecd4cfa68d7c Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Tue, 7 May 2024 15:33:45 +0800 Subject: [PATCH 6/8] add `sampler` arg to plot funcs --- glue/sample/src/sinter/_main_plot.py | 32 ++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/glue/sample/src/sinter/_main_plot.py b/glue/sample/src/sinter/_main_plot.py index c6defeb90..ce9aaaee5 100644 --- a/glue/sample/src/sinter/_main_plot.py +++ b/glue/sample/src/sinter/_main_plot.py @@ -24,6 +24,7 @@ def parse_args(args: List[str]) -> Any: 'Values available to the python expression:\n' ' metadata: The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' sampler: The sampler that sampled the data for the data point.\n' ' decoder: The decoder that decoded the data for the data point.\n' ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' @@ -39,6 +40,7 @@ def parse_args(args: List[str]) -> Any: 'Values available to the python expression:\n' ' metadata: The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' sampler: The sampler that sampled the data for the data point.\n' ' decoder: The decoder that decoded the data for the data point.\n' ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' @@ -58,6 +60,7 @@ def parse_args(args: List[str]) -> Any: 'Values available to the python expression:\n' ' metadata: The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' sampler: The sampler that sampled the data for the data point.\n' ' decoder: The decoder that decoded the data for the data point.\n' ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' @@ -79,6 +82,7 @@ def parse_args(args: List[str]) -> Any: 'Values available to the python expression:\n' ' metadata: The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' sampler: The sampler that sampled the data for the data point.\n' ' decoder: The decoder that decoded the data for the data point.\n' ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' @@ -112,6 +116,7 @@ def parse_args(args: List[str]) -> Any: 'Values available to the python expression:\n' ' metadata: The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' sampler: The sampler that sampled the data for the data point.\n' ' decoder: The decoder that decoded the data for the data point.\n' ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' @@ -138,6 +143,7 @@ def parse_args(args: List[str]) -> Any: 'Values available to the python expression:\n' ' metadata: The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' sampler: The sampler that sampled the data for the data point.\n' ' decoder: The decoder that decoded the data for the data point.\n' ' strong_id: The cryptographic hash of the case that was sampled for the data point.\n' ' stat: The sinter.TaskStats object for the data point.\n' @@ -159,6 +165,7 @@ def parse_args(args: List[str]) -> Any: ' stats: The list of sinter.TaskStats object in the group.\n' ' metadata: (From one arbitrary data point in the group.) The parsed value from the json_metadata for the data point.\n' ' m: `m.key` is a shorthand for `metadata.get("key", None)`.\n' + ' sampler: (From one arbitrary data point in the group.) The sampler that sampled the data for the data point.\n' ' decoder: (From one arbitrary data point in the group.) The decoder that decoded the data for the data point.\n' ' strong_id: (From one arbitrary data point in the group.) The cryptographic hash of the case that was sampled for the data point.\n' ' stat: (From one arbitrary data point in the group.) The sinter.TaskStats object for the data point.\n' @@ -266,32 +273,32 @@ def parse_args(args: List[str]) -> Any: if a.failure_unit_name is None: a.failure_unit_name = 'shot' a.x_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.x_func}', + f'lambda *, stat, sampler, decoder, metadata, m, strong_id: {a.x_func}', filename='x_func:command_line_arg', mode='eval')) if a.y_func is not None: a.y_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.y_func}', + f'lambda *, stat, sampler, decoder, metadata, m, strong_id: {a.y_func}', filename='x_func:command_line_arg', mode='eval')) a.group_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.group_func}', + f'lambda *, stat, sampler, decoder, metadata, m, strong_id: {a.group_func}', filename='group_func:command_line_arg', mode='eval')) a.filter_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.filter_func}', + f'lambda *, stat, sampler, decoder, metadata, m, strong_id: {a.filter_func}', filename='filter_func:command_line_arg', mode='eval')) a.failure_units_per_shot_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.failure_units_per_shot_func}', + f'lambda *, stat, sampler, decoder, metadata, m, strong_id: {a.failure_units_per_shot_func}', filename='failure_units_per_shot_func:command_line_arg', mode='eval')) a.failure_values_func = eval(compile( - f'lambda *, stat, decoder, metadata, m, strong_id: {a.failure_values_func}', + f'lambda *, stat, sampler, decoder, metadata, m, strong_id: {a.failure_values_func}', filename='failure_values_func:command_line_arg', mode='eval')) a.plot_args_func = eval(compile( - f'lambda *, index, key, stats, stat, decoder, metadata, m, strong_id: {a.plot_args_func}', + f'lambda *, index, key, stats, stat, sampler, decoder, metadata, m, strong_id: {a.plot_args_func}', filename='plot_args_func:command_line_arg', mode='eval')) return a @@ -456,6 +463,10 @@ def _common_json_properties(stats: List['sinter.TaskStats']) -> Dict[str, Any]: v = stat.json_metadata.get(k) if v is None or isinstance(v, (float, str, int)): vals[k].add(v) + if 'sampler' not in vals: + vals['sampler'] = set() + for stat in stats: + vals['sampler'].add(stat.sampler) if 'decoder' not in vals: vals['decoder'] = set() for stat in stats: @@ -712,36 +723,42 @@ def main_plot(*, command_line_args: List[str]): samples=total, group_func=lambda stat: args.group_func( stat=stat, + sampler=stat.sampler, decoder=stat.decoder, metadata=stat.json_metadata, m=_FieldToMetadataWrapper(stat.json_metadata), strong_id=stat.strong_id), x_func=lambda stat: args.x_func( stat=stat, + sampler=stat.sampler, decoder=stat.decoder, metadata=stat.json_metadata, m=_FieldToMetadataWrapper(stat.json_metadata), strong_id=stat.strong_id), y_func=None if args.y_func is None else lambda stat: args.y_func( stat=stat, + sampler=stat.sampler, decoder=stat.decoder, metadata=stat.json_metadata, m=_FieldToMetadataWrapper(stat.json_metadata), strong_id=stat.strong_id), filter_func=lambda stat: args.filter_func( stat=stat, + sampler=stat.sampler, decoder=stat.decoder, metadata=stat.json_metadata, m=_FieldToMetadataWrapper(stat.json_metadata), strong_id=stat.strong_id), failure_units_per_shot_func=lambda stat: args.failure_units_per_shot_func( stat=stat, + sampler=stat.sampler, decoder=stat.decoder, metadata=stat.json_metadata, m=_FieldToMetadataWrapper(stat.json_metadata), strong_id=stat.strong_id), failure_values_func=lambda stat: args.failure_values_func( stat=stat, + sampler=stat.sampler, decoder=stat.decoder, metadata=stat.json_metadata, m=_FieldToMetadataWrapper(stat.json_metadata), @@ -751,6 +768,7 @@ def main_plot(*, command_line_args: List[str]): key=group_key, stats=stats, stat=stats[0], + sampler=stats[0].sampler, decoder=stats[0].decoder, metadata=stats[0].json_metadata, m=_FieldToMetadataWrapper(stats[0].json_metadata), From e692cd2462fbab076108e0f8d145dfc1dc50d08a Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Tue, 7 May 2024 15:35:15 +0800 Subject: [PATCH 7/8] make the sampler specified by `Task` overwrite the default samplers by `collect` --- glue/sample/src/sinter/_collection.py | 3 +-- glue/sample/src/sinter/_collection_work_manager.py | 6 +++--- glue/sample/src/sinter/_task.py | 9 ++++++++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/glue/sample/src/sinter/_collection.py b/glue/sample/src/sinter/_collection.py index 58e897702..89153bbd7 100644 --- a/glue/sample/src/sinter/_collection.py +++ b/glue/sample/src/sinter/_collection.py @@ -79,8 +79,7 @@ def iter_collect(*, samplers: Defaults to ('stim',). The names of the samplers to use on each Task. It must either be the case that each Task specifies a sampler and this is set to None, or this is an iterable and each Task has - its sampler set to None. If both are set, samplers specified here - will be used instead. + its sampler set to None. decoders: Defaults to None (specified by each Task). The names of the decoders to use on each Task. It must either be the case that each Task specifies a decoder and this is set to None, or this is an diff --git a/glue/sample/src/sinter/_collection_work_manager.py b/glue/sample/src/sinter/_collection_work_manager.py index f508ff2f8..d3b5d86ce 100644 --- a/glue/sample/src/sinter/_collection_work_manager.py +++ b/glue/sample/src/sinter/_collection_work_manager.py @@ -276,10 +276,10 @@ def _iter_tasks_with_assigned_samplers_decoders( circuit_path=task.circuit_path, ) - if default_samplers is not None: - task_samplers = list(default_samplers) - elif task.sampler is not None: + if task.sampler is not None: task_samplers = [task.sampler] + elif default_samplers is not None: + task_samplers = list(default_samplers) else: raise ValueError("Samplers to use was not specified. samplers is None and task.sampler is None") diff --git a/glue/sample/src/sinter/_task.py b/glue/sample/src/sinter/_task.py index 70772c659..3069e70ec 100644 --- a/glue/sample/src/sinter/_task.py +++ b/glue/sample/src/sinter/_task.py @@ -70,7 +70,7 @@ def __init__( self, *, circuit: Optional['stim.Circuit'] = None, - sampler: str = 'stim', + sampler: Optional[str] = None, decoder: Optional[str] = None, detector_error_model: Optional['stim.DetectorErrorModel'] = None, postselection_mask: Optional[np.ndarray] = None, @@ -86,6 +86,8 @@ def __init__( circuit: The annotated noisy circuit to sample detection event data and logical observable data form. sampler: The sampler to use to sample detectors from the circuit. + This can be set to None if it will be specified later (e.g. by + the call to `collect`). decoder: The decoder to use to predict the logical observable data from the detection event data. This can be set to None if it will be specified later (e.g. by the call to `collect`). @@ -184,6 +186,7 @@ def strong_id_value(self) -> Dict[str, Any]: >>> task = sinter.Task( ... circuit=stim.Circuit('H 0'), ... detector_error_model=stim.DetectorErrorModel(), + ... sampler='stim', ... decoder='pymatching', ... ) >>> task.strong_id_value() @@ -228,6 +231,7 @@ def strong_id_text(self) -> str: >>> task = sinter.Task( ... circuit=stim.Circuit('H 0'), ... detector_error_model=stim.DetectorErrorModel(), + ... sampler='stim', ... decoder='pymatching', ... ) >>> task.strong_id_text() @@ -247,6 +251,7 @@ def strong_id_bytes(self) -> bytes: >>> task = sinter.Task( ... circuit=stim.Circuit('H 0'), ... detector_error_model=stim.DetectorErrorModel(), + ... sampler='stim', ... decoder='pymatching', ... ) >>> task.strong_id_bytes() @@ -263,6 +268,7 @@ def strong_id(self) -> str: This value is affected by: - The exact circuit. - The exact detector error model. + - The sampler. - The decoder. - The json metadata. - The postselection mask. @@ -273,6 +279,7 @@ def strong_id(self) -> str: >>> task = sinter.Task( ... circuit=stim.Circuit(), ... detector_error_model=stim.DetectorErrorModel(), + ... sampler='stim', ... decoder='pymatching', ... ) >>> task.strong_id() From 9ca48a62e4a5d31d4ddedf471ad0718df2701ab3 Mon Sep 17 00:00:00 2001 From: Yiming Zhang Date: Tue, 7 May 2024 15:35:41 +0800 Subject: [PATCH 8/8] print samplers in the progress status --- glue/sample/src/sinter/_collection_tracker_for_single_task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/glue/sample/src/sinter/_collection_tracker_for_single_task.py b/glue/sample/src/sinter/_collection_tracker_for_single_task.py index e34197b00..ac24a97e6 100644 --- a/glue/sample/src/sinter/_collection_tracker_for_single_task.py +++ b/glue/sample/src/sinter/_collection_tracker_for_single_task.py @@ -205,6 +205,7 @@ def status(self) -> str: t = math.ceil(t) t = f'{t}' terms = [ + f'{self.unfilled_task.sampler} '.rjust(22), f'{self.unfilled_task.decoder} '.rjust(22), f'processes={self.deployed_processes}'.ljust(13), f'~core_mins_left={t}'.ljust(24),