From 1213574c730088baf50d17de5474e150864b382f Mon Sep 17 00:00:00 2001 From: Augustin Bussy Date: Wed, 29 May 2024 13:57:00 +0200 Subject: [PATCH] Allow tunring off DBCSR ACC with env variable --- src/core/dbcsr_config.F | 42 +++++++++++++++++++++++++++++++++---- src/core/dbcsr_lib.F | 6 +++--- src/mm/dbcsr_mm.F | 12 +++++------ src/mm/dbcsr_mm_3d.F | 24 ++++++++++----------- src/mm/dbcsr_mm_cannon.F | 44 +++++++++++++++++++-------------------- src/mm/dbcsr_mm_hostdrv.F | 6 +++--- src/mm/dbcsr_mm_sched.F | 12 +++++------ 7 files changed, 90 insertions(+), 56 deletions(-) diff --git a/src/core/dbcsr_config.F b/src/core/dbcsr_config.F index ddc296aeb63..dd377fe059d 100644 --- a/src/core/dbcsr_config.F +++ b/src/core/dbcsr_config.F @@ -153,6 +153,7 @@ MODULE dbcsr_config SET_PARAMETER_DEFAULT(COMM_THREAD_LOAD, CONF_PAR_INT, 100) SET_PARAMETER_DEFAULT(MM_DENSE, CONF_PAR_LOGICAL,.NOT. has_acc) SET_PARAMETER_DEFAULT(MULTREC_LIMIT, CONF_PAR_INT, 512) + SET_PARAMETER_DEFAULT(TURN_OFF_ACC, CONF_PAR_LOGICAL, .FALSE.) SET_PARAMETER_DEFAULT(ACCDRV_THREAD_BUFFERS, CONF_PAR_INT, 8) TYPE(CONF_PAR_LOGICAL) :: ACCDRV_AVOID_AFTER_BUSY = & CONF_PAR_LOGICAL(name="ACCDRV_AVOID_AFTER_BUSY", val=.FALSE., defval=.FALSE.) @@ -184,6 +185,7 @@ MODULE dbcsr_config PUBLIC :: dbcsr_set_config, dbcsr_get_default_config, dbcsr_print_config PUBLIC :: max_kernel_dim PUBLIC :: get_accdrv_active_device_id, set_accdrv_active_device_id, reset_accdrv_active_device_id + PUBLIC :: use_acc CONTAINS @@ -340,6 +342,7 @@ SUBROUTINE dbcsr_set_config( & comm_thread_load, & mm_dense, & multrec_limit, & + turn_off_acc, & accdrv_thread_buffers, & accdrv_avoid_after_busy, & accdrv_min_flop_process, & @@ -368,6 +371,7 @@ SUBROUTINE dbcsr_set_config( & LOGICAL, INTENT(IN), OPTIONAL :: use_comm_thread INTEGER, INTENT(IN), OPTIONAL :: comm_thread_load LOGICAL, INTENT(IN), OPTIONAL :: mm_dense + LOGICAL, INTENT(IN), OPTIONAL :: turn_off_acc INTEGER, INTENT(IN), OPTIONAL :: multrec_limit, accdrv_thread_buffers LOGICAL, INTENT(IN), OPTIONAL :: accdrv_avoid_after_busy INTEGER, INTENT(IN), OPTIONAL :: accdrv_min_flop_process @@ -389,7 +393,7 @@ SUBROUTINE dbcsr_set_config( & CALL dbcsr_cfg%num_layers_3D%set(num_layers_3D) CALL dbcsr_cfg%use_comm_thread%set(use_comm_thread) CALL dbcsr_cfg%multrec_limit%set(multrec_limit) - CALL dbcsr_cfg%mm_dense%set(mm_dense) + CALL dbcsr_cfg%turn_off_acc%set(turn_off_acc) CALL dbcsr_cfg%accdrv_thread_buffers%set(accdrv_thread_buffers) CALL dbcsr_cfg%accdrv_avoid_after_busy%set(accdrv_avoid_after_busy) CALL dbcsr_cfg%accdrv_min_flop_process%set(accdrv_min_flop_process) @@ -419,9 +423,22 @@ SUBROUTINE dbcsr_set_config( & CALL dbcsr_cfg%comm_thread_load%set(comm_thread_load) CALL dbcsr_cfg%n_stacks%set(nstacks) - CALL dbcsr_cfg%mm_stack_size%set(mm_stack_size) CALL dbcsr_cfg%mm_driver%set(mm_driver) + ! If ACC is turned-off, use the CPU defaults + IF (.NOT. PRESENT(mm_stack_size) .AND. dbcsr_cfg%turn_off_acc%val) THEN + CALL dbcsr_cfg%mm_stack_size%set(1000) + ELSE + CALL dbcsr_cfg%mm_stack_size%set(mm_stack_size) + END IF + + IF (.NOT. PRESENT(mm_stack_size) .AND. dbcsr_cfg%turn_off_acc%val) THEN + CALL dbcsr_cfg%mm_dense%set(.TRUE.) + ELSE + CALL dbcsr_cfg%mm_dense%set(mm_dense) + END IF + + END SUBROUTINE dbcsr_set_config SUBROUTINE dbcsr_get_default_config( & @@ -435,6 +452,7 @@ SUBROUTINE dbcsr_get_default_config( & use_comm_thread, & comm_thread_load, & mm_dense, & + turn_off_acc, & multrec_limit, & accdrv_thread_buffers, & accdrv_avoid_after_busy, & @@ -455,7 +473,7 @@ SUBROUTINE dbcsr_get_default_config( & INTEGER, INTENT(OUT), OPTIONAL :: num_layers_3D LOGICAL, INTENT(OUT), OPTIONAL :: use_comm_thread INTEGER, INTENT(OUT), OPTIONAL :: comm_thread_load - LOGICAL, INTENT(OUT), OPTIONAL :: mm_dense + LOGICAL, INTENT(OUT), OPTIONAL :: mm_dense, turn_off_acc INTEGER, INTENT(OUT), OPTIONAL :: multrec_limit, accdrv_thread_buffers LOGICAL, INTENT(OUT), OPTIONAL :: accdrv_avoid_after_busy INTEGER, INTENT(OUT), OPTIONAL :: accdrv_min_flop_process @@ -478,6 +496,7 @@ SUBROUTINE dbcsr_get_default_config( & IF (PRESENT(comm_thread_load)) comm_thread_load = dbcsr_cfg%comm_thread_load%defval IF (PRESENT(mm_dense)) mm_dense = dbcsr_cfg%mm_dense%defval IF (PRESENT(multrec_limit)) multrec_limit = dbcsr_cfg%multrec_limit%defval + IF (PRESENT(turn_off_acc)) turn_off_acc = dbcsr_cfg%turn_off_acc%defval IF (PRESENT(accdrv_thread_buffers)) accdrv_thread_buffers = dbcsr_cfg%accdrv_thread_buffers%defval IF (PRESENT(accdrv_avoid_after_busy)) accdrv_avoid_after_busy = dbcsr_cfg%accdrv_avoid_after_busy%defval IF (PRESENT(accdrv_min_flop_process)) accdrv_min_flop_process = dbcsr_cfg%accdrv_min_flop_process%defval @@ -607,7 +626,12 @@ SUBROUTINE dbcsr_print_config(unit_nr) END IF END BLOCK - IF (has_acc) THEN + IF (dbcsr_cfg%turn_off_acc%val) THEN + WRITE (UNIT=unit_nr, FMT='(1X,A,T81,A4)') & + "DBCSR| ACC is turned off: only CPU is used", dbcsr_cfg%turn_off_acc%print_source() + END IF + + IF (use_acc()) THEN WRITE (UNIT=unit_nr, FMT='(1X,A,T70,I11)') & "DBCSR| ACC: Number of devices/node", dbcsr_acc_get_ndevices() WRITE (UNIT=unit_nr, FMT='(1X,A,T70,I11,A4)') & @@ -671,4 +695,14 @@ SUBROUTINE reset_accdrv_active_device_id() accdrv_active_device_id = default_accdrv_active_device_id END SUBROUTINE reset_accdrv_active_device_id + FUNCTION use_acc() + LOGICAL :: use_acc + + IF (has_acc .AND. .NOT. dbcsr_cfg%turn_off_acc%val) THEN + use_acc = .TRUE. + ELSE + use_acc = .FALSE. + END IF + END FUNCTION use_acc + END MODULE dbcsr_config diff --git a/src/core/dbcsr_lib.F b/src/core/dbcsr_lib.F index 293be20e73d..401abe931da 100644 --- a/src/core/dbcsr_lib.F +++ b/src/core/dbcsr_lib.F @@ -15,7 +15,7 @@ MODULE dbcsr_lib USE dbcsr_config, ONLY: set_accdrv_active_device_id, & reset_accdrv_active_device_id, & dbcsr_set_config, & - has_acc + use_acc USE dbcsr_kinds, ONLY: int_1_size, & int_2_size, & int_4_size, & @@ -213,7 +213,7 @@ SUBROUTINE dbcsr_init_lib_pre(mp_comm, io_unit, accdrv_active_device_id) #endif ! Initialize Acc and set active device - IF (has_acc) THEN + IF (use_acc()) THEN IF (PRESENT(accdrv_active_device_id)) THEN CALL set_accdrv_active_device_id(accdrv_active_device_id) ELSEIF (dbcsr_acc_get_ndevices() > 0) THEN @@ -313,7 +313,7 @@ SUBROUTINE dbcsr_finalize_lib() #endif ! Reset Acc ID CALL reset_accdrv_active_device_id() - IF (has_acc) THEN + IF (use_acc()) THEN CALL acc_finalize() END IF diff --git a/src/mm/dbcsr_mm.F b/src/mm/dbcsr_mm.F index ba1d29ed204..b3d475310e6 100644 --- a/src/mm/dbcsr_mm.F +++ b/src/mm/dbcsr_mm.F @@ -26,7 +26,7 @@ MODULE dbcsr_mm USE dbcsr_config, ONLY: dbcsr_cfg, & dbcsr_set_config, & default_resize_factor, & - has_acc + use_acc USE dbcsr_data_methods, ONLY: dbcsr_data_set_size_referenced, & dbcsr_scalar_are_equal, & dbcsr_scalar_one, & @@ -153,7 +153,7 @@ SUBROUTINE dbcsr_multiply_lib_init() ! Each thread has its own working-matrix and its own mempool ALLOCATE (memtype_product_wm(ithread)%p) - CALL dbcsr_memtype_setup(memtype_product_wm(ithread)%p, has_pool=dbcsr_cfg%use_mempools_cpu%val .OR. has_acc) + CALL dbcsr_memtype_setup(memtype_product_wm(ithread)%p, has_pool=dbcsr_cfg%use_mempools_cpu%val .OR. use_acc()) CALL dbcsr_mempool_limit_capacity(memtype_product_wm(ithread)%p%pool, capacity=MAX(1, dbcsr_cfg%num_layers_3D%val)) END SUBROUTINE dbcsr_multiply_lib_init @@ -438,7 +438,7 @@ SUBROUTINE dbcsr_multiply_generic(transa, transb, & CALL array_nullify(dense_row_sizes) ! Reset GPU errors - IF (has_acc) THEN + IF (use_acc()) THEN CALL dbcsr_acc_clear_errors() END IF @@ -447,12 +447,12 @@ SUBROUTINE dbcsr_multiply_generic(transa, transb, & ! give any performance benefit) CALL check_openmpi_rma() - use_mempools = dbcsr_cfg%use_mempools_cpu%val .OR. has_acc + use_mempools = dbcsr_cfg%use_mempools_cpu%val .OR. use_acc() ! setup driver-dependent memory-types and their memory-pools --------------- ! the ab_buffers are shared by all threads - IF (has_acc) THEN + IF (use_acc()) THEN IF (.NOT. acc_stream_associated(stream_1)) THEN CALL acc_stream_create(stream_1, "MemCpy (odd ticks)") CALL acc_stream_create(stream_2, "MemCpy (even ticks)") @@ -616,7 +616,7 @@ SUBROUTINE dbcsr_multiply_generic(transa, transb, & END IF ab_dense = use_dense_mult ! Use memory pools when no dense - IF (.NOT. has_acc) THEN + IF (.NOT. use_acc()) THEN CALL dbcsr_memtype_setup(memtype_abpanel_1, has_pool=.NOT. ab_dense .AND. use_mempools, mpi=.TRUE.) CALL dbcsr_memtype_setup(memtype_abpanel_2, has_pool=.NOT. ab_dense .AND. use_mempools, mpi=.TRUE.) END IF diff --git a/src/mm/dbcsr_mm_3d.F b/src/mm/dbcsr_mm_3d.F index 8974bd6aa40..dd3f9a03618 100644 --- a/src/mm/dbcsr_mm_3d.F +++ b/src/mm/dbcsr_mm_3d.F @@ -28,7 +28,7 @@ MODULE dbcsr_mm_3d dbcsr_data_clear, & dbcsr_data_set USE dbcsr_config, ONLY: dbcsr_cfg, & - has_acc + use_acc USE dbcsr_data_methods, ONLY: & dbcsr_data_clear_pointer, dbcsr_data_ensure_size, dbcsr_data_exists, & dbcsr_data_get_memory_type, dbcsr_data_get_size, dbcsr_data_get_size_referenced, & @@ -1436,7 +1436,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & IF (ASSOCIATED(memtype_abpanel_2%pool)) & CALL dbcsr_mempool_limit_capacity(memtype_abpanel_2%pool, & capacity=2) - IF (has_acc) THEN + IF (use_acc()) THEN ! enumerate the blocksizes to keep the following 2D-arrays small. CALL enumerate_blk_sizes(matrix_right%row_blk_size%low%data, & dbcsr_max_row_size(matrix_right), & @@ -1544,7 +1544,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & ! ! Evaluate sizes for workspaces size_guess_init = 1 - IF (.NOT. keep_sparsity .AND. has_acc) THEN + IF (.NOT. keep_sparsity .AND. use_acc()) THEN size_guess_init = product_matrix_size_guess(matrix_left, matrix_right, product_matrix, & left_max_data_size, right_max_data_size, & left_col_nimages, right_row_nimages, & @@ -1733,7 +1733,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & IF (is_not_comm) THEN ! Right IF (do_comm_right(icol3D)) THEN - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_right", handle2) CALL acc_event_synchronize(right_buffer_p%data%d%acc_ready) CALL timestop(handle2) @@ -1749,7 +1749,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & END IF ! Left IF (do_comm_left(irow3D)) THEN - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_left", handle2) CALL acc_event_synchronize(left_buffer_p%data%d%acc_ready) CALL timestop(handle2) @@ -1967,7 +1967,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & ileft_buffer_calc = MIN(ileft_buffer_calc, nbuffers_norms) ! check if right matrix was already initialized IF (.NOT. right_buffer_p%matrix%valid) THEN - IF (has_acc) CALL dbcsr_data_host2dev(right_buffer_p%data) + IF (use_acc()) CALL dbcsr_data_host2dev(right_buffer_p%data) ! Repoint indices of matrices CALL make_meta(right_buffer_p, & right_row_total_nimages, & @@ -1985,7 +1985,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & right_norms(:, iright_buffer_calc), & k_sizes, n_sizes(icol3D)%sizes) END IF - IF (has_acc) THEN + IF (use_acc()) THEN CALL acc_transpose_blocks(right_buffer_p%matrix, & right_buffer_p%trs_stackbuf, & k_sizes, n_sizes(icol3D)%sizes, & @@ -1996,7 +1996,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & END IF ! check if left matrix was already initialized IF (.NOT. left_buffer_p%matrix%valid) THEN - IF (has_acc) CALL dbcsr_data_host2dev(left_buffer_p%data) + IF (use_acc()) CALL dbcsr_data_host2dev(left_buffer_p%data) ! Repoint indices of matrices CALL make_meta(left_buffer_p, & left_col_total_nimages, & @@ -2012,7 +2012,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & END IF END IF ! Wait for left and right buffers transfer to device before proceeding - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_sync_h2d", handle2) CALL acc_device_synchronize() CALL timestop(handle2) @@ -2224,7 +2224,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & CALL dbcsr_data_release(right_buffers(v_ki)%data) NULLIFY (right_buffers(v_ki)%matrix%index) CALL dbcsr_release(right_buffers(v_ki)%matrix) - IF (has_acc) THEN + IF (use_acc()) THEN CALL dbcsr_data_clear_pointer(right_buffers(v_ki)%trs_stackbuf) IF (right_buffers(v_ki)%trs_stackbuf%d%memory_type%acc_devalloc) THEN CALL acc_event_destroy(right_buffers(v_ki)%trs_stackbuf%d%acc_ready) @@ -2235,7 +2235,7 @@ SUBROUTINE multiply_3D(imgdist_left, imgdist_right, & DEALLOCATE (left_buffers, right_buffers) DEALLOCATE (do_comm_left, do_comm_right) DEALLOCATE (right_vcol, left_vrow) - IF (has_acc) THEN + IF (use_acc()) THEN DEALLOCATE (row_blk_sizes2enum, enum2row_blk_sizes) DEALLOCATE (col_blk_sizes2enum, enum2col_blk_sizes) END IF @@ -2544,7 +2544,7 @@ SUBROUTINE buffer_init(buffer, data_type, & CALL dbcsr_data_init(buffer%data_before_resize) CALL dbcsr_data_new(buffer%data_before_resize, data_type, memory_type=data_memory_type) END IF - new_trs_stackbuf = PRESENT(trs_memory_type) .AND. has_acc + new_trs_stackbuf = PRESENT(trs_memory_type) .AND. use_acc() ! IF (buffer%is_valid) THEN ! Invalid buffers if data_type is different diff --git a/src/mm/dbcsr_mm_cannon.F b/src/mm/dbcsr_mm_cannon.F index fffd7f5d109..53b1f7faf7c 100644 --- a/src/mm/dbcsr_mm_cannon.F +++ b/src/mm/dbcsr_mm_cannon.F @@ -32,7 +32,7 @@ MODULE dbcsr_mm_cannon dbcsr_block_transpose_aa, & dbcsr_block_transpose USE dbcsr_config, ONLY: dbcsr_cfg, & - has_acc + use_acc USE dbcsr_data_methods, ONLY: & dbcsr_data_clear_pointer, dbcsr_data_ensure_size, dbcsr_data_get_size, & dbcsr_data_get_size_referenced, dbcsr_data_hold, dbcsr_data_host2dev, dbcsr_data_init, & @@ -1172,7 +1172,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & !! ! Evaluate sizes for workspaces IF (.NOT. keep_sparsity) THEN - IF (has_acc) THEN + IF (use_acc()) THEN size_guess_init = product_matrix_size_guess(left_set%mats(1, 1), right_set%mats(1, 1), product_matrix, & left_max_nze, right_max_nze, & left_col_nimages, right_row_nimages, & @@ -1214,7 +1214,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & IF (ASSOCIATED(memtype_abpanel_2%pool)) & CALL dbcsr_mempool_limit_capacity(memtype_abpanel_2%pool, & capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1) - IF (has_acc) THEN + IF (use_acc()) THEN ! enumerate the blocksizes to keep the following 2D-arrays small. CALL enumerate_blk_sizes(right_set%mats(1, 1)%row_blk_size%low%data, & dbcsr_max_row_size(right_set%mats(1, 1)), & @@ -1319,7 +1319,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & CALL dbcsr_data_new(right_data_rp, data_type) ! Setup transpose stackbuffers - IF (has_acc) THEN + IF (use_acc()) THEN CALL dbcsr_data_init(trs_stackbuf_1) CALL dbcsr_data_init(trs_stackbuf_2) CALL dbcsr_data_new(trs_stackbuf_1, data_type=dbcsr_type_int_4, & @@ -1422,7 +1422,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & vprow_shift=metronome + right_row_nimages, & shifting='R') ! - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_right", handle3) CALL acc_event_synchronize(right_buffer_comm%mats(v_ki + 1, 1)%data_area%d%acc_ready) CALL timestop(handle3) @@ -1545,7 +1545,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & vpcol_shift=metronome + left_col_nimages, & shifting='L') ! - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_left", handle3) CALL acc_event_synchronize(left_buffer_comm%mats(1, v_ki + 1)%data_area%d%acc_ready) CALL timestop(handle3) @@ -1617,7 +1617,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & array_data(right_buffer_calc%mats(v_ki_right, 1)%local_rows), & k_sizes) ! - IF (has_acc) THEN + IF (use_acc()) THEN CALL dbcsr_data_host2dev(left_buffer_calc%mats(1, v_ki_left)%data_area) CALL dbcsr_data_host2dev(right_buffer_calc%mats(v_ki_right, 1)%data_area) CALL acc_transpose_blocks(right_buffer_calc%mats(v_ki_right, 1), trs_stackbuf_calc, & @@ -1637,7 +1637,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & END IF ! Wait for left and right buffers transfer to device before proceeding - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_sync_h2d", handle2) CALL acc_device_synchronize() CALL timestop(handle2) @@ -1716,7 +1716,7 @@ SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & CALL m_memory(mem) max_memory = MAX(max_memory, REAL(mem)) - IF (has_acc) THEN + IF (use_acc()) THEN CALL dbcsr_data_release(trs_stackbuf_1) CALL dbcsr_data_release(trs_stackbuf_2) DEALLOCATE (row_blk_sizes2enum, enum2row_blk_sizes) @@ -2121,7 +2121,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & !! ! Evaluate sizes for workspaces IF (.NOT. keep_sparsity) THEN - IF (has_acc) THEN + IF (use_acc()) THEN size_guess_init = product_matrix_size_guess(left_set%mats(1, 1), right_set%mats(1, 1), product_matrix, & left_max_nze, right_max_nze, & left_col_nimages, right_row_nimages, & @@ -2163,7 +2163,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & IF (ASSOCIATED(memtype_abpanel_2%pool)) & CALL dbcsr_mempool_limit_capacity(memtype_abpanel_2%pool, & capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1) - IF (has_acc) THEN + IF (use_acc()) THEN ! enumerate the blocksizes to keep the following 2D-arrays small. CALL enumerate_blk_sizes(right_set%mats(1, 1)%row_blk_size%low%data, & dbcsr_max_row_size(right_set%mats(1, 1)), & @@ -2226,7 +2226,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & IF (stat .NE. 0) otf_filtering = .FALSE. END IF - IF (has_acc .and. otf_filtering) THEN + IF (use_acc() .and. otf_filtering) THEN max_nblks = MAX(left_max_nblks, right_max_nblks) CALL dbcsr_data_init(normsbuf) CALL dbcsr_data_new(normsbuf, data_type=dbcsr_type_real_4, & @@ -2295,7 +2295,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & CALL dbcsr_data_new(right_data_rp, data_type) ! Setup transpose stackbuffers - IF (has_acc) THEN + IF (use_acc()) THEN CALL dbcsr_data_init(trs_stackbuf_1) CALL dbcsr_data_init(trs_stackbuf_2) CALL dbcsr_data_new(trs_stackbuf_1, data_type=dbcsr_type_int_4, & @@ -2389,7 +2389,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & k_sizes) ! ! Transfer left and right buffers from host to GPU memory - IF (has_acc) THEN + IF (use_acc()) THEN IF (copy_left) THEN ! copy left buffer images to device DO v_ki = 1, left_col_nimages @@ -2493,7 +2493,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & vprow_shift=metronome + right_row_nimages, & shifting='R') ! - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_right", handle3) CALL acc_event_synchronize(right_buffer_comm%mats(v_ki + 1, 1)%data_area%d%acc_ready) CALL timestop(handle3) @@ -2522,7 +2522,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & END IF ! CALL timeset(routineN//"_metrocomm2", handle2) - IF (.not. has_acc) THEN + IF (.not. use_acc()) THEN CALL dbcsr_irecv_any(right_data_rp, right_recv_p, & grp, right_data_rr(v_ki + 1), tag=right_src_vrow) ELSE @@ -2538,7 +2538,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & END IF CALL mp_irecv(right_index_rp, right_recv_p, & grp, right_index_rr(v_ki + 1), tag=right_src_vrow) - IF (.not. has_acc) THEN + IF (.not. use_acc()) THEN CALL dbcsr_isend_any(right_data_sp, right_send_p, & grp, right_data_sr(v_ki + 1), tag=right_dst_vrow) ELSE @@ -2620,7 +2620,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & vpcol_shift=metronome + left_col_nimages, & shifting='L') ! - IF (has_acc) THEN + IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_left", handle3) CALL acc_event_synchronize(left_buffer_comm%mats(1, v_ki + 1)%data_area%d%acc_ready) CALL timestop(handle3) @@ -2649,7 +2649,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & END IF ! CALL timeset(routineN//"_metrocomm4", handle2) - IF (.not. has_acc) THEN + IF (.not. use_acc()) THEN CALL dbcsr_irecv_any(left_data_rp, left_recv_p, & grp, left_data_rr(v_ki + 1), tag=left_src_vcol) ELSE @@ -2665,7 +2665,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & END IF CALL mp_irecv(left_index_rp, left_recv_p, & grp, left_index_rr(v_ki + 1), tag=left_src_vcol) - IF (.not. has_acc) THEN + IF (.not. use_acc()) THEN CALL dbcsr_isend_any(left_data_sp, left_send_p, & grp, left_data_sr(v_ki + 1), tag=left_dst_vcol) ELSE @@ -2693,7 +2693,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & ! Do multiplication ! If no GPU backend, calculate norms on the CPU - IF (otf_filtering .and. .not. has_acc) THEN + IF (otf_filtering .and. .not. use_acc()) THEN left_norms(:) = huge_norm right_norms(:) = huge_norm CALL calculate_norms(right_buffer_calc%mats(v_ki_right, 1), & @@ -2776,7 +2776,7 @@ SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, & CALL m_memory(mem) max_memory = MAX(max_memory, REAL(mem)) - IF (has_acc) THEN + IF (use_acc()) THEN CALL dbcsr_data_release(trs_stackbuf_1) CALL dbcsr_data_release(trs_stackbuf_2) DEALLOCATE (row_blk_sizes2enum, enum2row_blk_sizes) diff --git a/src/mm/dbcsr_mm_hostdrv.F b/src/mm/dbcsr_mm_hostdrv.F index 7089d2f3a58..83e0aa3c2e9 100644 --- a/src/mm/dbcsr_mm_hostdrv.F +++ b/src/mm/dbcsr_mm_hostdrv.F @@ -10,7 +10,7 @@ MODULE dbcsr_mm_hostdrv !! Stacks of small matrix multiplications USE dbcsr_config, ONLY: dbcsr_cfg, & - has_acc, & + use_acc, & mm_driver_blas, & mm_driver_matmul, & mm_driver_smm, & @@ -108,7 +108,7 @@ SUBROUTINE dbcsr_mm_hostdrv_process(this, left, right, params, stack_size, & INTEGER :: error_handle, sp REAL(KIND=dp) :: rnd - IF (has_acc) & !for cpu-only runs this is called too often + IF (use_acc()) & !for cpu-only runs this is called too often CALL timeset(routineN, error_handle) success = .TRUE. !host driver never fails...hopefully @@ -236,7 +236,7 @@ SUBROUTINE dbcsr_mm_hostdrv_process(this, left, right, params, stack_size, & DBCSR_ABORT("Invalid multiplication driver") END SELECT - IF (has_acc) & !for cpu-only runs this is called too often + IF (use_acc()) & !for cpu-only runs this is called too often CALL timestop(error_handle) END SUBROUTINE dbcsr_mm_hostdrv_process diff --git a/src/mm/dbcsr_mm_sched.F b/src/mm/dbcsr_mm_sched.F index 27c3e350b67..78843e2fa19 100644 --- a/src/mm/dbcsr_mm_sched.F +++ b/src/mm/dbcsr_mm_sched.F @@ -20,7 +20,7 @@ MODULE dbcsr_mm_sched USE dbcsr_block_operations, ONLY: dbcsr_data_clear USE dbcsr_config, ONLY: dbcsr_cfg, & default_resize_factor, & - has_acc + use_acc USE dbcsr_data_methods, ONLY: dbcsr_data_ensure_size, & dbcsr_data_get_size USE dbcsr_kinds, ONLY: int_4, int_8, real_8 @@ -194,7 +194,7 @@ SUBROUTINE dbcsr_mm_sched_init(this, product_wm, nlayers, keep_product_data) CALL dbcsr_mm_hostdrv_init(this%hostdrv, product_wm) - IF (has_acc) & + IF (use_acc()) & CALL dbcsr_mm_accdrv_init(this%accdrv, product_wm, nlayers, keep_product_data) CALL timestop(handle) @@ -215,7 +215,7 @@ SUBROUTINE dbcsr_mm_sched_finalize(this) CALL ensure_product_wm_cleared(this) !CALL dbcsr_mm_hostdrv_finalize(this%hostdrv) ! not needed - IF (has_acc) & + IF (use_acc()) & CALL dbcsr_mm_accdrv_finalize(this%accdrv) CALL timestop(handle) @@ -232,7 +232,7 @@ SUBROUTINE dbcsr_mm_sched_dev2host_init(this) CALL timeset(routineN, handle) - IF (has_acc) & + IF (use_acc()) & CALL dbcsr_mm_accdrv_dev2host_init(this%accdrv) CALL timestop(handle) @@ -258,7 +258,7 @@ SUBROUTINE dbcsr_mm_sched_barrier() !CALL dbcsr_mm_hostdrv_barrier(this%hostdrv) ! not needed - IF (has_acc) & + IF (use_acc()) & CALL dbcsr_mm_accdrv_barrier() END SUBROUTINE dbcsr_mm_sched_barrier @@ -324,7 +324,7 @@ SUBROUTINE dbcsr_mm_sched_process(this, left, right, stack_data, & flop_per_entry = INT(2, KIND=int_8)*stack_descr%max_m*stack_descr%max_n*stack_descr%max_k total_flop = stack_fillcount*flop_per_entry - IF (has_acc .AND. & + IF (use_acc() .AND. & flop_per_entry > dbcsr_cfg%accdrv_min_flop_process%val .AND. & (.NOT. this%avoid_accdrv) .AND. & (stack_descr%defined_mnk .OR. dbcsr_cfg%accdrv_do_inhomogenous%val)) THEN