Skip to content

Commit

Permalink
Merge branch 'main' into vecmem_typedef_tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
StewMH authored Nov 7, 2024
2 parents 7cabd3d + 0174113 commit 15b65b6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct gain_matrix_updater {
TRACCC_HOST_DEVICE inline bool operator()(
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
track_state<algebra_t>& trk_state,
bound_track_parameters& bound_params) const {
const bound_track_parameters& bound_params) const {

using shape_type = typename mask_group_t::value_type::shape;

Expand All @@ -57,7 +57,7 @@ struct gain_matrix_updater {
template <size_type D, typename shape_t>
TRACCC_HOST_DEVICE inline bool update(
track_state<algebra_t>& trk_state,
bound_track_parameters& bound_params) const {
const bound_track_parameters& bound_params) const {

static_assert(((D == 1u) || (D == 2u)),
"The measurement dimension should be 1 or 2");
Expand Down Expand Up @@ -123,25 +123,21 @@ struct gain_matrix_updater {
const matrix_type<1, 1> chi2 = matrix_operator().transpose(residual) *
matrix_operator().inverse(R) * residual;

// Set the stepper parameter
bound_params.set_vector(filtered_vec);
bound_params.set_covariance(filtered_cov);

// Return false if track is parallel to z-axis or phi is not finite
const scalar theta = bound_params.theta();
if (theta <= 0.f || theta >= constant<traccc::scalar>::pi ||
!std::isfinite(bound_params.phi())) {
return false;
}

// Wrap the phi in the range of [-pi, pi]
wrap_phi(bound_params);

// Set the track state parameters
trk_state.filtered().set_vector(filtered_vec);
trk_state.filtered().set_covariance(filtered_cov);
trk_state.filtered_chi2() = matrix_operator().element(chi2, 0, 0);

// Wrap the phi in the range of [-pi, pi]
wrap_phi(trk_state.filtered());

return true;
}
};
Expand Down
5 changes: 4 additions & 1 deletion core/include/traccc/fitting/kalman_filter/kalman_actor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct kalman_actor : detray::actor {
TRACCC_HOST_DEVICE void operator()(state& actor_state,
propagator_state_t& propagation) const {

const auto& stepping = propagation._stepping;
auto& stepping = propagation._stepping;
auto& navigation = propagation._navigation;

// If the iterator reaches the end, terminate the propagation
Expand Down Expand Up @@ -114,6 +114,9 @@ struct kalman_actor : detray::actor {
return;
}

// Update the propagation flow
stepping.bound_params() = trk_state.filtered();

// Set full jacobian
trk_state.jacobian() = stepping.full_jacobian();

Expand Down
3 changes: 1 addition & 2 deletions core/src/finding/find_tracks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,9 @@ track_candidate_container_types::host find_tracks(
track_state<algebra_type> trk_state(meas);

// Run the Kalman update on a copy of the track parameters
bound_track_parameters bound_param(in_param);
const bool res =
sf.template visit_mask<gain_matrix_updater<algebra_type>>(
trk_state, bound_param);
trk_state, in_param);

// The chi2 from Kalman update should be less than chi2_max
if (res && trk_state.filtered_chi2() < config.chi2_max) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ TRACCC_DEVICE inline void find_tracks(
owner_local_thread_id +
thread_id.getBlockDimX() * thread_id.getBlockIdX();
assert(in_params_liveness.at(owner_global_thread_id) != 0u);
bound_track_parameters in_par =
const bound_track_parameters& in_par =
in_params.at(owner_global_thread_id);
const unsigned int meas_idx =
shared_payload.shared_candidates[thread_id.getLocalThreadIdX()]
Expand Down

0 comments on commit 15b65b6

Please sign in to comment.