Skip to content

Commit

Permalink
Add stim.Circuit.detecting_regions
Browse files Browse the repository at this point in the history
Fixes #349
  • Loading branch information
Strilanc committed Mar 12, 2024
1 parent 52694bd commit bc0b422
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 2 deletions.
1 change: 1 addition & 0 deletions file_lists/test_files
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ src/stim/cmd/command_gen.test.cc
src/stim/cmd/command_m2d.test.cc
src/stim/cmd/command_sample.test.cc
src/stim/cmd/command_sample_dem.test.cc
src/stim/dem/dem_instruction.test.cc
src/stim/dem/detector_error_model.test.cc
src/stim/diagram/ascii_diagram.test.cc
src/stim/diagram/base64.test.cc
Expand Down
229 changes: 229 additions & 0 deletions src/stim/circuit/circuit.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,89 @@
using namespace stim;
using namespace stim_pybind;

std::set<DemTarget> py_dem_filter_to_dem_target_set(
const Circuit &circuit, const CircuitStats &stats, const pybind11::object &included_targets_filter) {
std::set<DemTarget> result;
auto add_all_dets = [&]() {
for (uint64_t k = 0; k < stats.num_detectors; k++) {
result.insert(DemTarget::relative_detector_id(k));
}
};
auto add_all_obs = [&]() {
for (uint64_t k = 0; k < stats.num_observables; k++) {
result.insert(DemTarget::observable_id(k));
}
};

bool has_coords = false;
std::map<uint64_t, std::vector<double>> cached_coords;
auto get_coords_cached = [&]() -> const std::map<uint64_t, std::vector<double>> & {
std::set<uint64_t> all_dets;
for (uint64_t k = 0; k < stats.num_detectors; k++) {
all_dets.insert(k);
}
if (!has_coords) {
cached_coords = circuit.get_detector_coordinates(all_dets);
has_coords = true;
}
return cached_coords;
};

if (included_targets_filter.is_none()) {
add_all_dets();
add_all_obs();
return result;
}
for (const auto &filter : included_targets_filter) {
bool fail = false;
if (pybind11::isinstance<ExposedDemTarget>(filter)) {
result.insert(pybind11::cast<ExposedDemTarget>(filter));
} else if (pybind11::isinstance<pybind11::str>(filter)) {
std::string s = pybind11::cast<std::string>(filter);
if (s == "D") {
add_all_dets();
} else if (s == "L") {
add_all_obs();
} else if (s.starts_with("D") || s.starts_with("L")) {
result.insert(DemTarget::from_text(s));
} else {
fail = true;
}
} else {
std::vector<double> prefix;
for (auto e : filter) {
if (pybind11::isinstance<pybind11::int_>(e) || pybind11::isinstance<pybind11::float_>(e)) {
prefix.push_back(pybind11::cast<double>(e));
} else {
fail = true;
break;
}
}
if (!fail) {
for (const auto &[target, coord] : get_coords_cached()) {
if (coord.size() >= prefix.size()) {
bool match = true;
for (size_t k = 0; k < prefix.size(); k++) {
match &= prefix[k] == coord[k];
}
if (match) {
result.insert(DemTarget::relative_detector_id(target));
}
}
}
}
}
if (fail) {
std::stringstream ss;
ss << "Don't know how to interpret '";
ss << pybind11::cast<std::string>(pybind11::repr(filter));
ss << "' as a dem target filter.";
throw std::invalid_argument(ss.str());
}
}
return result;
}

std::string circuit_repr(const Circuit &self) {
if (self.operations.empty()) {
return "stim.Circuit()";
Expand Down Expand Up @@ -2118,6 +2201,152 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_<Ci
)DOC")
.data());

