Skip to content

Commit

Permalink
Dropped elementwise access to single_store (#857)
Browse files Browse the repository at this point in the history
Replaced all calls to single_store::operator[]() with calls to single_store::at()
  • Loading branch information
tsulaiav authored Oct 18, 2024
1 parent e3724d6 commit e4dbe80
Show file tree
Hide file tree
Showing 15 changed files with 39 additions and 46 deletions.
2 changes: 1 addition & 1 deletion core/include/detray/builders/cylinder_portal_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class cylinder_portal_generator final
double mean{0.};

for (const auto &sf_desc : surfaces) {
const auto &trf = transforms[sf_desc.transform()];
const auto &trf = transforms.at(sf_desc.transform());
mean += getter::perp(trf.translation());
}

Expand Down
2 changes: 1 addition & 1 deletion core/include/detray/builders/detail/bin_association.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ static inline void bin_association(const context_t & /*context*/,
}

// Unroll the mask container and generate vertices
const auto &transform = transforms[sf.transform()];
const auto &transform = transforms.at(sf.transform());

auto vertices_per_masks = surface_masks.template visit<
detail::vertexer<point2_t, point3_t>>(sf.mask());
Expand Down
28 changes: 10 additions & 18 deletions core/include/detray/core/detail/single_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ class single_store {
return m_container.empty();
}

/// @returns the collections iterator at the start position.
/// @returns the collections iterator at the start position
DETRAY_HOST_DEVICE
constexpr auto begin(const context_type & /*ctx*/ = {}) {
constexpr auto begin(const context_type & /*ctx*/ = {}) const {
return m_container.begin();
}

/// @returns the collections iterator sentinel.
/// @returns the collections iterator sentinel
DETRAY_HOST_DEVICE
constexpr auto end(const context_type & /*ctx*/ = {}) {
constexpr auto end(const context_type & /*ctx*/ = {}) const {
return m_container.end();
}

Expand All @@ -130,30 +130,22 @@ class single_store {
return m_container;
}

/// Elementwise access. Needs @c operator[] for storage type - non-const
DETRAY_HOST_DEVICE
constexpr decltype(auto) operator[](const dindex i) {
return m_container[i];
}

/// Elementwise access. Needs @c operator[] for storage type - const
DETRAY_HOST_DEVICE
constexpr decltype(auto) operator[](const dindex i) const {
return m_container[i];
}

/// @returns context based access to an element (also range checked)
DETRAY_HOST_DEVICE
constexpr auto at(const dindex i,
const context_type & /*ctx*/) const noexcept
const context_type &ctx = {}) const noexcept
-> const T & {
[[maybe_unused]] context_type tmp_ctx{
ctx}; // Temporary measure to avoid warnings
return m_container.at(i);
}

/// @returns context based access to an element (also range checked)
DETRAY_HOST_DEVICE
constexpr auto at(const dindex i, const context_type & /*ctx*/) noexcept
constexpr auto at(const dindex i, const context_type &ctx = {}) noexcept
-> T & {
[[maybe_unused]] context_type tmp_ctx{
ctx}; // Temporary measure to avoid warnings
return m_container.at(i);
}

Expand Down
2 changes: 1 addition & 1 deletion core/include/detray/core/detector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void set_transform(detector_t &det, const transform3_t &trf, unsigned int i) {
<< "WARNING: Modifying transforms in the detector will be deprecated! "
"Please, use a separate geometry context in this case"
<< std::endl;
det._transforms[i] = trf;
det._transforms.at(i) = trf;
}
} // namespace detail

Expand Down
2 changes: 1 addition & 1 deletion core/include/detray/geometry/tracking_volume.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class tracking_volume {
DETRAY_HOST_DEVICE
constexpr auto transform() const -> const
typename detector_t::transform3_type & {
return m_detector.transform_store()[m_desc.transform()];
return m_detector.transform_store().at(m_desc.transform());
}

/// @returns the center point of the volume.
Expand Down
4 changes: 2 additions & 2 deletions core/include/detray/navigation/intersection_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct intersection_initialize {
using mask_t = typename mask_group_t::value_type;
using algebra_t = typename mask_t::algebra_type;

const auto &ctf = contextual_transforms[surface.transform()];
const auto &ctf = contextual_transforms.at(surface.transform());

// Run over the masks that belong to the surface (only one can be hit)
for (const auto &mask :
Expand Down Expand Up @@ -143,7 +143,7 @@ struct intersection_update {
using mask_t = typename mask_group_t::value_type;
using algebra_t = typename mask_t::algebra_type;

const auto &ctf = contextual_transforms[sfi.sf_desc.transform()];
const auto &ctf = contextual_transforms.at(sfi.sf_desc.transform());

// Run over the masks that belong to the surface
for (const auto &mask :
Expand Down
2 changes: 1 addition & 1 deletion core/include/detray/utils/grid/grid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class grid_impl {
const track_t &track, const config_t &cfg) const {

// Track position in grid coordinates
const auto &trf = det.transform_store()[volume.transform()];
const auto &trf = det.transform_store().at(volume.transform());
const auto loc_pos = project(trf, track.pos(), track.dir());

// Grid lookup
Expand Down
2 changes: 1 addition & 1 deletion io/include/detray/io/common/geometry_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class geometry_writer {
vol_data.index = detail::basic_converter::convert(vol_desc.index());
vol_data.name = name;
vol_data.transform =
convert<detector_t>(det.transform_store()[vol_desc.transform()]);
convert<detector_t>(det.transform_store().at(vol_desc.transform()));
vol_data.type = vol_desc.id();

// Count the surfaces belonging to this volume
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ GTEST_TEST(detray_builders, detector_builder_with_material) {
// Check the volume placement
typename detector_t::transform3_type trf{t};
EXPECT_TRUE(vol.transform() == trf);
EXPECT_TRUE(d.transform_store()[0u] == trf);
EXPECT_TRUE(d.transform_store().at(0u) == trf);

EXPECT_EQ(d.surfaces().size(), 7u);
EXPECT_EQ(d.mask_store().template size<mask_id::e_rectangle2>(), 3u);
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/cpu/builders/volume_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ GTEST_TEST(detray_builders, tracking_volume_construction) {

// Check the volume placement
typename detector_t::transform3_type trf{t};
EXPECT_TRUE(d.transform_store()[first_trf] == trf);
EXPECT_TRUE(d.transform_store().at(first_trf) == trf);

// Check the acceleration data structure link
dtyped_index<accel_id, dindex> acc_link{accel_id::e_default, 1u};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ std::pair<euler_rotation<algebra_type>, std::array<scalar, 3u>> tilt_surface(

const auto& sf = det.surface(sf_id);
const auto& trf_link = sf.transform();
auto& trf = det.transform_store()[trf_link];
auto& trf = det.transform_store().at(trf_link);

euler_rotation<algebra_type> euler;
euler.alpha = alpha;
Expand Down Expand Up @@ -620,7 +620,7 @@ bound_track_parameters<algebra_type> get_initial_parameter(

const auto& departure_sf = det.surface(0u);
const auto& trf_link = departure_sf.transform();
const auto& departure_trf = det.transform_store()[trf_link];
const auto& departure_trf = det.transform_store().at(trf_link);
const auto& mask_link = departure_sf.mask();
const auto& departure_mask =
det.mask_store().template get<mask_id>().at(mask_link.index());
Expand Down Expand Up @@ -985,7 +985,7 @@ bound_param_vector_type get_displaced_bound_vector_helix(

const auto& destination_sf = det.surface(1u);
const auto& trf_link = destination_sf.transform();
const auto& destination_trf = det.transform_store()[trf_link];
const auto& destination_trf = det.transform_store().at(trf_link);
const auto& mask_link = destination_sf.mask();
const auto& destination_mask =
det.mask_store().template get<mask_id>().at(mask_link.index());
Expand Down Expand Up @@ -1047,7 +1047,7 @@ void evaluate_jacobian_difference_helix(

const auto& destination_sf = det.surface(1u);
const auto& trf_link = destination_sf.transform();
const auto& destination_trf = det.transform_store()[trf_link];
const auto& destination_trf = det.transform_store().at(trf_link);
const auto& mask_link = destination_sf.mask();
const auto& destination_mask =
det.mask_store().template get<mask_id>().at(mask_link.index());
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/cpu/builders/detector_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ GTEST_TEST(detray_builders, detector_builder) {
// Check the volume placements for both volumes
typename detector_t::transform3_type identity{};
EXPECT_TRUE(vol0.transform() == identity);
EXPECT_TRUE(d.transform_store()[0u] == identity);
EXPECT_TRUE(d.transform_store().at(0u) == identity);
typename detector_t::transform3_type trf{t};
EXPECT_TRUE(vol1.transform() == trf);
EXPECT_TRUE(d.transform_store()[4u] == trf);
EXPECT_TRUE(d.transform_store().at(4u) == trf);

EXPECT_EQ(d.surfaces().size(), 12u);
EXPECT_EQ(d.mask_store().template size<mask_id::e_portal_cylinder2>(), 0u);
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/cpu/core/containers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ GTEST_TEST(detray_core, single_store) {
EXPECT_EQ(store.size(), 4u);

// Check access to the data
EXPECT_NEAR(store[0], 1., tol_double);
EXPECT_NEAR(store[2], 10.5, tol_double);
EXPECT_NEAR(store.at(0, ctx), 1., tol_double);
EXPECT_NEAR(store.at(2, ctx), 10.5, tol_double);
EXPECT_NEAR(store.at(1, ctx), 2., tol_double);
EXPECT_NEAR(store.at(3, ctx), 7.6, tol_double);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,18 @@ GTEST_TEST(detray_intersection, intersection_kernel_ray) {

if (sfi_init[i].sf_desc.mask().id() == e_rectangle2) {
global =
rect.to_global_frame(transform_store[0], sfi_init[i].local);
rect.to_global_frame(transform_store.at(0), sfi_init[i].local);
} else if (sfi_init[i].sf_desc.mask().id() == e_trapezoid2) {
global =
trap.to_global_frame(transform_store[1], sfi_init[i].local);
trap.to_global_frame(transform_store.at(1), sfi_init[i].local);
} else if (sfi_init[i].sf_desc.mask().id() == e_annulus2) {
global =
annl.to_global_frame(transform_store[2], sfi_init[i].local);
annl.to_global_frame(transform_store.at(2), sfi_init[i].local);
} else if (sfi_init[i].sf_desc.mask().id() == e_cylinder2) {
global = cyl.to_global_frame(transform_store[3], sfi_init[i].local);
global =
cyl.to_global_frame(transform_store.at(3), sfi_init[i].local);
} else if (sfi_init[i].sf_desc.mask().id() == e_cylinder2_portal) {
global = cyl_portal.to_global_frame(transform_store[4],
global = cyl_portal.to_global_frame(transform_store.at(4),
sfi_init[i].local);
}

Expand Down Expand Up @@ -284,13 +285,13 @@ GTEST_TEST(detray_intersection, intersection_kernel_helix) {

if (surface.mask().id() == e_rectangle2) {
global =
rect.to_global_frame(transform_store[0], sfi_helix[0].local);
rect.to_global_frame(transform_store.at(0), sfi_helix[0].local);
} else if (surface.mask().id() == e_trapezoid2) {
global =
trap.to_global_frame(transform_store[1], sfi_helix[0].local);
trap.to_global_frame(transform_store.at(1), sfi_helix[0].local);
} else if (surface.mask().id() == e_annulus2) {
global =
annl.to_global_frame(transform_store[2], sfi_helix[0].local);
annl.to_global_frame(transform_store.at(2), sfi_helix[0].local);
}

ASSERT_NEAR(global[0], expected_points[sf_idx][0], is_close);
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/device/cuda/container_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ TEST(container_cuda, single_store) {
EXPECT_EQ(mng_store.size(), 4u);

// Check the host-side access to the data
EXPECT_NEAR(mng_store[0], 1., tol);
EXPECT_NEAR(mng_store[2], 10.5, tol);
EXPECT_NEAR(mng_store.at(0, ctx), 1., tol);
EXPECT_NEAR(mng_store.at(2, ctx), 10.5, tol);
EXPECT_NEAR(mng_store.at(1, ctx), 2., tol);
EXPECT_NEAR(mng_store.at(3, ctx), 7.6, tol);

Expand Down

0 comments on commit e4dbe80

Please sign in to comment.