Skip to content

Commit

Permalink
add do concurrent to diag manager outer loop
Browse files Browse the repository at this point in the history
  • Loading branch information
uramirez8707 committed Dec 30, 2024
1 parent 9256dfd commit 1e3a044
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 64 deletions.
7 changes: 4 additions & 3 deletions diag_manager/diag_data.F90
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ end subroutine fms_add_attribute

!> @brief gets the type of a variable
!> @return the type of the variable (r4,r8,i4,i8,string)
function get_var_type(var) &
pure function get_var_type(var) &
result(var_type)
class(*), intent(in) :: var !< Variable to get the type for
integer :: var_type !< The variable's type
Expand All @@ -611,8 +611,9 @@ function get_var_type(var) &
type is (character(len=*))
var_type = string
class default
call mpp_error(FATAL, "get_var_type:: The variable does not have a supported type. &
&The supported types are r4, r8, i4, i8 and string.")
! TODO Better error handling
! call mpp_error(FATAL, "get_var_type:: The variable does not have a supported type. &
! &The supported types are r4, r8, i4, i8 and string.")
end select
end function get_var_type

Expand Down
31 changes: 21 additions & 10 deletions diag_manager/fms_diag_axis_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ module fms_diag_axis_object_mod
!! this is the axis ids of the structured axis
CHARACTER(len=:), ALLOCATABLE, private :: set_name !< Name of the axis set. This is to distinguish
!! two axis with the same name
integer , private :: compute_size !< Size of the compute domain

contains

Expand All @@ -191,6 +192,7 @@ module fms_diag_axis_object_mod
PROCEDURE :: get_set_name
PROCEDURE :: has_set_name
PROCEDURE :: is_x_or_y_axis
PROCEDURE :: set_compute_domain
! TO DO:
! Get/has/is subroutines as needed
END TYPE fmsDiagFullAxis_type
Expand Down Expand Up @@ -244,6 +246,10 @@ subroutine register_diag_axis_obj(this, axis_name, axis_data, units, cart_name,
& Currently only r4 and r8 data is supported.")
end select

this%domain_position = CENTER
if (present(domain_position)) this%domain_position = domain_position
call check_if_valid_domain_position(this%domain_position)

this%type_of_domain = NO_DOMAIN
if (present(Domain)) then
if (present(Domain2) .or. present(DomainU)) call mpp_error(FATAL, &
Expand All @@ -257,6 +263,7 @@ subroutine register_diag_axis_obj(this, axis_name, axis_data, units, cart_name,
"Check you diag_axis_init call for axis_name:"//trim(axis_name))
allocate(diagDomain2d_t :: this%axis_domain)
call this%axis_domain%set(Domain2=Domain2)
call this%set_compute_domain(Domain2)
this%type_of_domain = TWO_D_DOMAIN
else if (present(DomainU)) then
allocate(diagDomainUg_t :: this%axis_domain)
Expand All @@ -267,10 +274,6 @@ subroutine register_diag_axis_obj(this, axis_name, axis_data, units, cart_name,
this%tile_count = 1
if (present(tile_count)) this%tile_count = tile_count

this%domain_position = CENTER
if (present(domain_position)) this%domain_position = domain_position
call check_if_valid_domain_position(this%domain_position)

this%direction = 0
if (present(direction)) this%direction = direction
call check_if_valid_direction(this%direction)
Expand Down Expand Up @@ -606,16 +609,15 @@ end subroutine get_global_io_domain

!> @brief Get the length of the axis
!> @return axis length
function get_axis_length(this) &
pure function get_axis_length(this) &
result (axis_length)
class(fmsDiagFullAxis_type), intent(in) :: this !< diag_axis obj
integer :: axis_length

!< If the axis is domain decomposed axis_length will be set to the length for the current PE:
axis_length = this%length
if (allocated(this%axis_domain)) then
axis_length = this%axis_domain%length(this%cart_name, this%domain_position, this%length)
else
axis_length = this%length
if (this%cart_name .eq. "X" .or. this%cart_name .eq. "Y") axis_length = this%compute_size
endif

end function
Expand Down Expand Up @@ -870,13 +872,13 @@ end subroutine fill_subaxis

!> @brief Get the axis length of a subaxis
!> @return the axis length
function axis_length(this) &
pure function axis_length(this) &
result(res)
class(fmsDiagSubAxis_type) , INTENT(IN) :: this !< diag_sub_axis obj
integer :: res

res = this%ending_index - this%starting_index + 1
end function
end function

!> @brief Accesses its member starting_index
!! @return a copy of the starting_index
Expand Down Expand Up @@ -916,6 +918,15 @@ function get_ntiles(this) &
end select
end function get_ntiles

subroutine set_compute_domain(this, Domain)
class(fmsDiagFullAxis_type),INTENT(inout):: this !< Diag_axis obj
type(domain2d), intent(in) :: Domain

if (trim(this%cart_name) == "X") call mpp_get_compute_domain(domain, xsize=this%compute_size, position=this%domain_position)
if (trim(this%cart_name) == "Y") call mpp_get_compute_domain(domain, ysize=this%compute_size, position=this%domain_position)

end subroutine

!> @brief Get the length of a 2D domain
!> @return Length of the 2D domain
function get_length(this, cart_axis, domain_position, global_length) &
Expand Down
2 changes: 1 addition & 1 deletion diag_manager/fms_diag_bbox.F90
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ function set_bounds(this, field_data, lower_i, upper_i, lower_j, upper_j, lower_
end function set_bounds
!> @brief Reset the instance bounding box with the bounds determined from the
!! first three dimensions of the 5D "array" argument
SUBROUTINE reset_bounds_from_array_4D(this, array)
pure SUBROUTINE reset_bounds_from_array_4D(this, array)
CLASS (fmsDiagIbounds_type), INTENT(inout) :: this !< The instance of the bounding box.
class(*), INTENT( in), DIMENSION(:,:,:,:) :: array !< The 4D input array.
this%imin = LBOUND(array,1)
Expand Down
22 changes: 12 additions & 10 deletions diag_manager/fms_diag_field_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ function get_send_data_time(this) &
end function get_send_data_time

!> @brief Prepare the input_data_buffer to do the reduction method
subroutine prepare_data_buffer(this)
pure subroutine prepare_data_buffer(this)
class (fmsDiagField_type) , intent(inout):: this !< The field object

if (.not. this%multiple_send_data) return
Expand All @@ -488,7 +488,7 @@ subroutine prepare_data_buffer(this)
end subroutine prepare_data_buffer

!> @brief Initialize the input_data_buffer
subroutine init_data_buffer(this)
pure subroutine init_data_buffer(this)
class (fmsDiagField_type) , intent(inout):: this !< The field object

if (.not. this%multiple_send_data) return
Expand Down Expand Up @@ -542,7 +542,7 @@ logical function allocate_data_buffer(this, input_data, diag_axis)
allocate_data_buffer = .true.
end function allocate_data_buffer
!> Sets the flag saying that the math functions need to be done
subroutine set_math_needs_to_be_done (this, math_needs_to_be_done)
pure subroutine set_math_needs_to_be_done (this, math_needs_to_be_done)
class (fmsDiagField_type) , intent(inout):: this
logical, intent (in) :: math_needs_to_be_done !< Flag saying that the math functions need to be done
this%math_needs_to_be_done = math_needs_to_be_done
Expand Down Expand Up @@ -714,7 +714,7 @@ end function diag_obj_is_static

!> @brief Determine if the field is a scalar
!! @return .True. if the field is a scalar
function is_scalar (this) result (rslt)
pure function is_scalar (this) result (rslt)
class(fmsDiagField_type), intent(in) :: this !< diag_field object
logical :: rslt
rslt = this%scalar
Expand Down Expand Up @@ -1353,16 +1353,16 @@ end subroutine write_coordinate_attribute

!> @brief Gets a fields data buffer
!! @return a pointer to the data buffer
function get_data_buffer (this) &
pure function get_data_buffer (this) &
result(rslt)
class (fmsDiagField_type), target, intent(in) :: this !< diag field
class(*),dimension(:,:,:,:), pointer :: rslt !< The field's data buffer
class(*),dimension(:,:,:,:), allocatable :: rslt !< The field's data buffer

if (.not. this%data_buffer_is_allocated) &
call mpp_error(FATAL, "The input data buffer for the field:"&
//trim(this%varname)//" was never allocated.")
if (.not. this%data_buffer_is_allocated) then
! TODO Better error handling
endif

rslt => this%input_data_buffer%get_buffer()
rslt = this%input_data_buffer%get_buffer()
end function get_data_buffer


Expand Down Expand Up @@ -1757,6 +1757,8 @@ end function get_starting_compute_domain
pure function get_file_ids(this)
class(fmsDiagField_type), intent(in) :: this
integer, allocatable :: get_file_ids(:) !< Ids of the FMS_diag_files the variable

allocate(get_file_ids(size(this%file_ids)))
get_file_ids = this%file_ids
end function

Expand Down
2 changes: 1 addition & 1 deletion diag_manager/fms_diag_file_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,7 @@ logical function is_time_to_write(this, time_step, output_buffers, diag_fields,
end function is_time_to_write

!> \brief Determine if the current PE has data to write
logical function writing_on_this_pe(this)
pure logical function writing_on_this_pe(this)
class(fmsDiagFileContainer_type), intent(in), target :: this !< The file object

select type(diag_file => this%FMS_diag_file)
Expand Down
13 changes: 6 additions & 7 deletions diag_manager/fms_diag_input_buffer.F90
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ module fms_diag_input_buffer_mod

!> @brief Get the buffer from the input buffer object
!! @return a pointer to the buffer
function get_buffer(this) &
pure function get_buffer(this) &
result(buffer)
class(fmsDiagInputBuffer_t), target, intent(in) :: this !< input buffer object
class(*), pointer :: buffer(:,:,:,:)
class(*), allocatable :: buffer(:,:,:,:)

buffer => this%buffer
buffer = this%buffer
end function get_buffer


Expand Down Expand Up @@ -144,7 +144,7 @@ function allocate_input_buffer_object(this, input_data, axis_ids, diag_axis) &
end function allocate_input_buffer_object

!> @brief Initiliazes an input data buffer and the counter
subroutine init_input_buffer_object(this)
pure subroutine init_input_buffer_object(this)
class(fmsDiagInputBuffer_t), intent(inout) :: this !< input buffer object

select type(buffer=>this%buffer)
Expand Down Expand Up @@ -205,7 +205,7 @@ end function update_input_buffer_object

!> @brief Prepare the input data buffer to do the reduction methods (i.e divide by the number of times
!! send data has been called)
subroutine prepare_input_buffer_object(this, field_info)
pure subroutine prepare_input_buffer_object(this, field_info)
class(fmsDiagInputBuffer_t), intent(inout) :: this !< input buffer object
character(len=*), intent(in) :: field_info !< Field info to append to error message

Expand All @@ -215,8 +215,7 @@ subroutine prepare_input_buffer_object(this, field_info)
type is (real(kind=r8_kind))
input_data = input_data / this%counter(1,1,1,1)
class default
call mpp_error(FATAL, "prepare_input_buffer_object::"//trim(field_info)//&
" has only been implemented for real variables. Contact developers.")
!TODO very error handling
end select
end subroutine prepare_input_buffer_object

Expand Down
45 changes: 27 additions & 18 deletions diag_manager/fms_diag_object.F90
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ subroutine do_buffer_math(this)
logical :: math !< True if the math functions need to be called using the data buffer,
!! False if the math functions were done in accept_data
integer, dimension(:), allocatable :: file_field_ids !< Array of field IDs for a file
class(*), pointer :: input_data_buffer(:,:,:,:)
class(*), allocatable :: input_data_buffer(:,:,:,:)
character(len=128) :: error_string
type(fmsDiagIbounds_type) :: bounds
integer, dimension(:), allocatable :: file_ids !< Array of file IDs for a field
Expand All @@ -736,31 +736,39 @@ subroutine do_buffer_math(this)
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!! In the future, this may be parallelized for offloading
! loop through each field
field_loop: do ifield = 1, size(this%FMS_diag_fields)

field_loop: do concurrent (ifield = 1:size(this%FMS_diag_fields))
diag_field => this%FMS_diag_fields(ifield)

! Skip this field if it was never registered
if(.not. diag_field%is_registered()) cycle
if(DEBUG_SC) call mpp_error(NOTE, "fms_diag_send_complete:: var: "//diag_field%get_varname())
! get files the field is in
allocate (file_ids(size(diag_field%get_file_ids() )))

! Get the the file ids of the file the field is in
file_ids = diag_field%get_file_ids()
math = diag_field%get_math_needs_to_be_done()
! if doing math loop through each file for given field

! If doing math loop through each file the field is in
doing_math: if (size(file_ids) .ge. 1 .and. math) then
! Check if buffer alloc'd
has_input_buff: if (diag_field%has_input_data_buffer()) then
! If the buffer was registered with multiple_send_data calls, prepare the input buffer data
call diag_field%prepare_data_buffer()
input_data_buffer => diag_field%get_data_buffer()
! reset bounds, allocate output buffer, and update it with reduction

! Get the input data buffer
input_data_buffer = diag_field%get_data_buffer()

call bounds%reset_bounds_from_array_4D(input_data_buffer)
call this%allocate_diag_field_output_buffers(input_data_buffer, ifield)
error_string = this%fms_diag_do_reduction(input_data_buffer, ifield, &
diag_field%get_mask(), diag_field%get_weight(), &
bounds, .False., Time=diag_field%get_send_data_time())
call diag_field%init_data_buffer()
if (trim(error_string) .ne. "") call mpp_error(FATAL, "Field:"//trim(diag_field%get_varname()//&
" -"//trim(error_string)))

!TODO Better error handling
!if (trim(error_string) .ne. "") call mpp_error(FATAL, "Field:"//trim(diag_field%get_varname()//&
! " -"//trim(error_string)))
else
call mpp_error(FATAL, "diag_send_complete:: no input buffer allocated for field"//diag_field%get_longname())
!TODO Better error handling
! call mpp_error(FATAL, "diag_send_complete:: no input buffer allocated for field"//diag_field%get_longname())
endif has_input_buff
endif doing_math
call diag_field%set_math_needs_to_be_done(.False.)
Expand Down Expand Up @@ -1277,7 +1285,7 @@ END FUNCTION fms_get_domain2d

!> @brief Gets the length of the axis based on the axis_id
!> @return Axis_length
integer function fms_get_axis_length(this, axis_id)
pure integer function fms_get_axis_length(this, axis_id)
class(fmsDiagObject_type), intent (in) :: this !< The diag object
INTEGER, INTENT(in) :: axis_id !< Axis ID of the axis to the length of

Expand All @@ -1287,8 +1295,9 @@ integer function fms_get_axis_length(this, axis_id)
#else
fms_get_axis_length = 0

if (axis_id < 0 .and. axis_id > this%registered_axis) &
call mpp_error(FATAL, "fms_get_axis_length: The axis_id is not valid")
!TODO Better error handling
!if (axis_id < 0 .and. axis_id > this%registered_axis) &
! call mpp_error(FATAL, "fms_get_axis_length: The axis_id is not valid")

select type (axis => this%diag_axis(axis_id)%axis)
type is (fmsDiagFullAxis_type)
Expand Down Expand Up @@ -1384,7 +1393,7 @@ subroutine dump_diag_obj( filename )

!> @brief Allocates the output buffers of the fields corresponding to the registered variable
!! Input arguments are the field and its ID passed to routine fms_diag_accept_data()
subroutine allocate_diag_field_output_buffers(this, field_data, field_id)
pure subroutine allocate_diag_field_output_buffers(this, field_data, field_id)
class(fmsDiagObject_type), target, intent(inout) :: this !< diag object
class(*), dimension(:,:,:,:), intent(in) :: field_data !< field data
integer, intent(in) :: field_id !< Id of the field data
Expand All @@ -1395,7 +1404,7 @@ subroutine allocate_diag_field_output_buffers(this, field_data, field_id)
integer :: axes_length(4) !< Length of each axis
integer :: i, j !< For looping
class(fmsDiagOutputBuffer_type), pointer :: ptr_diag_buffer_obj !< Pointer to the buffer class
class(DiagYamlFilesVar_type), pointer :: ptr_diag_field_yaml !< Pointer to a field from yaml fields
class(DiagYamlFilesVar_type), allocatable :: ptr_diag_field_yaml !< Pointer to a field from yaml fields
integer, pointer :: axis_ids(:) !< Pointer to indices of axes of the field variable
integer :: var_type !< Stores type of the field data (r4, r8, i4, i8, and string) represented as an integer.
character(len=:), allocatable :: var_name !< Field name to initialize output buffers
Expand Down Expand Up @@ -1430,7 +1439,7 @@ subroutine allocate_diag_field_output_buffers(this, field_data, field_id)

yaml_id = this%FMS_diag_output_buffers(buffer_id)%get_yaml_id()

ptr_diag_field_yaml => diag_yaml%diag_fields(yaml_id)
ptr_diag_field_yaml = diag_yaml%diag_fields(yaml_id)
num_diurnal_samples = ptr_diag_field_yaml%get_n_diurnal() !< Get number of diurnal samples

axes_length = 1
Expand Down
Loading

0 comments on commit 1e3a044

Please sign in to comment.