Skip to content

Commit

Permalink
Fix HERALDED_PAULI_CHANNEL_1 permuting X/Y/Z error argument components
Browse files Browse the repository at this point in the history
- :foreheadslap:
- autoformat
- Add `stim::circuit_to_dem` to the C++ API for easier conversions with named arguments via a struct
  • Loading branch information
Strilanc committed Jul 29, 2024
1 parent c0627e2 commit cec9b46
Show file tree
Hide file tree
Showing 13 changed files with 326 additions and 191 deletions.
1 change: 1 addition & 0 deletions file_lists/test_files
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ src/stim/util_bot/twiddle.test.cc
src/stim/util_top/circuit_flow_generators.test.cc
src/stim/util_top/circuit_inverse_qec.test.cc
src/stim/util_top/circuit_inverse_unitary.test.cc
src/stim/util_top/circuit_to_dem.test.cc
src/stim/util_top/circuit_to_detecting_regions.test.cc
src/stim/util_top/circuit_vs_amplitudes.test.cc
src/stim/util_top/circuit_vs_tableau.test.cc
Expand Down
1 change: 1 addition & 0 deletions src/stim.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
#include "stim/util_top/circuit_flow_generators.h"
#include "stim/util_top/circuit_inverse_qec.h"
#include "stim/util_top/circuit_inverse_unitary.h"
#include "stim/util_top/circuit_to_dem.h"
#include "stim/util_top/circuit_to_detecting_regions.h"
#include "stim/util_top/circuit_vs_amplitudes.h"
#include "stim/util_top/circuit_vs_tableau.h"
Expand Down
11 changes: 9 additions & 2 deletions src/stim/cmd/command_diagram.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,22 @@ DiagramHelper stim_pybind::circuit_diagram(
type == "timeslice" || type == "time-slice") {
std::stringstream out;
DiagramTimelineSvgDrawer::make_diagram_write_to(
circuit, out, tick_min, num_ticks, DiagramTimelineSvgDrawerMode::SVG_MODE_TIME_SLICE, filter_coords, num_rows);
circuit,
out,
tick_min,
num_ticks,
DiagramTimelineSvgDrawerMode::SVG_MODE_TIME_SLICE,
filter_coords,
num_rows);
DiagramType d_type =
type.find("html") != std::string::npos ? DiagramType::DIAGRAM_TYPE_SVG_HTML : DiagramType::DIAGRAM_TYPE_SVG;
return DiagramHelper{d_type, out.str()};
} else if (
type == "detslice-svg" || type == "detslice" || type == "detslice-html" || type == "detslice-svg-html" ||
type == "detector-slice-svg" || type == "detector-slice") {
std::stringstream out;
DetectorSliceSet::from_circuit_ticks(circuit, tick_min, num_ticks, filter_coords).write_svg_diagram_to(out, num_rows);
DetectorSliceSet::from_circuit_ticks(circuit, tick_min, num_ticks, filter_coords)
.write_svg_diagram_to(out, num_rows);
DiagramType d_type =
type.find("html") != std::string::npos ? DiagramType::DIAGRAM_TYPE_SVG_HTML : DiagramType::DIAGRAM_TYPE_SVG;
return DiagramHelper{d_type, out.str()};
Expand Down
1 change: 0 additions & 1 deletion src/stim/dem/detector_error_model_target.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ pybind11::class_<ExposedDemTarget> stim_pybind::pybind_detector_error_model_targ

void stim_pybind::pybind_detector_error_model_target_methods(
pybind11::module &m, pybind11::class_<ExposedDemTarget> &c) {

c.def(
pybind11::init([](const pybind11::object &arg) -> ExposedDemTarget {
if (pybind11::isinstance<ExposedDemTarget>(arg)) {
Expand Down
200 changes: 100 additions & 100 deletions src/stim/gates/gates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,110 +47,110 @@ GateDataMap::GateDataMap() {

GateType Gate::hadamard_conjugated(bool ignoring_sign) const {
switch (id) {
case GateType::DETECTOR:
case GateType::OBSERVABLE_INCLUDE:
case GateType::TICK:
case GateType::QUBIT_COORDS:
case GateType::SHIFT_COORDS:
case GateType::MPAD:
case GateType::H:
case GateType::DEPOLARIZE1:
case GateType::DEPOLARIZE2:
case GateType::Y_ERROR:
case GateType::I:
case GateType::Y:
case GateType::SQRT_YY:
case GateType::SQRT_YY_DAG:
case GateType::MYY:
case GateType::SWAP:
return id;
case GateType::DETECTOR:
case GateType::OBSERVABLE_INCLUDE:
case GateType::TICK:
case GateType::QUBIT_COORDS:
case GateType::SHIFT_COORDS:
case GateType::MPAD:
case GateType::H:
case GateType::DEPOLARIZE1:
case GateType::DEPOLARIZE2:
case GateType::Y_ERROR:
case GateType::I:
case GateType::Y:
case GateType::SQRT_YY:
case GateType::SQRT_YY_DAG:
case GateType::MYY:
case GateType::SWAP:
return id;

case GateType::MY:
case GateType::MRY:
case GateType::RY:
case GateType::YCY:
return ignoring_sign ? id : GateType::NOT_A_GATE;
case GateType::MY:
case GateType::MRY:
case GateType::RY:
case GateType::YCY:
return ignoring_sign ? id : GateType::NOT_A_GATE;

case GateType::ISWAP:
case GateType::CZSWAP:
case GateType::ISWAP_DAG:
return GateType::NOT_A_GATE;
case GateType::ISWAP:
case GateType::CZSWAP:
case GateType::ISWAP_DAG:
return GateType::NOT_A_GATE;

case GateType::XCY:
return ignoring_sign ? GateType::CY : GateType::NOT_A_GATE;
case GateType::CY:
return ignoring_sign ? GateType::XCY : GateType::NOT_A_GATE;
case GateType::YCX:
return ignoring_sign ? GateType::YCZ : GateType::NOT_A_GATE;
case GateType::YCZ:
return ignoring_sign ? GateType::YCX : GateType::NOT_A_GATE;
case GateType::C_XYZ:
return ignoring_sign ? GateType::C_ZYX : GateType::NOT_A_GATE;
case GateType::C_ZYX:
return ignoring_sign ? GateType::C_XYZ : GateType::NOT_A_GATE;
case GateType::H_XY:
return ignoring_sign ? GateType::H_YZ : GateType::NOT_A_GATE;
case GateType::H_YZ:
return ignoring_sign ? GateType::H_XY : GateType::NOT_A_GATE;
case GateType::XCY:
return ignoring_sign ? GateType::CY : GateType::NOT_A_GATE;
case GateType::CY:
return ignoring_sign ? GateType::XCY : GateType::NOT_A_GATE;
case GateType::YCX:
return ignoring_sign ? GateType::YCZ : GateType::NOT_A_GATE;
case GateType::YCZ:
return ignoring_sign ? GateType::YCX : GateType::NOT_A_GATE;
case GateType::C_XYZ:
return ignoring_sign ? GateType::C_ZYX : GateType::NOT_A_GATE;
case GateType::C_ZYX:
return ignoring_sign ? GateType::C_XYZ : GateType::NOT_A_GATE;
case GateType::H_XY:
return ignoring_sign ? GateType::H_YZ : GateType::NOT_A_GATE;
case GateType::H_YZ:
return ignoring_sign ? GateType::H_XY : GateType::NOT_A_GATE;

case GateType::X:
return GateType::Z;
case GateType::Z:
return GateType::X;
case GateType::SQRT_Y:
return GateType::SQRT_Y_DAG;
case GateType::SQRT_Y_DAG:
return GateType::SQRT_Y;
case GateType::MX:
return GateType::M;
case GateType::M:
return GateType::MX;
case GateType::MRX:
return GateType::MR;
case GateType::MR:
return GateType::MRX;
case GateType::RX:
return GateType::R;
case GateType::R:
return GateType::RX;
case GateType::XCX:
return GateType::CZ;
case GateType::XCZ:
return GateType::CX;
case GateType::CX:
return GateType::XCZ;
case GateType::CZ:
return GateType::XCX;
case GateType::X_ERROR:
return GateType::Z_ERROR;
case GateType::Z_ERROR:
return GateType::X_ERROR;
case GateType::SQRT_X:
return GateType::S;
case GateType::SQRT_X_DAG:
return GateType::S_DAG;
case GateType::S:
return GateType::SQRT_X;
case GateType::S_DAG:
return GateType::SQRT_X_DAG;
case GateType::SQRT_XX:
return GateType::SQRT_ZZ;
case GateType::SQRT_XX_DAG:
return GateType::SQRT_ZZ_DAG;
case GateType::SQRT_ZZ:
return GateType::SQRT_XX;
case GateType::SQRT_ZZ_DAG:
return GateType::SQRT_XX_DAG;
case GateType::CXSWAP:
return GateType::SWAPCX;
case GateType::SWAPCX:
return GateType::CXSWAP;
case GateType::MXX:
return GateType::MZZ;
case GateType::MZZ:
return GateType::MXX;
default:
return GateType::NOT_A_GATE;
case GateType::X:
return GateType::Z;
case GateType::Z:
return GateType::X;
case GateType::SQRT_Y:
return GateType::SQRT_Y_DAG;
case GateType::SQRT_Y_DAG:
return GateType::SQRT_Y;
case GateType::MX:
return GateType::M;
case GateType::M:
return GateType::MX;
case GateType::MRX:
return GateType::MR;
case GateType::MR:
return GateType::MRX;
case GateType::RX:
return GateType::R;
case GateType::R:
return GateType::RX;
case GateType::XCX:
return GateType::CZ;
case GateType::XCZ:
return GateType::CX;
case GateType::CX:
return GateType::XCZ;
case GateType::CZ:
return GateType::XCX;
case GateType::X_ERROR:
return GateType::Z_ERROR;
case GateType::Z_ERROR:
return GateType::X_ERROR;
case GateType::SQRT_X:
return GateType::S;
case GateType::SQRT_X_DAG:
return GateType::S_DAG;
case GateType::S:
return GateType::SQRT_X;
case GateType::S_DAG:
return GateType::SQRT_X_DAG;
case GateType::SQRT_XX:
return GateType::SQRT_ZZ;
case GateType::SQRT_XX_DAG:
return GateType::SQRT_ZZ_DAG;
case GateType::SQRT_ZZ:
return GateType::SQRT_XX;
case GateType::SQRT_ZZ_DAG:
return GateType::SQRT_XX_DAG;
case GateType::CXSWAP:
return GateType::SWAPCX;
case GateType::SWAPCX:
return GateType::CXSWAP;
case GateType::MXX:
return GateType::MZZ;
case GateType::MZZ:
return GateType::MXX;
default:
return GateType::NOT_A_GATE;
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/stim/gates/gates.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#include "stim/simulators/tableau_simulator.h"
#include "stim/util_bot/str_util.h"
#include "stim/util_bot/test_util.test.h"
#include "stim/util_top/has_flow.h"
#include "stim/util_top/circuit_flow_generators.h"
#include "stim/util_top/has_flow.h"

using namespace stim;

Expand Down Expand Up @@ -375,8 +375,10 @@ TEST(gate_data, hadamard_conjugated_vs_flow_generators_of_two_qubit_gates) {
GateType actual_s = g.hadamard_conjugated(false);
GateType actual_u = g.hadamard_conjugated(true);
bool found = std::find(other_us.begin(), other_us.end(), actual_u) != other_us.end();
EXPECT_EQ(actual_s, expected_s) << "signed " << g.name << " -> " << GATE_DATA[actual_s].name << " != " << GATE_DATA[expected_s].name;
EXPECT_TRUE(found) << "unsigned " << g.name << " -> " << GATE_DATA[actual_u].name << " not in " << GATE_DATA[other_us[0]].name;
EXPECT_EQ(actual_s, expected_s)
<< "signed " << g.name << " -> " << GATE_DATA[actual_s].name << " != " << GATE_DATA[expected_s].name;
EXPECT_TRUE(found) << "unsigned " << g.name << " -> " << GATE_DATA[actual_u].name << " not in "
<< GATE_DATA[other_us[0]].name;
}
}
}
25 changes: 18 additions & 7 deletions src/stim/simulators/error_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ void ErrorAnalyzer::undo_MZ_with_context(const CircuitInstruction &dat, const ch
}

void ErrorAnalyzer::undo_HERALDED_ERASE(const CircuitInstruction &dat) {
check_can_approximate_disjoint("HERALDED_ERASE", dat.args);
check_can_approximate_disjoint("HERALDED_ERASE", dat.args, false);
double p = dat.args[0] * 0.25;
double i = std::max(0.0, 1.0 - 4 * p);

Expand All @@ -327,7 +327,7 @@ void ErrorAnalyzer::undo_HERALDED_ERASE(const CircuitInstruction &dat) {
}

void ErrorAnalyzer::undo_HERALDED_PAULI_CHANNEL_1(const CircuitInstruction &dat) {
check_can_approximate_disjoint("HERALDED_PAULI_CHANNEL_1", dat.args);
check_can_approximate_disjoint("HERALDED_PAULI_CHANNEL_1", dat.args, true);
double hi = dat.args[0];
double hx = dat.args[1];
double hy = dat.args[2];
Expand All @@ -341,7 +341,7 @@ void ErrorAnalyzer::undo_HERALDED_PAULI_CHANNEL_1(const CircuitInstruction &dat)
SparseXorVec<DemTarget> &herald_symptoms = tracker.rec_bits[tracker.num_measurements_in_past];
if (accumulate_errors) {
add_error_combinations<3>(
{i, 0, 0, 0, hi, hx, hy, hz},
{i, 0, 0, 0, hi, hz, hx, hy},
{tracker.xs[q].range(), tracker.zs[q].range(), herald_symptoms.range()},
true);
}
Expand Down Expand Up @@ -750,7 +750,7 @@ void ErrorAnalyzer::correlated_error_block(const std::vector<CircuitInstruction>
add_composite_error(dats[0].args[0], dats[0].targets);
return;
}
check_can_approximate_disjoint("ELSE_CORRELATED_ERROR", {});
check_can_approximate_disjoint("ELSE_CORRELATED_ERROR", {}, false);

double remaining_p = 1;
for (size_t k = dats.size(); k--;) {
Expand Down Expand Up @@ -820,7 +820,18 @@ void ErrorAnalyzer::undo_ELSE_CORRELATED_ERROR(const CircuitInstruction &dat) {
}
}

void ErrorAnalyzer::check_can_approximate_disjoint(const char *op_name, SpanRef<const double> probabilities) const {
void ErrorAnalyzer::check_can_approximate_disjoint(
const char *op_name, SpanRef<const double> probabilities, bool allow_single_component) const {
if (allow_single_component) {
size_t num_specified = 0;
for (double p : probabilities) {
num_specified += p > 0;
}
if (num_specified <= 1) {
return;
}
}

if (approximate_disjoint_errors_threshold == 0) {
std::stringstream msg;
msg << "Encountered the operation " << op_name
Expand Down Expand Up @@ -854,7 +865,7 @@ void ErrorAnalyzer::undo_PAULI_CHANNEL_1(const CircuitInstruction &dat) {
double iz;
bool is_independent = try_disjoint_to_independent_xyz_errors_approx(dx, dy, dz, &ix, &iy, &iz);
if (!is_independent) {
check_can_approximate_disjoint("PAULI_CHANNEL_1", dat.args);
check_can_approximate_disjoint("PAULI_CHANNEL_1", dat.args, true);
ix = dx;
iy = dy;
iz = dz;
Expand All @@ -875,7 +886,7 @@ void ErrorAnalyzer::undo_PAULI_CHANNEL_1(const CircuitInstruction &dat) {
}

void ErrorAnalyzer::undo_PAULI_CHANNEL_2(const CircuitInstruction &dat) {
check_can_approximate_disjoint("PAULI_CHANNEL_2", dat.args);
check_can_approximate_disjoint("PAULI_CHANNEL_2", dat.args, true);

std::array<double, 16> probabilities;
for (size_t k = 0; k < 15; k++) {
Expand Down
3 changes: 2 additions & 1 deletion src/stim/simulators/error_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ struct ErrorAnalyzer {
void undo_MXX_disjoint_controls_segment(const CircuitInstruction &inst);
void undo_MYY_disjoint_controls_segment(const CircuitInstruction &inst);
void undo_MZZ_disjoint_controls_segment(const CircuitInstruction &inst);
void check_can_approximate_disjoint(const char *op_name, SpanRef<const double> probabilities) const;
void check_can_approximate_disjoint(
const char *op_name, SpanRef<const double> probabilities, bool allow_single_component) const;
void add_composite_error(double probability, SpanRef<const GateTarget> targets);
void correlated_error_block(const std::vector<CircuitInstruction> &dats);
};
Expand Down
Loading

0 comments on commit cec9b46

Please sign in to comment.