Skip to content

Commit

Permalink
Remove unneeded state from Runge-Kutta stepper
Browse files Browse the repository at this point in the history
Currently, the Runge-Kutta stepper holds a lot of state which is used
only internally. The size of this state is 164 bytes, of which only 140
bytes are needed after the step. Keeping the remaining state between
steps is therefore taking up a lot of memory.

In this PR, I remove this internal state from the externally facing
stepper state and turn it into an intermediate struct which is not
stored persistently, saving 140 bytes on the size of the stepper state
(and, with it, the propagator state).
  • Loading branch information
stephenswat committed Oct 28, 2024
1 parent d6bc161 commit bc1ea73
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 84 deletions.
52 changes: 26 additions & 26 deletions core/include/detray/propagator/rk_stepper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ class rk_stepper final

rk_stepper() = default;

struct intermediate_state {
vector3_type b_first{0.f, 0.f, 0.f};
vector3_type b_middle{0.f, 0.f, 0.f};
vector3_type b_last{0.f, 0.f, 0.f};
// t = tangential direction = dr/ds
std::array<vector3_type, 4u> t;
// q/p
std::array<scalar_type, 4u> qop;
// dt/ds = d^2r/ds^2 = q/p ( t X B )
std::array<vector3_type, 4u> dtds;
// d(q/p)/ds
std::array<scalar_type, 4u> dqopds;
};

struct state : public base_type::state {

friend rk_stepper;
Expand All @@ -72,29 +86,27 @@ class rk_stepper final
/// @returns the B-field view
magnetic_field_type field() const { return m_magnetic_field; }

/// @returns access to the step data
const auto& step_data() const { return m_step_data; }

/// Update the track state by Runge-Kutta-Nystrom integration.
DETRAY_HOST_DEVICE
void advance_track();
void advance_track(const intermediate_state& sd);

/// Update the jacobian transport from free propagation
DETRAY_HOST_DEVICE
void advance_jacobian(const stepping::config& cfg = {});
void advance_jacobian(const stepping::config& cfg,
const intermediate_state&);

/// evaulate dqopds for a given step size and material
DETRAY_HOST_DEVICE
scalar_type evaluate_dqopds(const std::size_t i, const scalar_type h,
const scalar_type dqopds_prev,
const detray::stepping::config& cfg);
detray::pair<scalar_type, scalar_type> evaluate_dqopds(
const std::size_t i, const scalar_type h,
const scalar_type dqopds_prev, const detray::stepping::config& cfg);

/// evaulate dtds for runge kutta stepping
DETRAY_HOST_DEVICE
vector3_type evaluate_dtds(const vector3_type& b_field,
const std::size_t i, const scalar_type h,
const vector3_type& dtds_prev,
const scalar_type qop);
detray::pair<vector3_type, vector3_type> evaluate_dtds(
const vector3_type& b_field, const std::size_t i,
const scalar_type h, const vector3_type& dtds_prev,
const scalar_type qop);

DETRAY_HOST_DEVICE
matrix_type<3, 3> evaluate_field_gradient(const point3_type& pos);
Expand Down Expand Up @@ -128,20 +140,8 @@ class rk_stepper final
}

private:
/// stepping data required for RKN4
struct {
vector3_type b_first{0.f, 0.f, 0.f};
vector3_type b_middle{0.f, 0.f, 0.f};
vector3_type b_last{0.f, 0.f, 0.f};
// t = tangential direction = dr/ds
std::array<vector3_type, 4u> t;
// q/p
std::array<scalar_type, 4u> qop;
// dt/ds = d^2r/ds^2 = q/p ( t X B )
std::array<vector3_type, 4u> dtds;
// d(q/p)/ds
std::array<scalar_type, 4u> dqopds;
} m_step_data;
vector3_type m_dtds_3;
scalar_type m_dqopds_3;

/// Magnetic field view
const magnetic_field_t m_magnetic_field;
Expand Down
97 changes: 50 additions & 47 deletions core/include/detray/propagator/rk_stepper.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