c.def(
"detecting_regions",
[](const Circuit &self,
const pybind11::object &included_targets,
const pybind11::object &included_ticks,
bool ignore_anticommutation_errors) -> std::map<ExposedDemTarget, std::map<uint64_t, FlexPauliString>> {
auto stats = self.compute_stats();
auto included_target_set = py_dem_filter_to_dem_target_set(self, stats, included_targets);
std::set<uint64_t> included_tick_set;

if (included_ticks.is_none()) {
for (uint64_t k = 0; k < stats.num_ticks; k++) {
included_tick_set.insert(k);
}
} else {
for (const auto &t : included_ticks) {
included_tick_set.insert(pybind11::cast<uint64_t>(t));
}
}
auto result = circuit_to_detecting_regions(
self, included_target_set, included_tick_set, ignore_anticommutation_errors);
std::map<ExposedDemTarget, std::map<uint64_t, FlexPauliString>> exposed_result;
for (const auto &[k, v] : result) {
exposed_result.insert({ExposedDemTarget(k), std::move(v)});
}
return exposed_result;
},
pybind11::kw_only(),
pybind11::arg("targets") = pybind11::none(),
pybind11::arg("ticks") = pybind11::none(),
pybind11::arg("ignore_anticommutation_errors") = false,
clean_doc_string(R"DOC(
@signature def detecting_regions(self, *, included_targets: Iterable[Iterable[float] | str | stim.DemTarget] | None = None, included_ticks: None | Iterable[int] = None) -> Dict[stim.DemTarget, Dict[int, stim.PauliString]]:
Explains how detector error model errors are produced by circuit errors.
Args:
targets: Defaults to everything (None).
When specified, this should be an iterable of filters where items
matching any one filter are included.
A variety of filters are supported:
stim.DemTarget: Includes the targeted detector or observable.
Iterable[float]: Coordinate prefix match. Includes detectors whose
coordinate data begins with the same floats.
"D": Includes all detectors.
"L": Includes all observables.
"D#" (e.g. "D5"): Includes the detector with the specified index.
"L#" (e.g. "L5"): Includes the observable with the specified index.
ticks: Defaults to everything (None).
When specified, this should be a list of integers corresponding to
the tick indices to report sensitivities for.
ignore_anticommutation_errors: Defaults to False.
When set to False, invalid detecting regions that anticommute with a
reset will cause the method to raise an exception. When set to True,
the offending component will simply be silently dropped. This can
result in broken detectors having apparently enormous detecting
regions.
Returns:
Nested dictionaries keyed first by a `stim.DemTarget` identifying the
detector or observable, then by the index of the tick, leading to a
PauliString with that target's error sensitivity at that tick.
Note you can use `stim.PauliString.pauli_indices` to quickly get to the
non-identity terms in the sensitivity.
Examples:
>>> import stim
>>> detecting_regions = stim.Circuit('''
... R 0
... TICK
... H 0
... TICK
... CX 0 1
... TICK
... MX 0 1
... DETECTOR rec[-1] rec[-2]
... ''').detecting_regions()
>>> for target, tick_regions in detecting_regions.items():
... print("target", target)
... for tick, sensitivity in tick_regions.items():
... print(" tick", tick, "=", sensitivity)
target D0
tick 0 = +Z_
tick 1 = +X_
tick 2 = +XX
>>> circuit = stim.Circuit.generated(
... "surface_code:rotated_memory_x",
... rounds=5,
... distance=4,
... )
>>> detecting_regions = circuit.detecting_regions(
... targets=["L0", (2, 4), stim.DemTarget.relative_detector_id(5)],
... ticks=range(5, 15),
... )
>>> for target, tick_regions in detecting_regions.items():
... print("target", target)
... for tick, sensitivity in tick_regions.items():
... print(" tick", tick, "=", sensitivity)
target D1
tick 5 = +____________________X______________________
tick 6 = +____________________Z______________________
target D5
tick 5 = +______X____________________________________
tick 6 = +______Z____________________________________
target D14
tick 5 = +__________X_X______XXX_____________________
tick 6 = +__________X_X______XZX_____________________
tick 7 = +__________X_X______XZX_____________________
tick 8 = +__________X_X______XXX_____________________
tick 9 = +__________XXX_____XXX______________________
tick 10 = +__________XXX_______X______________________
tick 11 = +__________X_________X______________________
tick 12 = +____________________X______________________
tick 13 = +____________________Z______________________
target D29
tick 7 = +____________________Z______________________
tick 8 = +____________________X______________________
tick 9 = +____________________XX_____________________
tick 10 = +___________________XXX_______X_____________
tick 11 = +____________X______XXXX______X_____________
tick 12 = +__________X_X______XXX_____________________
tick 13 = +__________X_X______XZX_____________________
tick 14 = +__________X_X______XZX_____________________
target D44
tick 14 = +____________________Z______________________
target L0
tick 5 = +_X________X________X________X______________
tick 6 = +_X________X________X________X______________
tick 7 = +_X________X________X________X______________
tick 8 = +_X________X________X________X______________
tick 9 = +_X________X_______XX________X______________
tick 10 = +_X________X________X________X______________
tick 11 = +_X________XX_______X________XX_____________
tick 12 = +_X________X________X________X______________
tick 13 = +_X________X________X________X______________
tick 14 = +_X________X________X________X______________
)DOC")
.data());

