Skip to content

Commit

Permalink
Add derived_types handling : generating the pt on the data, and the p…
Browse files Browse the repository at this point in the history
…t on the field api object
  • Loading branch information
ecossevin committed Apr 25, 2024
1 parent cacb023 commit d1ece0e
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
28 changes: 27 additions & 1 deletion transformations/tests/test_parallel_routine_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,30 @@ def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
assert len(conditional) == 5
for cond in conditional:
assert fgen(cond) in field_delete
breakpoint()

@pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI]))
def test_parallel_routine_dispatch_derived(here, frontend):

source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
routine = source['dispatch_routine']

transformation = ParallelRoutineDispatchTransformation()
transformation.apply(source['dispatch_routine'])

dcls = [fgen(dcl) for dcl in routine.spec.body[-13:-1]]

test_dcls=["REAL(KIND=JPRB), POINTER :: Z_YDVARS_U_T0(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DM(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GELAM_T0(:, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_T0(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_DL(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_V_T0(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_GEOMETRY_GEMU_T0(:, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_Q_T0(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDCPG_PHY0_XYB_RDELP(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DM(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDCPG_DYN0_CTY_EVEL(:, :, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDMF_PHYS_SURF_GSD_VF_PZ0F(:, :)",
"REAL(KIND=JPRB), POINTER :: Z_YDVARS_CVGQ_DL(:, :, :)"]
for dcl in dcls:
assert dcl in test_dcls
Binary file added transformations/transformations/field_index.pkl
Binary file not shown.
63 changes: 61 additions & 2 deletions transformations/transformations/parallel_routine_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from loki.transform import Transformation
from loki import (
FindVariables, DerivedType, SymbolAttributes,
Array, single_variable_declaration, Transformer
Array, single_variable_declaration, Transformer,
BasicType
)
import pickle
import os

__all__ = ['ParallelRoutineDispatchTransformation']

Expand All @@ -26,11 +29,15 @@ def __init__(self):
"KLON", "YDCPG_OPTS%KLON", "YDGEOMETRY%YRDIM%NPROMA",
"KPROMA", "YDDIM%NPROMA", "NPROMA"
]
#TODO : do smthg for opening field_index.pkl
with open(os.getcwd()+"/transformations/transformations/field_index.pkl", 'rb') as fp:
self.map_index = pickle.load(fp)
# CALL FIELD_NEW (YL_ZA, UBOUNDS=[KLON, KFLEVG, KGPBLKS], LBOUNDS=[1, 0, 1], PERSISTENT=.TRUE.)
self.new_calls = []
# IF (ASSOCIATED (YL_ZA)) CALL FIELD_DELETE (YL_ZA)
self.delete_calls = []
self.routine_map_temp = {}
self.routine_map_derived = {}

def transform_subroutine(self, routine, **kwargs):
with pragma_regions_attached(routine):
Expand All @@ -40,6 +47,7 @@ def transform_subroutine(self, routine, **kwargs):
single_variable_declaration(routine)
self.add_temp(routine)
self.add_field(routine)
self.add_derived(routine)
#call add_arrays etc...

def process_parallel_region(self, routine, region):
Expand All @@ -61,11 +69,15 @@ def process_parallel_region(self, routine, region):
region.append(dr_hook_calls[1])

region_map_temp= self.decl_local_array(routine, region)
region_map_derived= self.decl_derived_types(routine, region)

for var_name in region_map_temp:
if var_name not in self.routine_map_temp:
self.routine_map_temp[var_name]=region_map_temp[var_name]

for var_name in region_map_derived:
if var_name not in self.routine_map_derived:
self.routine_map_derived[var_name]=region_map_derived[var_name]


@staticmethod
Expand Down Expand Up @@ -178,4 +190,51 @@ def add_field(self, routine):
routine, cdname='DELETE_TEMPORARIES',
handle=sym.Variable(name='ZHOOK_HANDLE_FIELD_API', scope=routine)
)
routine.body.insert(-2,(dr_hook_calls[0], ir.Comment(text=''), *self.delete_calls, dr_hook_calls[1]))
routine.body.insert(-2,(dr_hook_calls[0], ir.Comment(text=''), *self.delete_calls, dr_hook_calls[1]))

def decl_derived_types(self, routine, region):
region_map_derived = {}
derived = [var for var in FindVariables().visit(region) if var.name_parts[0] in routine.arguments]
for var in derived :

key = f"{routine.variable_map[var.name_parts[0]].type.dtype.name}%{'%'.join(var.name_parts[1:])}"
if key in self.map_index:
value = self.map_index[key]
# Creating the pointer on the data : YL_A
data_name = f"Z_{var.name.replace('%', '_')}"
if "REAL" and "JPRB" in value[0]:
data_type = SymbolAttributes(
dtype=BasicType.REAL, kind=routine.symbol_map['JPRB'],
pointer=True
)
data_dim = value[2] + 1
data_shape = (sym.RangeIndex((None, None)),) * data_dim
ptr_var = sym.Variable(name=data_name, type=data_type, dimensions=data_shape, scope=routine)

else:
raise NotImplementedError("This type isn't implemented yet")

# Creating the pointer on the field api object : YL%FA, YL%F_A...
if routine.variable_map[var.name_parts[0]].type.dtype.name=="MF_PHYS_SURF_TYPE":
# YL%PA becomes YL%F_A
field_name = f"{'%'.join(var.name_parts[:-1])}%F_{var.name_parts[-1][1:]}"
elif routine.variable_map[var.name_parts[0]].type.dtype.name=="FIELD_VARIABLES":
# YL%A becomes YL%FA
field_name = f"{'%'.join(var.name_parts[:-1])}%F{var.name_parts[-1]}"
if var.name_parts[-1]=="P": #YL%FP = YL%FT0
field_name = f"{field_name[-1]}T0"
else:
# YL%A becomes YL%F_A
field_name = f"{'%'.join(var.name_parts[:-1])}%F_{var.name_parts[-1]}"
field_ptr_var = var.clone(name=field_name)
region_map_derived[var.name] = [field_ptr_var, ptr_var]
return(region_map_derived)

def add_derived(self, routine):
ptr_var=()
for value in self.routine_map_derived.values():
dcl = ir.VariableDeclaration(
symbols=(value[1],)
)
ptr_var += (dcl,)
routine.spec.append(ptr_var)

0 comments on commit d1ece0e

Please sign in to comment.