template <typename magnetic_field_t, typename algebra_t, typename constraint_t,
typename policy_t, typename inspector_t>
DETRAY_HOST_DEVICE inline void
detray::rk_stepper<magnetic_field_t, algebra_t, constraint_t, policy_t,
inspector_t>::state::advance_track() {
DETRAY_HOST_DEVICE inline void detray::rk_stepper<
magnetic_field_t, algebra_t, constraint_t, policy_t, inspector_t>::state::
advance_track(
const detray::rk_stepper<magnetic_field_t, algebra_t, constraint_t,
policy_t, inspector_t>::intermediate_state&
sd) {

const auto& sd = m_step_data;
const scalar_type h{this->step_size()};
const scalar_type h_6{h * static_cast<scalar_type>(1. / 6.)};
auto& track = (*this)();
Expand Down Expand Up @@ -49,7 +51,8 @@ template <typename magnetic_field_t, typename algebra_t, typename constraint_t,
typename policy_t, typename inspector_t>
DETRAY_HOST_DEVICE inline void detray::rk_stepper<
magnetic_field_t, algebra_t, constraint_t, policy_t,
inspector_t>::state::advance_jacobian(const detray::stepping::config& cfg) {
inspector_t>::state::advance_jacobian(const detray::stepping::config& cfg,
const intermediate_state& sd) {
/// The calculations are based on ATL-SOFT-PUB-2009-002. The update of the
/// Jacobian matrix is requires only the calculation of eq. 17 and 18.
/// Since the terms of eq. 18 are currently 0, this matrix is not needed
Expand All @@ -73,7 +76,6 @@ DETRAY_HOST_DEVICE inline void detray::rk_stepper<
//( JacTransport = D * JacTransport )
auto D = matrix_operator().template identity<e_free_size, e_free_size>();

const auto& sd = m_step_data;
const scalar_type h{this->step_size()};
auto& track = (*this)();

Expand Down Expand Up @@ -334,32 +336,25 @@ DETRAY_HOST_DEVICE inline auto detray::rk_stepper<
const scalar_type h,
const scalar_type dqopds_prev,
const detray::stepping::config& cfg)
-> scalar_type {
-> detray::pair<scalar_type, scalar_type> {

const auto& track = (*this)();
const scalar_type qop = track.qop();
auto& sd = m_step_data;

if (!(this->volume_has_material())) {
sd.qop[i] = qop;
return 0.f;
const scalar_type qop = track.qop();
return detray::make_pair(0.f, qop);
} else if (cfg.use_mean_loss && i != 0u) {
// qop_n is calculated recursively like the direction of
// evaluate_dtds.
//
// https://doi.org/10.1016/0029-554X(81)90063-X says:
// "For y we have similar formulae as for x, for y' and
// \lambda similar formulae as for x'"
const scalar_type qop = track.qop() + h * dqopds_prev;
return detray::make_pair(this->dqopds(qop), qop);
} else {

if (cfg.use_mean_loss) {
if (i == 0u) {
sd.qop[i] = qop;
} else {

// qop_n is calculated recursively like the direction of
// evaluate_dtds.
//
// https://doi.org/10.1016/0029-554X(81)90063-X says:
// "For y we have similar formulae as for x, for y' and
// \lambda similar formulae as for x'"
sd.qop[i] = qop + h * dqopds_prev;
}
}
return this->dqopds(sd.qop[i]);
const scalar_type qop = track.qop();
return detray::make_pair(this->dqopds(qop), qop);
}
}

Expand All @@ -370,20 +365,22 @@ DETRAY_HOST_DEVICE inline auto detray::rk_stepper<
inspector_t>::state::evaluate_dtds(const vector3_type& b_field,
const std::size_t i, const scalar_type h,
const vector3_type& dtds_prev,
const scalar_type qop) -> vector3_type {
const scalar_type qop)
-> detray::pair<vector3_type, vector3_type> {
auto& track = (*this)();
const auto dir = track.dir();
auto& sd = m_step_data;

vector3_type t;

if (i == 0u) {
sd.t[i] = dir;
t = dir;
} else {
// Eq (84) of https://doi.org/10.1016/0029-554X(81)90063-X
sd.t[i] = dir + h * dtds_prev;
t = dir + h * dtds_prev;
}

// dtds = qop * (t X B) from Lorentz force
return qop * vector::cross(sd.t[i], b_field);
return detray::make_pair(vector3_type{qop * vector::cross(t, b_field)}, t);
}

template <typename magnetic_field_t, typename algebra_t, typename constraint_t,
Expand Down Expand Up @@ -445,7 +442,8 @@ detray::rk_stepper<magnetic_field_t, algebra_t, constraint_t, policy_t,

return (*this)().qop() * vector::cross((*this)().dir(), bvec);
}
return m_step_data.dtds[3u];

return m_dtds_3;
}

template <typename magnetic_field_t, typename algebra_t, typename constraint_t,
Expand All @@ -459,7 +457,7 @@ detray::rk_stepper<magnetic_field_t, algebra_t, constraint_t, policy_t,
return this->dqopds((*this)().qop());
}

return m_step_data.dqopds[3u];
return m_dqopds_3;
}

template <typename magnetic_field_t, typename algebra_t, typename constraint_t,
Expand Down Expand Up @@ -555,7 +553,7 @@ DETRAY_HOST_DEVICE inline bool detray::rk_stepper<

const point3_type pos = stepping().pos();

auto& sd = stepping.m_step_data;
intermediate_state sd;

scalar_type error_estimate{0.f};

