Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sinter custom sampler #735

Closed
wants to merge 9 commits into from
8 changes: 8 additions & 0 deletions glue/sample/src/sinter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from sinter._decoding_all_built_in_decoders import (
BUILT_IN_DECODERS,
)
from sinter._sampling_all_built_in_samplers import (
BUILT_IN_SAMPLERS,
)
from sinter._existing_data import (
read_stats_from_csv_files,
stats_from_csv_files,
Expand Down Expand Up @@ -54,3 +57,8 @@
CompiledDecoder,
Decoder,
)

from sinter._sampling_sampler_class import (
CompiledSampler,
Sampler,
)
27 changes: 27 additions & 0 deletions glue/sample/src/sinter/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ def iter_collect(*,
additional_existing_data: Optional[ExistingData] = None,
max_shots: Optional[int] = None,
max_errors: Optional[int] = None,
samplers: Optional[Iterable[str]] = ('stim',),
decoders: Optional[Iterable[str]] = None,
max_batch_seconds: Optional[int] = None,
max_batch_size: Optional[int] = None,
start_batch_size: Optional[int] = None,
count_observable_error_combos: bool = False,
count_detection_events: bool = False,
custom_samplers: Optional[Dict[str, "sinter.Sampler"]] = None,
custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None,
custom_error_count_key: Optional[str] = None,
allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
Expand All @@ -74,6 +76,11 @@ def iter_collect(*,
additional_existing_data: Defaults to None (no additional data).
Statistical data that has already been collected, in addition to
anything included in each task's `previous_stats` field.
samplers: Defaults to ('stim',). The names of the samplers to use on each
Task. It must either be the case that each Task specifies a sampler
and this is set to None, or this is an iterable and each Task has
its sampler set to None. If both are set, samplers specified here
will be used instead.
decoders: Defaults to None (specified by each Task). The names of the
decoders to use on each Task. It must either be the case that each
Task specifies a decoder and this is set to None, or this is an
Expand Down Expand Up @@ -107,6 +114,9 @@ def iter_collect(*,
taken per shot. This information is then used to predict the
biggest batch size that can finish in under the given number of
seconds. Limits each batch to be no larger than that.
custom_samplers: Custom samplers that can be used if requested by name.
If not specified, only samplers built into sinter, such as 'stim',
can be used.
custom_decoders: Custom decoders that can be used if requested by name.
If not specified, only decoders built into sinter, such as
'pymatching' and 'fusion_blossom', can be used.
Expand Down Expand Up @@ -156,6 +166,8 @@ def iter_collect(*,
>>> print(total_shots)
200
"""
if isinstance(samplers, str):
samplers = [samplers]
if isinstance(decoders, str):
decoders = [decoders]

Expand All @@ -175,10 +187,12 @@ def iter_collect(*,
start_batch_size=start_batch_size,
max_batch_size=max_batch_size,
),
samplers=samplers,
decoders=decoders,
count_observable_error_combos=count_observable_error_combos,
count_detection_events=count_detection_events,
additional_existing_data=additional_existing_data,
custom_samplers=custom_samplers or dict(),
custom_decoders=custom_decoders,
custom_error_count_key=custom_error_count_key,
allowed_cpu_affinity_ids=allowed_cpu_affinity_ids,
Expand Down Expand Up @@ -228,12 +242,14 @@ def collect(*,
max_errors: Optional[int] = None,
count_observable_error_combos: bool = False,
count_detection_events: bool = False,
samplers: Optional[Iterable[str]] = ('stim',),
decoders: Optional[Iterable[str]] = None,
max_batch_seconds: Optional[int] = None,
max_batch_size: Optional[int] = None,
start_batch_size: Optional[int] = None,
print_progress: bool = False,
hint_num_tasks: Optional[int] = None,
custom_samplers: Optional[Dict[str, "sinter.Sampler"]] = None,
custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None,
custom_error_count_key: Optional[str] = None,
allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
Expand All @@ -260,6 +276,11 @@ def collect(*,
hint_num_tasks: If `tasks` is an iterator or a generator, its length
can be given here so that progress printouts can say how many cases
are left.
samplers: Defaults to ('stim',). The names of the samplers to use on each
Task. It must either be the case that each Task specifies a sampler
and this is set to None, or this is an iterable and each Task has
its sampler set to None. If both are set, samplers specified here
will be used instead.
decoders: Defaults to None (specified by each Task). The names of the
decoders to use on each Task. It must either be the case that each
Task specifies a decoder and this is set to None, or this is an
Expand Down Expand Up @@ -295,6 +316,10 @@ def collect(*,
taken per shot. This information is then used to predict the
biggest batch size that can finish in under the given number of
seconds. Limits each batch to be no larger than that.
custom_samplers: Named child classes of `sinter.Sampler`, that can be
used if requested by name by a task or by the samplers list.
If not specified, only samplers with support built into sinter, such
as 'stim', can be used.
custom_decoders: Named child classes of `sinter.decoder`, that can be
used if requested by name by a task or by the decoders list.
If not specified, only decoders with support built into sinter, such
Expand Down Expand Up @@ -386,10 +411,12 @@ def collect(*,
max_batch_size=max_batch_size,
count_observable_error_combos=count_observable_error_combos,
count_detection_events=count_detection_events,
samplers=samplers,
decoders=decoders,
tasks=tasks,
hint_num_tasks=hint_num_tasks,
additional_existing_data=additional_existing_data,
custom_samplers=custom_samplers,
custom_decoders=custom_decoders,
custom_error_count_key=custom_error_count_key,
allowed_cpu_affinity_ids=allowed_cpu_affinity_ids,
Expand Down
4 changes: 2 additions & 2 deletions glue/sample/src/sinter/_collection_tracker_for_single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from typing import Iterator
from typing import Optional

import stim

from sinter._anon_task_stats import AnonTaskStats
from sinter._existing_data import ExistingData
from sinter._task import Task
Expand Down Expand Up @@ -167,6 +165,7 @@ def provide_more_work(self, *, desperate: bool) -> Optional[WorkIn]:
work_key=None,
circuit_path=self.circuit_path,
dem_path=self.dem_path,
sampler=self.unfilled_task.sampler,
decoder=self.unfilled_task.decoder,
postselection_mask=self.unfilled_task.postselection_mask,
postselected_observables_mask=self.unfilled_task.postselected_observables_mask,
Expand All @@ -189,6 +188,7 @@ def provide_more_work(self, *, desperate: bool) -> Optional[WorkIn]:
strong_id=self.task_strong_id,
circuit_path=self.circuit_path,
dem_path=self.dem_path,
sampler=self.unfilled_task.sampler,
decoder=self.unfilled_task.decoder,
postselection_mask=self.unfilled_task.postselection_mask,
postselected_observables_mask=self.unfilled_task.postselected_observables_mask,
Expand Down
44 changes: 37 additions & 7 deletions glue/sample/src/sinter/_collection_work_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import pathlib
import tempfile
import stim
import itertools
from typing import cast, Iterable, Optional, Iterator, Tuple, Dict, List

from sinter._sampling_sampler_class import Sampler
from sinter._decoding_decoder_class import Decoder
from sinter._collection_options import CollectionOptions
from sinter._existing_data import ExistingData
Expand All @@ -26,11 +28,14 @@ def __init__(
additional_existing_data: Optional[ExistingData],
count_observable_error_combos: bool,
count_detection_events: bool,
samplers: Optional[Iterable[str]] = ('stim',),
custom_samplers: Dict[str, Sampler],
decoders: Optional[Iterable[str]],
custom_decoders: Dict[str, Decoder],
custom_error_count_key: Optional[str],
allowed_cpu_affinity_ids: Optional[List[int]],
):
self.custom_samplers = custom_samplers
self.custom_decoders = custom_decoders
self.queue_from_workers: Optional[multiprocessing.Queue] = None
self.queue_to_workers: Optional[multiprocessing.Queue] = None
Expand All @@ -41,6 +46,7 @@ def __init__(
self.allowed_cpu_affinity_ids = None if allowed_cpu_affinity_ids is None else sorted(set(allowed_cpu_affinity_ids))

self.global_collection_options = global_collection_options
self.samplers: Optional[Tuple[str, ...]] = None if samplers is None else tuple(samplers)
self.decoders: Optional[Tuple[str, ...]] = None if decoders is None else tuple(decoders)
self.did_work = False

Expand All @@ -53,10 +59,12 @@ def __init__(
self.count_observable_error_combos = count_observable_error_combos
self.count_detection_events = count_detection_events

self.tasks_with_decoder_iter: Iterator[Task] = _iter_tasks_with_assigned_decoders(
self.tasks_with_sampler_decoder_iter: Iterator[Task] = _iter_tasks_with_assigned_samplers_decoders(
tasks_iter=tasks_iter,
default_samplers=self.samplers,
default_decoders=self.decoders,
global_collections_options=self.global_collection_options)
global_collections_options=self.global_collection_options,
)

def start_workers(self, num_workers: int) -> None:
assert self.tmp_dir is not None
Expand All @@ -81,15 +89,25 @@ def start_workers(self, num_workers: int) -> None:
cpu_pin = None if len(cpus) == 0 else cpus[index % len(cpus)]
w = multiprocessing.Process(
target=worker_loop,
args=(self.tmp_dir, self.queue_to_workers, self.queue_from_workers, self.custom_decoders, cpu_pin))
args=(
self.tmp_dir,
self.queue_to_workers,
self.queue_from_workers,
self.custom_samplers,
self.custom_decoders,
cpu_pin,
),
)
w.start()
self.workers.append(w)
finally:
multiprocessing.set_start_method(current_method, force=True)

def __enter__(self):
self.exit_stack = contextlib.ExitStack().__enter__()
self.tmp_dir = pathlib.Path(self.exit_stack.enter_context(tempfile.TemporaryDirectory()))
self.tmp_dir = pathlib.Path(
self.exit_stack.enter_context(tempfile.TemporaryDirectory())
)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -150,6 +168,7 @@ def wait_for_next_sample(self,
stats = AnonTaskStats()
return TaskStats(
strong_id=result.strong_id,
sampler=work_in.sampler,
decoder=work_in.decoder,
json_metadata=work_in.json_metadata,
shots=stats.shots,
Expand All @@ -165,7 +184,7 @@ def _iter_draw_collectors(self, *, prefer_started: bool) -> Iterator[Tuple[int,
while True:
key = self.next_collector_key
try:
task = next(self.tasks_with_decoder_iter)
task = next(self.tasks_with_sampler_decoder_iter)
except StopIteration:
break
collector = CollectionTrackerForSingleTask(
Expand Down Expand Up @@ -236,16 +255,18 @@ def status(self, *, num_circuits: Optional[int]) -> str:
return main_status + ''.join(collector_statuses)


def _iter_tasks_with_assigned_decoders(
def _iter_tasks_with_assigned_samplers_decoders(
*,
tasks_iter: Iterator[Task],
default_samplers: Optional[Iterable[str]],
default_decoders: Optional[Iterable[str]],
global_collections_options: CollectionOptions,
) -> Iterator[Task]:
for task in tasks_iter:
if task.circuit is None:
task = Task(
circuit=stim.Circuit.from_file(task.circuit_path),
sampler=task.sampler,
decoder=task.decoder,
detector_error_model=task.detector_error_model,
postselection_mask=task.postselection_mask,
Expand All @@ -255,16 +276,25 @@ def _iter_tasks_with_assigned_decoders(
circuit_path=task.circuit_path,
)

if default_samplers is not None:
task_samplers = list(default_samplers)
elif task.sampler is not None:
task_samplers = [task.sampler]
inmzhang marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError("Samplers to use was not specified. samplers is None and task.sampler is None")

if task.decoder is None and default_decoders is None:
raise ValueError("Decoders to use was not specified. decoders is None and task.decoder is None")

task_decoders = []
if default_decoders is not None:
task_decoders.extend(default_decoders)
if task.decoder is not None and task.decoder not in task_decoders:
task_decoders.append(task.decoder)
for decoder in task_decoders:
for sampler, decoder in itertools.product(task_samplers, task_decoders):
yield Task(
circuit=task.circuit,
sampler=sampler,
decoder=decoder,
detector_error_model=task.detector_error_model,
postselection_mask=task.postselection_mask,
Expand Down
4 changes: 4 additions & 0 deletions glue/sample/src/sinter/_csv_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def csv_line(*,
errors: Any,
discards: Any,
seconds: Any,
sampler: Any,
decoder: Any,
strong_id: Any,
json_metadata: Any,
Expand Down Expand Up @@ -50,13 +51,15 @@ def csv_line(*,
errors = escape_csv(errors, 10)
discards = escape_csv(discards, 10)
seconds = escape_csv(seconds, 8)
sampler = escape_csv(sampler, None)
decoder = escape_csv(decoder, None)
strong_id = escape_csv(strong_id, None)
json_metadata = escape_csv(json_metadata, None)
return (f'{shots},'
f'{errors},'
f'{discards},'
f'{seconds},'
f'{sampler},'
f'{decoder},'
f'{strong_id},'
f'{json_metadata},'
Expand All @@ -68,6 +71,7 @@ def csv_line(*,
discards='discards',
seconds='seconds',
strong_id='strong_id',
sampler='sampler',
decoder='decoder',
json_metadata='json_metadata',
custom_counts='custom_counts',
Expand Down
13 changes: 8 additions & 5 deletions glue/sample/src/sinter/_existing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sinter._task_stats import TaskStats
from sinter._task import Task
from sinter._decoding import AnonTaskStats
from sinter._sampling_and_decoding import AnonTaskStats

if TYPE_CHECKING:
import sinter
Expand Down Expand Up @@ -38,6 +38,8 @@ def __iadd__(self, other: 'ExistingData') -> 'ExistingData':

@staticmethod
def from_file(path_or_file: Any) -> 'ExistingData':
# Do not expect 'sampler' field in CSV files.
# This is for backwards compatibility.
expected_fields = {
"shots",
"discards",
Expand Down Expand Up @@ -81,6 +83,7 @@ def from_file(path_or_file: Any) -> 'ExistingData':
custom_counts=custom_counts,
seconds=float(row['seconds']),
strong_id=row['strong_id'],
sampler=row.get('sampler', 'stim'),
decoder=row['decoder'],
json_metadata=json.loads(row['json_metadata']),
))
Expand Down Expand Up @@ -122,8 +125,8 @@ def stats_from_csv_files(*paths_or_files: Any) -> List['sinter.TaskStats']:
>>> stats = sinter.stats_from_csv_files(in_memory_file)
>>> for stat in stats:
... print(repr(stat))
sinter.TaskStats(strong_id='9c31908e2b', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25)
sinter.TaskStats(strong_id='deadbeef08', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125)
sinter.TaskStats(strong_id='9c31908e2b', sampler='stim', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25)
sinter.TaskStats(strong_id='deadbeef08', sampler='stim', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125)
"""
result = ExistingData()
for p in paths_or_files:
Expand Down Expand Up @@ -163,8 +166,8 @@ def read_stats_from_csv_files(*paths_or_files: Any) -> List['sinter.TaskStats']:
>>> stats = sinter.read_stats_from_csv_files(in_memory_file)
>>> for stat in stats:
... print(repr(stat))
sinter.TaskStats(strong_id='9c31908e2b', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25)
sinter.TaskStats(strong_id='deadbeef08', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125)
sinter.TaskStats(strong_id='9c31908e2b', sampler='stim', decoder='pymatching', json_metadata={'d': 9}, shots=4000, errors=66, seconds=0.25)
sinter.TaskStats(strong_id='deadbeef08', sampler='stim', decoder='pymatching', json_metadata={'d': 7}, shots=1000, errors=250, seconds=0.125)
"""
result = ExistingData()
for p in paths_or_files:
Expand Down
13 changes: 13 additions & 0 deletions glue/sample/src/sinter/_existing_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,16 @@ def test_read_stats_from_csv_files():
sinter.TaskStats(strong_id='abc123', decoder='pymatching', json_metadata={'d': 3}, shots=2600, errors=8, discards=120, seconds=8.0, custom_counts=collections.Counter({'dets': 1234})),
sinter.TaskStats(strong_id='def456', decoder='pymatching', json_metadata={'d': 5}, shots=4000, errors=0, discards=20, seconds=4.0),
]

with open(d / 'tmp3.csv', 'w') as f:
print("""
shots,errors,discards,seconds,sampler,decoder,strong_id,json_metadata,custom_counts
300, 1, 20, 1.0,stim,pymatching,abc123,"{""d"":3}","{""dets"":1234}"
1000, 3, 40, 3.0,stim,pymatching,abc123,"{""d"":3}",
2000, 0, 10, 2.0,mock,pymatching,def456,"{""d"":5}"
""".strip(), file=f)

assert sinter.read_stats_from_csv_files(d / 'tmp3.csv') == [
sinter.TaskStats(strong_id='abc123', sampler='stim', decoder='pymatching', json_metadata={'d': 3}, shots=1300, errors=4, discards=60, seconds=4.0, custom_counts=collections.Counter({'dets': 1234})),
sinter.TaskStats(strong_id='def456', sampler='mock', decoder='pymatching', json_metadata={'d': 5}, shots=2000, errors=0, discards=10, seconds=2.0),
]
Loading
Loading