Skip to content

Commit

Permalink
Tensors: mpi_barrier for accurate timings
Browse files Browse the repository at this point in the history
  • Loading branch information
pseewald committed Sep 5, 2020
1 parent c27be50 commit d507393
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
13 changes: 12 additions & 1 deletion src/tas/dbcsr_tas_mm.F
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ MODULE dbcsr_tas_mm
USE dbcsr_kinds, ONLY: &
int_8, real_8, real_4, default_string_length
USE dbcsr_mpiwrap, ONLY: &
mp_comm_compare, mp_environ, mp_sum, mp_comm_free, mp_cart_create, mp_max
mp_comm_compare, mp_environ, mp_sum, mp_comm_free, mp_cart_create, mp_max, mp_sync
USE dbcsr_operations, ONLY: &
dbcsr_scale, dbcsr_get_info, dbcsr_copy, dbcsr_clear, dbcsr_add, dbcsr_zero
USE dbcsr_tas_io, ONLY: &
Expand Down Expand Up @@ -129,6 +129,7 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,
TYPE(dbcsr_type) :: matrix_a_mm, matrix_b_mm, matrix_c_mm

CALL timeset(routineN, handle)
CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timeset("dbcsr_tas_total", handle2)

NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
Expand Down Expand Up @@ -259,6 +260,7 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,
blk_ind=result_index, nze=nze_c, retain_sparsity=retain_sparsity)

IF (PRESENT(result_index)) THEN
CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timestop(handle2)
CALL timestop(handle)
RETURN
Expand Down Expand Up @@ -486,6 +488,7 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,
info_c = dbcsr_tas_info(matrix_c_rs)
CALL dbcsr_tas_info_hold(info_c)

CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timeset("dbcsr_tas_dbcsr", handle4)
SELECT CASE (tr_case)
CASE (dbcsr_no_transpose)
Expand All @@ -503,6 +506,7 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,

CALL timestop(handle3)
END SELECT
CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timestop(handle4)

CALL dbcsr_release(matrix_a_mm)
Expand Down Expand Up @@ -574,11 +578,13 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,
info_c = dbcsr_tas_info(matrix_c_rep)
CALL dbcsr_tas_info_hold(info_c)

CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timeset("dbcsr_tas_dbcsr", handle4)
CALL timeset("dbcsr_tas_mm_2", handle3)
CALL dbcsr_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
filter_eps=filter_eps_prv/REAL(nsplit, KIND=real_8), retain_sparsity=retain_sparsity, flop=flop)
CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timestop(handle3)
CALL timestop(handle4)

Expand Down Expand Up @@ -653,6 +659,7 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,
info_c = dbcsr_tas_info(matrix_c_rs)
CALL dbcsr_tas_info_hold(info_c)

CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timeset("dbcsr_tas_dbcsr", handle4)
SELECT CASE (tr_case)
CASE (dbcsr_no_transpose)
Expand All @@ -668,6 +675,7 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,
filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
CALL timestop(handle3)
END SELECT
CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timestop(handle4)

CALL dbcsr_release(matrix_a_mm)
Expand Down Expand Up @@ -757,6 +765,7 @@ RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a,
CALL dbcsr_tas_release_info(info_b)
CALL dbcsr_tas_release_info(info_c)

CALL mp_sync(matrix_a%dist%info%mp_comm)
CALL timestop(handle2)
CALL timestop(handle)

Expand Down Expand Up @@ -1588,6 +1597,7 @@ SUBROUTINE dbcsr_tas_batched_mm_finalize(matrix)
TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix
INTEGER :: handle
CALL mp_sync(matrix%dist%info%mp_comm)
CALL timeset("dbcsr_tas_total", handle)
IF (matrix%do_batched == 0) RETURN
Expand All @@ -1603,6 +1613,7 @@ SUBROUTINE dbcsr_tas_batched_mm_finalize(matrix)
DEALLOCATE (matrix%mm_storage)
CALL dbcsr_tas_set_batched_state(matrix, state=0)
CALL mp_sync(matrix%dist%info%mp_comm)
CALL timestop(handle)
END SUBROUTINE
Expand Down
8 changes: 7 additions & 1 deletion src/tensors/dbcsr_tensor.F
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ MODULE dbcsr_tensor
USE dbcsr_kinds, ONLY: &
${uselist(dtype_float_prec)}$, default_string_length, int_8, dp
USE dbcsr_mpiwrap, ONLY: &
mp_environ, mp_max, mp_sum, mp_comm_free, mp_cart_create
mp_environ, mp_max, mp_sum, mp_comm_free, mp_cart_create, mp_sync
USE dbcsr_toollib, ONLY: &
sort
USE dbcsr_tensor_reshape, ONLY: &
Expand Down Expand Up @@ -125,13 +125,15 @@ SUBROUTINE dbcsr_t_copy(tensor_in, tensor_out, order, summation, bounds, move_da
INTEGER, INTENT(IN), OPTIONAL :: unit_nr
INTEGER :: handle
CALL mp_sync(tensor_in%pgrid%mp_comm_2d)
CALL timeset("dbcsr_t_total", handle)
! make sure that it is safe to use dbcsr_t_copy during a batched contraction
CALL dbcsr_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.)
CALL dbcsr_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.)
CALL dbcsr_t_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
CALL mp_sync(tensor_in%pgrid%mp_comm_2d)
CALL timestop(handle)
END SUBROUTINE
Expand Down Expand Up @@ -517,6 +519,7 @@ SUBROUTINE dbcsr_t_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
INTEGER :: handle
CALL mp_sync(tensor_1%pgrid%mp_comm_2d)
CALL timeset("dbcsr_t_total", handle)
CALL dbcsr_t_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
contract_1, notcontract_1, &
Expand All @@ -535,6 +538,7 @@ SUBROUTINE dbcsr_t_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
retain_sparsity=retain_sparsity, &
unit_nr=unit_nr, &
log_verbose=log_verbose)
CALL mp_sync(tensor_1%pgrid%mp_comm_2d)
CALL timestop(handle)
END SUBROUTINE
Expand Down Expand Up @@ -2084,6 +2088,7 @@ SUBROUTINE dbcsr_t_batched_contract_finalize(tensor, unit_nr)
LOGICAL :: do_write
INTEGER :: unit_nr_prv, handle

CALL mp_sync(tensor%pgrid%mp_comm_2d)
CALL timeset("dbcsr_t_total", handle)
unit_nr_prv = prep_output_unit(unit_nr)

Expand All @@ -2107,6 +2112,7 @@ SUBROUTINE dbcsr_t_batched_contract_finalize(tensor, unit_nr)

CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
DEALLOCATE (tensor%contraction_storage)
CALL mp_sync(tensor%pgrid%mp_comm_2d)
CALL timestop(handle)

END SUBROUTINE
Expand Down

0 comments on commit d507393

Please sign in to comment.