Expand All @@ -567,8 +565,9 @@ DETRAY_HOST_DEVICE inline bool detray::rk_stepper<

// qop should be recalcuated at every point
// Reference: Eq (84) of https://doi.org/10.1016/0029-554X(81)90063-X
sd.dqopds[0u] = stepping.evaluate_dqopds(0u, 0.f, 0.f, cfg);
sd.dtds[0u] = stepping.evaluate_dtds(
detray::tie(sd.dqopds[0u], sd.qop[0u]) =
stepping.evaluate_dqopds(0u, 0.f, 0.f, cfg);
detray::tie(sd.dtds[0u], sd.t[0u]) = stepping.evaluate_dtds(
sd.b_first, 0u, 0.f, vector3_type{0.f, 0.f, 0.f}, sd.qop[0u]);

/// RKN step trial and error estimation
Expand All @@ -587,18 +586,18 @@ DETRAY_HOST_DEVICE inline bool detray::rk_stepper<
sd.b_middle[1] = bvec1[1];
sd.b_middle[2] = bvec1[2];

sd.dqopds[1u] =
detray::tie(sd.dqopds[1u], sd.qop[1u]) =
stepping.evaluate_dqopds(1u, half_h, sd.dqopds[0u], cfg);
sd.dtds[1u] = stepping.evaluate_dtds(sd.b_middle, 1u, half_h,
sd.dtds[0u], sd.qop[1u]);
detray::tie(sd.dtds[1u], sd.t[1u]) = stepping.evaluate_dtds(
sd.b_middle, 1u, half_h, sd.dtds[0u], sd.qop[1u]);

// Third Runge-Kutta point
// qop should be recalcuated at every point
// Reference: Eq (84) of https://doi.org/10.1016/0029-554X(81)90063-X
sd.dqopds[2u] =
detray::tie(sd.dqopds[2u], sd.qop[2u]) =
stepping.evaluate_dqopds(2u, half_h, sd.dqopds[1u], cfg);
sd.dtds[2u] = stepping.evaluate_dtds(sd.b_middle, 2u, half_h,
sd.dtds[1u], sd.qop[2u]);
detray::tie(sd.dtds[2u], sd.t[2u]) = stepping.evaluate_dtds(
sd.b_middle, 2u, half_h, sd.dtds[1u], sd.qop[2u]);

// Last Runge-Kutta point
// qop should be recalcuated at every point
Expand All @@ -609,8 +608,9 @@ DETRAY_HOST_DEVICE inline bool detray::rk_stepper<
sd.b_last[1] = bvec2[1];
sd.b_last[2] = bvec2[2];

sd.dqopds[3u] = stepping.evaluate_dqopds(3u, h, sd.dqopds[2u], cfg);
sd.dtds[3u] =
detray::tie(sd.dqopds[3u], sd.qop[3u]) =
stepping.evaluate_dqopds(3u, h, sd.dqopds[2u], cfg);
detray::tie(sd.dtds[3u], sd.t[3u]) =
stepping.evaluate_dtds(sd.b_last, 3u, h, sd.dtds[2u], sd.qop[3u]);

// Compute and check the local integration error estimate
Expand Down Expand Up @@ -662,6 +662,9 @@ DETRAY_HOST_DEVICE inline bool detray::rk_stepper<
}
}

stepping.m_dtds_3 = sd.dtds[3u];
stepping.m_dqopds_3 = sd.dqopds[3u];

// Update navigation direction
const step::direction step_dir = stepping.step_size() >= 0.f
? step::direction::e_forward
Expand All @@ -680,11 +683,11 @@ DETRAY_HOST_DEVICE inline bool detray::rk_stepper<
}

// Advance track state
stepping.advance_track();
stepping.advance_track(sd);

// Advance jacobian transport
if (cfg.do_covariance_transport) {
stepping.advance_jacobian(cfg);
stepping.advance_jacobian(cfg, sd);
}

// Save the current step size
Expand Down
11 changes: 0 additions & 11 deletions tests/include/detray/test/utils/inspectors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,17 +450,6 @@ struct print_inspector {
debug_stream << "Step size scale factor"
<< "\t\t" << step_scalor << std::endl;

debug_stream << "Bfield points:" << std::endl;
const auto &f = state.step_data().b_first;
debug_stream << "\tfirst:" << tabs << f[0] << ", " << f[1] << ", "
<< f[2] << std::endl;
const auto &m = state.step_data().b_middle;
debug_stream << "\tmiddle:" << tabs << m[0] << ", " << m[1] << ", "
<< m[2] << std::endl;
const auto &l = state.step_data().b_last;
debug_stream << "\tlast:" << tabs << l[0] << ", " << l[1] << ", "
<< l[2] << std::endl;

debug_stream << std::endl;
}

Expand Down

0 comments on commit bc1ea73

Please sign in to comment.