Skip to content

Commit

Permalink
Use enum class instead of enum, and autoformat the code (#652)
Browse files Browse the repository at this point in the history
Enum class does less pollution of the namespace. This change involved some additional things like defining `&` operators since `enum class` isn't implicitly an int like `enum` is.
  • Loading branch information
Strilanc authored Nov 13, 2023
1 parent 7083eb5 commit c128b57
Show file tree
Hide file tree
Showing 70 changed files with 520 additions and 459 deletions.
16 changes: 8 additions & 8 deletions src/stim/benchmark_util.perf.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ inline void add_benchmark(RegisteredBenchmark benchmark) {
all_registered_benchmarks_data->push_back(benchmark);
}

#define BENCHMARK(name) \
void BENCH_##name##_METHOD(); \
struct BENCH_STARTUP_TYPE_##name { \
BENCH_STARTUP_TYPE_##name() { \
add_benchmark({#name, BENCH_##name##_METHOD}); \
} \
}; \
static BENCH_STARTUP_TYPE_##name BENCH_STARTUP_INSTANCE_##name; \
#define BENCHMARK(name) \
void BENCH_##name##_METHOD(); \
struct BENCH_STARTUP_TYPE_##name { \
BENCH_STARTUP_TYPE_##name() { \
add_benchmark({#name, BENCH_##name##_METHOD}); \
} \
}; \
static BENCH_STARTUP_TYPE_##name BENCH_STARTUP_INSTANCE_##name; \
void BENCH_##name##_METHOD()

// HACK: Templating the body function type makes inlining significantly more likely.
Expand Down
24 changes: 12 additions & 12 deletions src/stim/circuit/circuit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

using namespace stim;

enum READ_CONDITION {
enum class READ_CONDITION {
READ_AS_LITTLE_AS_POSSIBLE,
READ_UNTIL_END_OF_BLOCK,
READ_UNTIL_END_OF_FILE,
Expand Down Expand Up @@ -351,13 +351,13 @@ void circuit_read_operations(Circuit &circuit, SOURCE read_char, READ_CONDITION
int c = read_char();
read_past_dead_space_between_commands(c, read_char);
if (c == EOF) {
if (read_condition == READ_UNTIL_END_OF_BLOCK) {
if (read_condition == READ_CONDITION::READ_UNTIL_END_OF_BLOCK) {
throw std::invalid_argument("Unterminated block. Got a '{' without an eventual '}'.");
}
return;
}
if (c == '}') {
if (read_condition != READ_UNTIL_END_OF_BLOCK) {
if (read_condition != READ_CONDITION::READ_UNTIL_END_OF_BLOCK) {
throw std::invalid_argument("Uninitiated block. Got a '}' without a '{'.");
}
return;
Expand All @@ -378,7 +378,7 @@ void circuit_read_operations(Circuit &circuit, SOURCE read_char, READ_CONDITION

// Read block.
circuit.blocks.emplace_back();
circuit_read_operations(circuit.blocks.back(), read_char, READ_UNTIL_END_OF_BLOCK);
circuit_read_operations(circuit.blocks.back(), read_char, READ_CONDITION::READ_UNTIL_END_OF_BLOCK);

// Rewrite target data to reference the parsed block.
circuit.target_buf.ensure_available(3);
Expand All @@ -390,7 +390,7 @@ void circuit_read_operations(Circuit &circuit, SOURCE read_char, READ_CONDITION

// Fuse operations.
circuit.try_fuse_last_two_ops();
} while (read_condition != READ_AS_LITTLE_AS_POSSIBLE);
} while (read_condition != READ_CONDITION::READ_AS_LITTLE_AS_POSSIBLE);
}

void Circuit::append_from_text(const char *text) {
Expand All @@ -400,7 +400,7 @@ void Circuit::append_from_text(const char *text) {
[&]() {
return text[k] != 0 ? text[k++] : EOF;
},
READ_UNTIL_END_OF_FILE);
READ_CONDITION::READ_UNTIL_END_OF_FILE);
}

void Circuit::safe_append(const CircuitInstruction &operation) {
Expand Down Expand Up @@ -433,7 +433,7 @@ void Circuit::safe_append_u(
}

void Circuit::safe_append(GateType gate_type, SpanRef<const GateTarget> targets, SpanRef<const double> args) {
auto flags = GATE_DATA.items[gate_type].flags;
auto flags = GATE_DATA[gate_type].flags;
if (flags & GATE_IS_BLOCK) {
throw std::invalid_argument("Can't append a block like a normal operation.");
}
Expand All @@ -460,11 +460,11 @@ void Circuit::append_from_file(FILE *file, bool stop_asap) {
[&]() {
return getc(file);
},
stop_asap ? READ_AS_LITTLE_AS_POSSIBLE : READ_UNTIL_END_OF_FILE);
stop_asap ? READ_CONDITION::READ_AS_LITTLE_AS_POSSIBLE : READ_CONDITION::READ_UNTIL_END_OF_FILE);
}

std::ostream &stim::operator<<(std::ostream &out, const CircuitInstruction &instruction) {
out << GATE_DATA.items[instruction.gate_type].name;
out << GATE_DATA[instruction.gate_type].name;
if (!instruction.args.empty()) {
out << '(';
bool first = true;
Expand Down Expand Up @@ -762,7 +762,7 @@ const Circuit Circuit::aliased_noiseless_circuit() const {
// HACK: result has pointers into `circuit`!
Circuit result;
for (const auto &op : operations) {
auto flags = GATE_DATA.items[op.gate_type].flags;
auto flags = GATE_DATA[op.gate_type].flags;
if (flags & GATE_PRODUCES_RESULTS) {
if (op.gate_type == GateType::HERALDED_ERASE || op.gate_type == GateType::HERALDED_PAULI_CHANNEL_1) {
// Replace heralded errors with fixed MPAD.
Expand Down Expand Up @@ -794,7 +794,7 @@ const Circuit Circuit::aliased_noiseless_circuit() const {
Circuit Circuit::without_noise() const {
Circuit result;
for (const auto &op : operations) {
auto flags = GATE_DATA.items[op.gate_type].flags;
auto flags = GATE_DATA[op.gate_type].flags;
if (flags & GATE_PRODUCES_RESULTS) {
if (op.gate_type == GateType::HERALDED_ERASE || op.gate_type == GateType::HERALDED_PAULI_CHANNEL_1) {
// Replace heralded errors with fixed MPAD.
Expand Down Expand Up @@ -886,7 +886,7 @@ Circuit Circuit::inverse(bool allow_weak_inverse) const {
}

SpanRef<const double> args = op.args;
const auto &gate_data = GATE_DATA.items[op.gate_type];
const auto &gate_data = GATE_DATA[op.gate_type];
auto flags = gate_data.flags;
if (flags & GATE_IS_UNITARY) {
// Unitary gates always have an inverse.
Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/circuit.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_<Ci
for (auto t : op.args) {
args.append(t);
}
const auto &gate_data = GATE_DATA.items[op.gate_type];
const auto &gate_data = GATE_DATA[op.gate_type];
if (op.args.empty()) {
// Backwards compatibility.
result.append(pybind11::make_tuple(gate_data.name, targets, 0));
Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/circuit.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ Circuit stim::generate_test_circuit_with_all_operations() {

TEST(circuit, generate_test_circuit_with_all_operations) {
auto c = generate_test_circuit_with_all_operations();
std::set<GateType> seen{NOT_A_GATE};
std::set<GateType> seen{GateType::NOT_A_GATE};
for (const auto &instruction : c.operations) {
seen.insert(instruction.gate_type);
}
Expand Down
8 changes: 4 additions & 4 deletions src/stim/circuit/circuit_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ CircuitStats CircuitInstruction::compute_stats(const Circuit *host) const {
}

void CircuitInstruction::add_stats_to(CircuitStats &out, const Circuit *host) const {
if (gate_type == REPEAT) {
if (gate_type == GateType::REPEAT) {
if (host == nullptr) {
throw std::invalid_argument("gate_type == REPEAT && host == nullptr");
}
Expand Down Expand Up @@ -110,7 +110,7 @@ CircuitInstruction::CircuitInstruction(
}

void CircuitInstruction::validate() const {
const Gate &gate = GATE_DATA.items[gate_type];
const Gate &gate = GATE_DATA[gate_type];

if (gate.flags == GateFlags::NO_GATE_FLAG) {
throw std::invalid_argument("Unrecognized gate_type. Associated flag is NO_GATE_FLAG.");
Expand Down Expand Up @@ -248,7 +248,7 @@ void CircuitInstruction::validate() const {
}

uint64_t CircuitInstruction::count_measurement_results() const {
auto flags = GATE_DATA.items[gate_type].flags;
auto flags = GATE_DATA[gate_type].flags;
if (!(flags & GATE_PRODUCES_RESULTS)) {
return 0;
}
Expand All @@ -266,7 +266,7 @@ uint64_t CircuitInstruction::count_measurement_results() const {
}

bool CircuitInstruction::can_fuse(const CircuitInstruction &other) const {
auto flags = GATE_DATA.items[gate_type].flags;
auto flags = GATE_DATA[gate_type].flags;
return gate_type == other.gate_type && args == other.args && !(flags & GATE_IS_NOT_FUSABLE);
}

Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/circuit_instruction.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ PyCircuitInstruction::operator CircuitInstruction() const {
return as_operation_ref();
}
std::string PyCircuitInstruction::name() const {
return GATE_DATA.items[gate_type].name;
return GATE_DATA[gate_type].name;
}
std::vector<uint32_t> PyCircuitInstruction::raw_targets() const {
std::vector<uint32_t> result;
Expand Down
12 changes: 5 additions & 7 deletions src/stim/circuit/gate_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#include <complex>

#include "stim/circuit/stabilizer_flow.h"

using namespace stim;

GateDataMap::GateDataMap() {
Expand Down Expand Up @@ -76,7 +74,7 @@ std::vector<std::vector<std::complex<float>>> Gate::unitary() const {
const Gate &Gate::inverse() const {
std::string inv_name = name;
if ((flags & GATE_IS_UNITARY) || id == GateType::TICK) {
return GATE_DATA.items[static_cast<uint8_t>(best_candidate_inverse_id)];
return GATE_DATA[best_candidate_inverse_id];
}
throw std::out_of_range(inv_name + " has no inverse.");
}
Expand All @@ -101,16 +99,16 @@ Gate::Gate(
}

void GateDataMap::add_gate(bool &failed, const Gate &gate) {
assert(gate.id < NUM_DEFINED_GATES);
assert((size_t)gate.id < NUM_DEFINED_GATES);
const char *c = gate.name;
auto h = gate_name_to_hash(c);
auto &hash_loc = hashed_name_to_gate_type_table[h];
if (hash_loc.expected_name_len != 0) {
std::cerr << "GATE COLLISION " << gate.name << " vs " << items[hash_loc.id].name << "\n";
std::cerr << "GATE COLLISION " << gate.name << " vs " << items[(size_t)hash_loc.id].name << "\n";
failed = true;
return;
}
items[gate.id] = gate;
items[(size_t)gate.id] = gate;
hash_loc.id = gate.id;
hash_loc.expected_name = gate.name;
hash_loc.expected_name_len = gate.name_len;
Expand All @@ -120,7 +118,7 @@ void GateDataMap::add_gate_alias(bool &failed, const char *alt_name, const char
auto h_alt = gate_name_to_hash(alt_name);
auto &hash_loc = hashed_name_to_gate_type_table[h_alt];
if (hash_loc.expected_name_len != 0) {
std::cerr << "GATE COLLISION " << alt_name << " vs " << items[hash_loc.id].name << "\n";
std::cerr << "GATE COLLISION " << alt_name << " vs " << items[(size_t)hash_loc.id].name << "\n";
failed = true;
return;
}
Expand Down
14 changes: 9 additions & 5 deletions src/stim/circuit/gate_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ constexpr inline uint16_t gate_name_to_hash(const char *c) {

constexpr const size_t NUM_DEFINED_GATES = 67;

enum GateType : uint8_t {
enum class GateType : uint8_t {
NOT_A_GATE = 0,
// Annotations
DETECTOR,
Expand Down Expand Up @@ -258,7 +258,7 @@ struct Gate {

template <size_t W>
Tableau<W> tableau() const {
if (!(flags & GATE_IS_UNITARY)) {
if (!(flags & GateFlags::GATE_IS_UNITARY)) {
throw std::invalid_argument(std::string(name) + " isn't unitary so it doesn't have a tableau.");
}
const auto &tableau_data = extra_data_func().flow_data;
Expand All @@ -274,9 +274,9 @@ struct Gate {

template <size_t W>
std::vector<StabilizerFlow<W>> flows() const {
if (flags & GATE_IS_UNITARY) {
if (flags & GateFlags::GATE_IS_UNITARY) {
auto t = tableau<W>();
if (flags & GATE_TARGETS_PAIRS) {
if (flags & GateFlags::GATE_TARGETS_PAIRS) {
return {
StabilizerFlow<W>{stim::PauliString<W>::from_str("X_"), t.xs[0], {}},
StabilizerFlow<W>{stim::PauliString<W>::from_str("Z_"), t.zs[0], {}},
Expand Down Expand Up @@ -340,14 +340,18 @@ struct GateDataMap {
std::array<Gate, NUM_DEFINED_GATES> items;
GateDataMap();

inline const Gate &operator[](GateType g) const {
return items[(uint64_t)g];
}

inline const Gate &at(const char *text, size_t text_len) const {
auto h = gate_name_to_hash(text, text_len);
const auto &entry = hashed_name_to_gate_type_table[h];
if (_case_insensitive_mismatch(text, text_len, entry.expected_name, entry.expected_name_len)) {
throw std::out_of_range("Gate not found: '" + std::string(text, text_len) + "'");
}
// Canonicalize.
return items[entry.id];
return (*this)[entry.id];
}

inline const Gate &at(const char *text) const {
Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/gate_data.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void stim_pybind::pybind_gate_data_methods(pybind11::module &m, pybind11::class_

std::map<std::string, Gate> result;
for (const auto &g : GATE_DATA.items) {
if (g.id != NOT_A_GATE) {
if (g.id != GateType::NOT_A_GATE) {
result.insert({g.name, g});
}
}
Expand Down
28 changes: 15 additions & 13 deletions src/stim/circuit/gate_data.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,34 @@ TEST(gate_data, lookup) {
}

TEST(gate_data, zero_flag_means_not_a_gate) {
ASSERT_EQ(GATE_DATA.items[0].id, 0);
ASSERT_EQ(GATE_DATA.items[0].flags, GateFlags::NO_GATE_FLAG);
ASSERT_EQ((GateType)0, GateType::NOT_A_GATE);
ASSERT_EQ(GATE_DATA[(GateType)0].id, (GateType)0);
ASSERT_EQ(GATE_DATA[(GateType)0].flags, GateFlags::NO_GATE_FLAG);
for (size_t k = 0; k < GATE_DATA.items.size(); k++) {
const auto &g = GATE_DATA.items[k];
if (g.id != 0) {
const auto &g = GATE_DATA[(GateType)k];
if (g.id != GateType::NOT_A_GATE) {
EXPECT_NE(g.flags, GateFlags::NO_GATE_FLAG) << g.name;
}
}
}

TEST(gate_data, one_step_to_canonical_gate) {
for (size_t k = 0; k < GATE_DATA.items.size(); k++) {
const auto &g = GATE_DATA.items[k];
if (g.id != 0) {
EXPECT_TRUE(g.id == k || GATE_DATA.items[g.id].id == g.id) << g.name;
const auto &g = GATE_DATA[(GateType)k];
if (g.id != GateType::NOT_A_GATE) {
EXPECT_TRUE(g.id == (GateType)k || GATE_DATA[g.id].id == g.id) << g.name;
}
}
}

TEST(gate_data, hash_matches_storage_location) {
ASSERT_EQ(GATE_DATA.items[0].id, 0);
ASSERT_EQ(GATE_DATA.items[0].flags, GateFlags::NO_GATE_FLAG);
ASSERT_EQ((GateType)0, GateType::NOT_A_GATE);
ASSERT_EQ(GATE_DATA[(GateType)0].id, (GateType)0);
ASSERT_EQ(GATE_DATA[(GateType)0].flags, GateFlags::NO_GATE_FLAG);
for (size_t k = 0; k < GATE_DATA.items.size(); k++) {
const auto &g = GATE_DATA.items[k];
EXPECT_EQ(g.id, k) << g.name;
if (g.id != 0) {
const auto &g = GATE_DATA[(GateType)k];
EXPECT_EQ(g.id, (GateType)k) << g.name;
if (g.id != GateType::NOT_A_GATE) {
EXPECT_EQ(GATE_DATA.hashed_name_to_gate_type_table[gate_name_to_hash(g.name)].id, g.id) << g.name;
}
}
Expand Down Expand Up @@ -132,7 +134,7 @@ TEST_EACH_WORD_SIZE_W(gate_data, unitary_inverses_are_correct, {
for (const auto &g : GATE_DATA.items) {
if (g.flags & GATE_IS_UNITARY) {
auto g_t_inv = g.tableau<W>().inverse(false);
auto g_inv_t = GATE_DATA.items[static_cast<uint8_t>(g.best_candidate_inverse_id)].tableau<W>();
auto g_inv_t = GATE_DATA[g.best_candidate_inverse_id].tableau<W>();
EXPECT_EQ(g_t_inv, g_inv_t) << g.name;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/stim/circuit/gate_data_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ struct GateVTable {
#ifndef NDEBUG
std::array<bool, NUM_DEFINED_GATES> seen{};
for (const auto &[gate_id, value] : gate_data_pairs) {
seen[gate_id] = true;
seen[(size_t)gate_id] = true;
}
for (const auto &gate : GATE_DATA.items) {
if (!seen[gate.id]) {
if (!seen[(size_t)gate.id]) {
throw std::invalid_argument(
"Missing gate data! A value was not defined for '" + std::string(gate.name) + "'.");
}
Expand Down
6 changes: 3 additions & 3 deletions src/stim/circuit/gate_decomposition.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ TEST(gate_decomposition, decompose_pair_instruction_into_segments_with_single_us
for (size_t k = 0; k < segment.targets.size(); k += 2) {
evens.push_back(segment.targets[k]);
}
out.safe_append(CircuitInstruction{stim::CX, {}, segment.targets});
out.safe_append(CircuitInstruction{stim::MX, segment.args, evens});
out.safe_append(CircuitInstruction{stim::CX, {}, segment.targets});
out.safe_append(CircuitInstruction{GateType::CX, {}, segment.targets});
out.safe_append(CircuitInstruction{GateType::MX, segment.args, evens});
out.safe_append(CircuitInstruction{GateType::CX, {}, segment.targets});
out.append_from_text("TICK");
};
decompose_pair_instruction_into_segments_with_single_use_controls(
Expand Down
2 changes: 1 addition & 1 deletion src/stim/cmd/command_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ int stim::command_convert(int argc, const char **argv) {
// convert arbitrary bits.
if (!details.include_measurements && !details.include_detectors && !details.include_observables) {
// dets outputs explicit value types, which we don't know if we get here.
if (out_format.id == SAMPLE_FORMAT_DETS) {
if (out_format.id == SampleFormat::SAMPLE_FORMAT_DETS) {
std::cerr
<< "\033[31mNot enough information given to parse input file to write to dets. Please given a circuit "
"with --types, a DEM file, or explicit number of each desired type\n";
Expand Down
2 changes: 1 addition & 1 deletion src/stim/cmd/command_detect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int stim::command_detect(int argc, const char **argv) {
find_argument("--shots", argc, argv) ? (uint64_t)find_int64_argument("--shots", 1, 0, INT64_MAX, argc, argv)
: find_argument("--detect", argc, argv) ? (uint64_t)find_int64_argument("--detect", 1, 0, INT64_MAX, argc, argv)
: 1;
if (out_format.id == SAMPLE_FORMAT_DETS && !append_observables) {
if (out_format.id == SampleFormat::SAMPLE_FORMAT_DETS && !append_observables) {
prepend_observables = true;
}

Expand Down
Loading

0 comments on commit c128b57

Please sign in to comment.