From 1a1531ab2404905cbe54caa5a22192a3dcb9a931 Mon Sep 17 00:00:00 2001
From: Aidan Chalk <3043914+LonelyCat124@users.noreply.github.com>
Date: Fri, 10 Jan 2025 14:45:49 +0000
Subject: [PATCH] Missing coverage for other types of read

---
 .../transformations/scalarization_trans.py    | 17 +++--
 .../scalarization_trans_test.py               | 63 ++++++++++++++++++-
 2 files changed, 70 insertions(+), 10 deletions(-)

diff --git a/src/psyclone/psyir/transformations/scalarization_trans.py b/src/psyclone/psyir/transformations/scalarization_trans.py
index 463a0ef705..1e7f6dceee 100644
--- a/src/psyclone/psyir/transformations/scalarization_trans.py
+++ b/src/psyclone/psyir/transformations/scalarization_trans.py
@@ -106,9 +106,14 @@ def _value_unused_after_loop(sig, node, var_accesses):
             if isinstance(next_access, (CodeBlock, Call, Kern)):
                 return False
 
-            # If next access is an IfBlock then it reads the value.
-            if isinstance(next_access, IfBlock):
-                return False
+            # If next access is in an IfBlock condition then it reads the
+            # value.
+            ancestor_ifblock = next_access.ancestor(IfBlock)
+            if ancestor_ifblock:
+                conditions = ancestor_ifblock.condition.walk(Node)
+                for node in conditions:
+                    if node is next_access:
+                        return False
 
             # If next access has an ancestor WhileLoop, and its in the
             # condition then it reads the value.
@@ -201,8 +206,6 @@ def apply(self, node, options=None):
                 ScalarizationTrans._have_same_unmodified_index(sig,
                                                                var_accesses),
                 potential_targets)
-#        potential_targets = self._find_potential_scalarizable_array_symbols(
-#                node, var_accesses)
 
         # Now we need to check the first access is a write and remove those
         # that aren't.
@@ -211,8 +214,6 @@ def apply(self, node, options=None):
                 ScalarizationTrans._check_first_access_is_write(sig,
                                                                 var_accesses),
                 potential_targets)
-#        potential_targets = self._check_first_access_is_write(
-#                node, var_accesses, potential_targets)
 
         # Check the values written to these arrays are not used after this loop
         finalised_targets = filter(
@@ -220,8 +221,6 @@ def apply(self, node, options=None):
                 ScalarizationTrans._value_unused_after_loop(sig, node,
                                                             var_accesses),
                 potential_targets)
-#        finalised_targets = self._check_valid_following_access(
-#                node, var_accesses, potential_targets)
 
         routine_table = node.ancestor(Routine).symbol_table
         # For each finalised target we can replace them with a scalarized
diff --git a/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py b/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py
index 399e619dc0..aaaa059b1c 100644
--- a/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py
+++ b/src/psyclone/tests/psyir/transformations/scalarization_trans_test.py
@@ -381,12 +381,73 @@ def test_scalarizationtrans_value_unused_after_loop(fortran_reader):
     assert not ScalarizationTrans._value_unused_after_loop(keys[2],
                                                            node,
                                                            var_accesses)
-    # Test b
+    # Test c
     assert var_accesses[keys[3]].var_name == "c"
     assert not ScalarizationTrans._value_unused_after_loop(keys[3],
                                                            node,
                                                            var_accesses)
 
+    # Test being a symbol in a Codeblock counts as used
+    code = '''subroutine test()
+        use my_mod
+        integer :: i
+        integer :: k
+        integer, dimension(1:100) :: arr
+        integer, dimension(1:100) :: b
+        integer, dimension(1:100) :: c
+        integer, dimension(1:100, 1:100) :: d
+
+          do i = 1, 100
+           arr(i) = exp(arr(i))
+           b(i) = arr(i) * 3
+           c(i) = i
+          end do
+          do i = 1, 100
+            print *, arr(i)
+          end do
+        end subroutine test
+        '''
+    psyir = fortran_reader.psyir_from_source(code)
+    node = psyir.children[0].children[0]
+    var_accesses = VariablesAccessInfo(nodes=node.loop_body)
+    keys = list(var_accesses.keys())
+    # Test arr
+    assert var_accesses[keys[1]].var_name == "arr"
+    assert not ScalarizationTrans._value_unused_after_loop(keys[1],
+                                                           node,
+                                                           var_accesses)
+
+    # Test being in an IfBlock condition counts as used.
+    code = '''subroutine test()
+        use my_mod
+        integer :: i
+        integer :: k
+        integer, dimension(1:100) :: arr
+        integer, dimension(1:100) :: b
+        integer, dimension(1:100) :: c
+        integer, dimension(1:100, 1:100) :: d
+
+          do i = 1, 100
+           arr(i) = exp(arr(i))
+           b(i) = arr(i) * 3
+           c(i) = i
+          end do
+          do i = 1, 100
+            if(arr(i) == 1) then
+                print *, b(i)
+            end if
+          end do
+        end subroutine test
+        '''
+    psyir = fortran_reader.psyir_from_source(code)
+    node = psyir.children[0].children[0]
+    var_accesses = VariablesAccessInfo(nodes=node.loop_body)
+    keys = list(var_accesses.keys())
+    # Test arr
+    assert var_accesses[keys[1]].var_name == "arr"
+    assert not ScalarizationTrans._value_unused_after_loop(keys[1],
+                                                           node,
+                                                           var_accesses)
 
 def test_scalarization_trans_apply(fortran_reader, fortran_writer, tmpdir):
     ''' Test the application of the scalarization transformation.'''