diff --git a/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp b/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp index 183a0f50f..2d5618ef4 100644 --- a/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp +++ b/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp @@ -39,7 +39,7 @@ struct gain_matrix_updater { TRACCC_HOST_DEVICE inline bool operator()( const mask_group_t& /*mask_group*/, const index_t& /*index*/, track_state& trk_state, - bound_track_parameters& bound_params) const { + const bound_track_parameters& bound_params) const { using shape_type = typename mask_group_t::value_type::shape; @@ -57,7 +57,7 @@ struct gain_matrix_updater { template TRACCC_HOST_DEVICE inline bool update( track_state& trk_state, - bound_track_parameters& bound_params) const { + const bound_track_parameters& bound_params) const { static_assert(((D == 1u) || (D == 2u)), "The measurement dimension should be 1 or 2"); @@ -123,10 +123,6 @@ struct gain_matrix_updater { const matrix_type<1, 1> chi2 = matrix_operator().transpose(residual) * matrix_operator().inverse(R) * residual; - // Set the stepper parameter - bound_params.set_vector(filtered_vec); - bound_params.set_covariance(filtered_cov); - // Return false if track is parallel to z-axis or phi is not finite const scalar theta = bound_params.theta(); if (theta <= 0.f || theta >= constant::pi || @@ -134,14 +130,14 @@ struct gain_matrix_updater { return false; } - // Wrap the phi in the range of [-pi, pi] - wrap_phi(bound_params); - // Set the track state parameters trk_state.filtered().set_vector(filtered_vec); trk_state.filtered().set_covariance(filtered_cov); trk_state.filtered_chi2() = matrix_operator().element(chi2, 0, 0); + // Wrap the phi in the range of [-pi, pi] + wrap_phi(trk_state.filtered()); + return true; } }; diff --git a/core/include/traccc/fitting/kalman_filter/kalman_actor.hpp b/core/include/traccc/fitting/kalman_filter/kalman_actor.hpp index 23a886490..9db5b1fd3 100644 --- a/core/include/traccc/fitting/kalman_filter/kalman_actor.hpp +++ b/core/include/traccc/fitting/kalman_filter/kalman_actor.hpp @@ -78,7 +78,7 @@ struct kalman_actor : detray::actor { TRACCC_HOST_DEVICE void operator()(state& actor_state, propagator_state_t& propagation) const { - const auto& stepping = propagation._stepping; + auto& stepping = propagation._stepping; auto& navigation = propagation._navigation; // If the iterator reaches the end, terminate the propagation @@ -114,6 +114,9 @@ struct kalman_actor : detray::actor { return; } + // Update the propagation flow + stepping.bound_params() = trk_state.filtered(); + // Set full jacobian trk_state.jacobian() = stepping.full_jacobian(); diff --git a/core/src/finding/find_tracks.hpp b/core/src/finding/find_tracks.hpp index 5a33d4b4d..f24cc28f6 100644 --- a/core/src/finding/find_tracks.hpp +++ b/core/src/finding/find_tracks.hpp @@ -239,10 +239,9 @@ track_candidate_container_types::host find_tracks( track_state trk_state(meas); // Run the Kalman update on a copy of the track parameters - bound_track_parameters bound_param(in_param); const bool res = sf.template visit_mask>( - trk_state, bound_param); + trk_state, in_param); // The chi2 from Kalman update should be less than chi2_max if (res && trk_state.filtered_chi2() < config.chi2_max) { diff --git a/device/common/include/traccc/finding/device/impl/find_tracks.ipp b/device/common/include/traccc/finding/device/impl/find_tracks.ipp index fe816bff8..b0024dca2 100644 --- a/device/common/include/traccc/finding/device/impl/find_tracks.ipp +++ b/device/common/include/traccc/finding/device/impl/find_tracks.ipp @@ -182,7 +182,7 @@ TRACCC_DEVICE inline void find_tracks( owner_local_thread_id + thread_id.getBlockDimX() * thread_id.getBlockIdX(); assert(in_params_liveness.at(owner_global_thread_id) != 0u); - bound_track_parameters in_par = + const bound_track_parameters& in_par = in_params.at(owner_global_thread_id); const unsigned int meas_idx = shared_payload.shared_candidates[thread_id.getLocalThreadIdX()]