From 3c243d3935addde010d34e57bac510c45f9ae36d Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Sun, 26 Mar 2023 22:16:27 +0100 Subject: [PATCH] Retain source object for injected statement functions --- loki/frontend/util.py | 10 +++++++--- tests/test_sourcefile.py | 1 + tests/test_subroutine.py | 1 + 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index 9415bee71..76c91ecd4 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -205,7 +205,7 @@ def inject_statement_functions(routine): def create_stmt_func(assignment): arguments = assignment.lhs.dimensions variable = assignment.lhs.clone(dimensions=None) - return StatementFunction(variable, arguments, assignment.rhs, variable.type) + return StatementFunction(variable, arguments, assignment.rhs, variable.type, source=assignment.source) def create_type(stmt_func): name = str(stmt_func.variable) @@ -253,7 +253,9 @@ def create_type(stmt_func): if variable.name.lower() in stmt_funcs: if isinstance(variable, Array): parameters = variable.dimensions - expr_map_spec[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters) + expr_map_spec[variable] = InlineCall( + variable.clone(dimensions=None), parameters=parameters, source=variable.source + ) elif not isinstance(variable, ProcedureSymbol): expr_map_spec[variable] = variable.clone() expr_map_body = {} @@ -261,7 +263,9 @@ def create_type(stmt_func): if variable.name.lower() in stmt_funcs: if isinstance(variable, Array): parameters = variable.dimensions - expr_map_body[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters) + expr_map_body[variable] = InlineCall( + variable.clone(dimensions=None), parameters=parameters, source=variable.source + ) elif not isinstance(variable, ProcedureSymbol): expr_map_body[variable] = variable.clone() diff --git a/tests/test_sourcefile.py b/tests/test_sourcefile.py index 4c4766a97..64714fc19 100644 --- a/tests/test_sourcefile.py +++ b/tests/test_sourcefile.py @@ -216,6 +216,7 @@ def test_sourcefile_cpp_stmt_func(here, frontend): assert isinstance(var, ProcedureSymbol) assert isinstance(var.type.dtype, ProcedureType) assert var.type.dtype.procedure is decl + assert decl.source is not None # Generate code and compile filepath = here/f'{module.name}.f90' diff --git a/tests/test_subroutine.py b/tests/test_subroutine.py index 726582d6e..fb08c049d 100644 --- a/tests/test_subroutine.py +++ b/tests/test_subroutine.py @@ -1520,6 +1520,7 @@ def test_subroutine_stmt_func(here, frontend): assert isinstance(var, ProcedureSymbol) assert isinstance(var.type.dtype, ProcedureType) assert var.type.dtype.procedure is stmt_func_decls[var] + assert stmt_func_decls[var].source is not None # Make sure this produces the correct result filepath = here/f'{routine.name}.f90'