From 78f12410013a01334a2943e0aebb04e3e12c1087 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Sat, 5 Oct 2024 04:08:19 +0000 Subject: [PATCH] Transformations: Recurse to array dimes and types in ResolveAssoc --- loki/transformations/sanitise.py | 11 ++++++++++- loki/transformations/tests/test_sanitise.py | 8 +++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/loki/transformations/sanitise.py b/loki/transformations/sanitise.py index fa8b22f14..33229f05a 100644 --- a/loki/transformations/sanitise.py +++ b/loki/transformations/sanitise.py @@ -110,7 +110,16 @@ def map_variable_symbol(self, expr, *args, **kwargs): def map_array(self, expr, *args, **kwargs): """ Special case for arrys: we need to preserve the dimensions """ new = self.map_variable_symbol(expr, *args, **kwargs) - return new.clone(dimensions=expr.dimensions) + + # Recurse over the type's shape + _type = expr.type + if expr.type.shape: + new_shape = self.rec(expr.type.shape, *args, **kwargs) + _type = expr.type.clone(shape=new_shape) + + # Recurse over array dimensions + new_dims = self.rec(expr.dimensions, *args, **kwargs) + return new.clone(dimensions=new_dims, type=_type) map_scalar = map_variable_symbol map_deferred_type_symbol = map_variable_symbol diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index 97426f677..f7881f9da 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -114,8 +114,10 @@ def test_transform_associates_array_call(frontend): integer :: i real :: local_var + real, allocatable :: local_arr(:) - associate (some_array => some_obj%some_array) + associate (some_array => some_obj%some_array, a => some_obj%a) + allocate(local_arr(a%n)) do i=1, 5 call another_routine(i, n=some_array(i)%n) @@ -131,6 +133,7 @@ def test_transform_associates_array_call(frontend): call = FindNodes(ir.CallStatement).visit(routine.body)[0] assert call.kwarguments[0][1] == 'some_array(i)%n' assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED + assert routine.variable_map['local_arr'].type.shape == ('a%n',) # Now apply the association resolver resolve_associates(routine) @@ -142,6 +145,9 @@ def test_transform_associates_array_call(frontend): assert call.kwarguments[0][1].scope == routine assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED + # Test the special case of shapes derived from allocations + assert routine.variable_map['local_arr'].type.shape == ('some_obj%a%n',) + @pytest.mark.parametrize('frontend', available_frontends( xfail=[(OMNI, 'OMNI does not handle missing type definitions')]