c.def(
"without_noise",
&Circuit::without_noise,
Expand Down
17 changes: 17 additions & 0 deletions src/stim/circuit/circuit_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,20 @@ def test_has_flow_shorthands():
assert c.has_flow("-iX_ -> -iXX xor rec[1] xor rec[3]")
with pytest.raises(ValueError):
c.has_flow("iX_ -> XX")


def test_detecting_regions():
assert stim.Circuit('''
R 0
TICK
H 0
TICK
CX 0 1
TICK
MX 0 1
DETECTOR rec[-1] rec[-2]
''').detecting_regions() == {stim.DemTarget.relative_detector_id(0): {
0: stim.PauliString("Z_"),
1: stim.PauliString("X_"),
2: stim.PauliString("XX"),
}}
27 changes: 25 additions & 2 deletions src/stim/dem/dem_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cmath>

#include "stim/arg_parse.h"
#include "stim/dem/detector_error_model.h"
#include "stim/simulators/error_analyzer.h"
#include "stim/str_util.h"
Expand All @@ -11,14 +12,17 @@ using namespace stim;
constexpr uint64_t OBSERVABLE_BIT = uint64_t{1} << 63;
constexpr uint64_t SEPARATOR_SYGIL = UINT64_MAX;

constexpr uint64_t MAX_OBS = 0xFFFFFFFF;
constexpr uint64_t MAX_DET = (uint64_t{1} << 62) - 1;

DemTarget DemTarget::observable_id(uint64_t id) {
if (id > 0xFFFFFFFF) {
if (id > MAX_OBS) {
throw std::invalid_argument("id > 0xFFFFFFFF");
}
return {OBSERVABLE_BIT | id};
}
DemTarget DemTarget::relative_detector_id(uint64_t id) {
if (id >= (uint64_t{1} << 62)) {
if (id > MAX_DET) {
throw std::invalid_argument("Relative detector id too large.");
}
return {id};
Expand Down Expand Up @@ -75,6 +79,25 @@ void DemTarget::shift_if_detector_id(int64_t offset) {
data = (uint64_t)((int64_t)data + offset);
}
}
DemTarget DemTarget::from_text(std::string_view text) {
if (!text.empty()) {
bool is_det = text[0] == 'D';
bool is_obs = text[0] == 'L';
if (is_det || is_obs) {
int64_t parsed = 0;
if (parse_int64(text.substr(1), &parsed)) {
if (parsed >= 0) {
if (is_det && parsed <= (int64_t)MAX_DET) {
return DemTarget::relative_detector_id(parsed);
} else if (is_obs && parsed <= (int64_t)MAX_OBS) {
return DemTarget::observable_id(parsed);
}
}
}
}
}
throw std::invalid_argument("Failed to parse as a stim.DemTarget: '" + std::string(text) + "'");
}

bool DemInstruction::operator<(const DemInstruction &other) const {
if (type != other.type) {
Expand Down
2 changes: 2 additions & 0 deletions src/stim/dem/dem_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct DemTarget {
bool operator!=(const DemTarget &other) const;
bool operator<(const DemTarget &other) const;
std::string str() const;

static DemTarget from_text(std::string_view text);
};

struct DetectorErrorModel;
Expand Down
31 changes: 31 additions & 0 deletions src/stim/dem/dem_instruction.test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "stim/dem/dem_instruction.h"

#include "gtest/gtest.h"

using namespace stim;

TEST(dem_instruction, from_str) {
ASSERT_EQ(DemTarget::from_text("D5"), DemTarget::relative_detector_id(5));
ASSERT_EQ(DemTarget::from_text("D0"), DemTarget::relative_detector_id(0));
ASSERT_EQ(DemTarget::from_text("D4611686018427387903"), DemTarget::relative_detector_id(4611686018427387903));

ASSERT_EQ(DemTarget::from_text("L5"), DemTarget::observable_id(5));
ASSERT_EQ(DemTarget::from_text("L0"), DemTarget::observable_id(0));
ASSERT_EQ(DemTarget::from_text("L4294967295"), DemTarget::observable_id(4294967295));

ASSERT_THROW({ DemTarget::from_text("D4611686018427387904"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("L4294967296"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("L-1"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("L-1"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("D-1"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("Da"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("Da "); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text(" Da"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("X"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text(""); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("1"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("-1"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("0"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text("'"); }, std::invalid_argument);
ASSERT_THROW({ DemTarget::from_text(" "); }, std::invalid_argument);
}
Loading

0 comments on commit bc0b422

Please sign in to comment.