Skip to content

Commit

Permalink
Tidy up new kernel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes authored Dec 6, 2024
1 parent c928d44 commit 0b31028
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions test/accelerate_tests/base_tests/po_update_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_init(self):
['pr_update', 'ob_update'],
err_msg='PoUpdateKernel does not have the correct functions registered.')

def prepare_arrays(self, wavefield=False):
def prepare_arrays(self):
B = 5 # frame size y
C = 5 # frame size x

Expand Down Expand Up @@ -88,12 +88,7 @@ def prepare_arrays(self, wavefield=False):
for idx in range(D):
probe_denominator[idx] = np.ones((E, F)) * (5 * idx + 2) # + 1j * np.ones((E, F)) * (5 * idx + 2)

if wavefield:
object_array_wavefield = np.empty_like(object_array, dtype=FLOAT_TYPE)
probe_wavefield = np.empty_like(probe, dtype=FLOAT_TYPE)
return addr, object_array, object_array_denominator, object_array_wavefield, probe, exit_wave, probe_denominator, probe_wavefield
else:
return addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator
return addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator

def test_ob_update(self):
# setup
Expand Down Expand Up @@ -249,7 +244,8 @@ def test_ob_update_ML(self):

def test_pr_update_ML_wavefield(self):
# setup
addr, object_array, object_array_denominator, object_array_wavefield, probe, exit_wave, probe_denominator, probe_wavefield = self.prepare_arrays(wavefield=True)
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
probe_wavefield = np.empty_like(probe, dtype=FLOAT_TYPE)

# test
POUK = PoUpdateKernel()
Expand Down Expand Up @@ -290,7 +286,8 @@ def test_pr_update_ML_wavefield(self):

def test_ob_update_ML_wavefield(self):
# setup
addr, object_array, object_array_denominator, object_array_wavefield, probe, exit_wave, probe_denominator, probe_wavefield = self.prepare_arrays(wavefield=True)
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
object_array_wavefield = np.empty_like(object_array, dtype=FLOAT_TYPE)

# test
POUK = PoUpdateKernel()
Expand Down

0 comments on commit 0b31028

Please sign in to comment.