Skip to content

Commit

Permalink
tests: ruffing and cleanup fastsweep, sources, backend, distributed, …
Browse files Browse the repository at this point in the history
…proposals, paripool
  • Loading branch information
hvasbath committed Mar 13, 2024
1 parent 85991f4 commit 5095907
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 52 deletions.
10 changes: 4 additions & 6 deletions beat/fast_sweeping/fast_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
S 0025-5718(04)01678-3
"""

import fast_sweep_ext
import numpy as num
import pytensor
import pytensor.tensor as tt
from pytensor.ifelse import ifelse

import fast_sweep_ext
from pytensor.scan.utils import until

km = 1000.0

Expand Down Expand Up @@ -293,9 +293,7 @@ def upwind(dip_ind, str_ind, StartTimes, slownesses, patch_size):
# xnew = |
# |0.5 * [ a+b+sqrt( 2*f^2*h^2 - (a-b)^2 ) ], |a-b| < f*h
start_new = ifelse(
tt.le(
slownesses[dip_ind, str_ind] * patch_size, tt.abs_(ST_xmin - ST_ymin)
),
tt.le(slownesses[dip_ind, str_ind] * patch_size, tt.abs(ST_xmin - ST_ymin)),
tt.min((ST_xmin, ST_ymin)) + slownesses[dip_ind, str_ind] * patch_size,
(
ST_xmin
Expand Down Expand Up @@ -332,7 +330,7 @@ def loop_upwind(StartTimes, PreviousTimes, err_val, iteration, epsilon):
PreviousTimes = StartTimes.copy()
return (
(StartTimes, PreviousTimes, err_val, iteration + 1),
pytensor.scan_module.until(err_val < epsilon),
until(err_val < epsilon),
)

# while loop until err < epsilon
Expand Down
13 changes: 5 additions & 8 deletions beat/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from beat.utility import get_rotation_matrix


# MTQT constants
pi = num.pi
pi4 = pi / 4.0
Expand Down Expand Up @@ -348,7 +347,6 @@ def extent_source(

@classmethod
def from_kite_source(cls, source, kwargs):

d = dict(
lat=source.lat,
lon=source.lon,
Expand All @@ -362,7 +360,7 @@ def from_kite_source(cls, source, kwargs):
rake=source.rake,
slip=source.slip,
anchor="top",
**kwargs
**kwargs,
)

if hasattr(source, "decimation_factor"):
Expand Down Expand Up @@ -445,7 +443,6 @@ class MTQTSource(gf.SourceWithMagnitude):
)

def __init__(self, **kwargs):

self.R = get_rotation_matrix()
self.roty_pi4 = self.R["y"](-pi4)
self.rotx_pi = self.R["x"](pi)
Expand Down Expand Up @@ -555,7 +552,7 @@ def discretize_basesource(self, store, target=None):
)
return meta.DiscretizedMTSource(
m6s=self.m6[num.newaxis, :] * amplitudes[:, num.newaxis],
**self._dparams_base_repeated(times)
**self._dparams_base_repeated(times),
)

def pyrocko_moment_tensor(self, store=None, target=None):
Expand Down Expand Up @@ -620,7 +617,7 @@ class MTSourceWithMagnitude(gf.SourceWithMagnitude):

def __init__(self, **kwargs):
if "m6" in kwargs:
for (k, v) in zip("mnn mee mdd mne mnd med".split(), kwargs.pop("m6")):
for k, v in zip("mnn mee mdd mne mnd med".split(), kwargs.pop("m6")):
kwargs[k] = float(v)

Source.__init__(self, **kwargs)
Expand Down Expand Up @@ -664,7 +661,7 @@ def discretize_basesource(self, store, target=None):
m6s = self.scaled_m6 * m0
return meta.DiscretizedMTSource(
m6s=m6s[num.newaxis, :] * amplitudes[:, num.newaxis],
**self._dparams_base_repeated(times)
**self._dparams_base_repeated(times),
)

def pyrocko_moment_tensor(self):
Expand All @@ -676,7 +673,7 @@ def pyrocko_event(self, **kwargs):
self,
moment_tensor=self.pyrocko_moment_tensor(),
magnitude=float(mt.moment_magnitude()),
**kwargs
**kwargs,
)

@classmethod
Expand Down
1 change: 0 additions & 1 deletion test/pt_toy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def metrop_select(m1, m2):


def master_process(comm, size, tags, status):

num_workers = size - 1
tasks = range(num_workers)
chain = []
Expand Down
2 changes: 0 additions & 2 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def two_gaussians(x):
self.expected_chain_data[data_key] = num.array(data)

def test_text_chain(self):

textchain = TextChain(dir_path=self.test_dir_path, model=self.PT_test)
textchain.setup(10, 0, overwrite=True)

Expand All @@ -111,7 +110,6 @@ def test_text_chain(self):
)

def test_chain_bin(self):

numpy_chain = NumpyChain(dir_path=self.test_dir_path, model=self.PT_test)
numpy_chain.setup(10, 0, overwrite=True)
print(numpy_chain)
Expand Down
2 changes: 0 additions & 2 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(self, *args, **kwargs):
self.beatpath = project_root

def test_mpi_runner(self):

logger.info("testing")
runner = MPIRunner()
runner.run(self.beatpath + "/test/pt_toy_example.py", n_jobs=self.n_jobs)
Expand All @@ -29,6 +28,5 @@ def test_arg_passing(self):


if __name__ == "__main__":

util.setup_logging("test_distributed", "info")
unittest.main()
47 changes: 22 additions & 25 deletions test/test_fastsweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from time import time

import numpy as num
import theano.tensor as tt
import pytensor.tensor as tt
from pyrocko import util
from theano import function
from pytensor import function

from beat import theanof
from beat import pytensorf
from beat.fast_sweeping import fast_sweep

km = 1000.0
Expand All @@ -31,7 +31,6 @@ def get_slownesses(self):
return 1.0 / velocities

def _numpy_implementation(self):

slownesses = self.get_slownesses()

t0 = time()
Expand All @@ -43,14 +42,13 @@ def _numpy_implementation(self):
self.nuc_x,
self.nuc_y,
)
print("np", numpy_start_times)
# print("np", numpy_start_times)
t1 = time()

logger.info("done numpy fast_sweeping in %f" % (t1 - t0))
return numpy_start_times

def _theano_implementation(self):

def _pytensor_implementation(self):
Slownesses = self.get_slownesses()

slownesses = tt.dmatrix("slownesses")
Expand All @@ -64,22 +62,21 @@ def _theano_implementation(self):

patch_size = tt.cast(self.patch_size / km, "float64")

theano_start_times = fast_sweep.get_rupture_times_theano(
pytensor_start_times = fast_sweep.get_rupture_times_pytensor(
slownesses, patch_size, nuc_x, nuc_y
)

t0 = time()
f = function([slownesses, nuc_x, nuc_y], theano_start_times)
f = function([slownesses, nuc_x, nuc_y], pytensor_start_times)
t1 = time()
theano_start_times = f(Slownesses, self.nuc_x, self.nuc_y)
pytensor_start_times = f(Slownesses, self.nuc_x, self.nuc_y)
t2 = time()

logger.info("Theano compile time %f" % (t1 - t0))
logger.info("done Theano fast_sweeping in %f" % (t2 - t1))
return theano_start_times

def _theano_c_wrapper(self):
logger.info("pytensor compile time %f" % (t1 - t0))
logger.info("done pytensor fast_sweeping in %f" % (t2 - t1))
return pytensor_start_times

def _pytensor_c_wrapper(self):
Slownesses = self.get_slownesses()

slownesses = tt.dvector("slownesses")
Expand All @@ -91,7 +88,7 @@ def _theano_c_wrapper(self):
nuc_y = tt.lscalar("nuc_y")
nuc_y.tag.test_value = self.nuc_y

cleanup = theanof.Sweeper(
cleanup = pytensorf.Sweeper(
self.patch_size / km, self.n_patch_dip, self.n_patch_strike, "c"
)

Expand All @@ -100,13 +97,13 @@ def _theano_c_wrapper(self):
t0 = time()
f = function([slownesses, nuc_y, nuc_x], start_times)
t1 = time()
theano_c_wrap_start_times = f(Slownesses.flatten(), self.nuc_y, self.nuc_x)
print("tc", theano_c_wrap_start_times)
pytensor_c_wrap_start_times = f(Slownesses.flatten(), self.nuc_y, self.nuc_x)
# print("tc", pytensor_c_wrap_start_times)
t2 = time()
logger.info("Theano C wrapper compile time %f" % (t1 - t0))
logger.info("done theano C wrapper fast_sweeping in %f" % (t2 - t1))
print("Theano C wrapper compile time %f" % (t1 - t0))
return theano_c_wrap_start_times
logger.info("pytensor C wrapper compile time %f", (t1 - t0))
logger.info("done pytensor C wrapper fast_sweeping in %f", (t2 - t1))
logger.info("pytensor C wrapper compile time %f", (t1 - t0))
return pytensor_c_wrap_start_times

def _c_implementation(self):
slownesses = self.get_slownesses()
Expand All @@ -121,15 +118,15 @@ def _c_implementation(self):
self.nuc_y,
)
t1 = time()
print("c", c_start_times)
# print("c", c_start_times)
logger.info("done c fast_sweeping in %f" % (t1 - t0))
return c_start_times

def test_differences(self):
np_i = self._numpy_implementation().flatten()
t_i = self._theano_implementation().flatten()
t_i = self._pytensor_implementation().flatten()
c_i = self._c_implementation()
tc_i = self._theano_c_wrapper()
tc_i = self._pytensor_c_wrapper()

num.testing.assert_allclose(np_i, t_i, rtol=0.0, atol=1e-6)
num.testing.assert_allclose(np_i, c_i, rtol=0.0, atol=1e-6)
Expand Down
5 changes: 2 additions & 3 deletions test/test_paripool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as num
from pyrocko import util

from beat import paripool
from beat.parallel import paripool

logger = logging.getLogger("test_paripool")

Expand All @@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs):
self.factors = num.array([0, 1, 2, 3, 2, 1, 0])

def test_pool(self):

featureClass = [[k, 1] for k in self.factors] # list of arguments
p = paripool.paripool(add, featureClass, chunksize=2, nprocs=4, timeout=3)
p = paripool(add, featureClass, chunksize=2, nprocs=4, timeout=3)

ref_values = (self.factors + 1).tolist()
ref_values[3] = None
Expand Down
2 changes: 0 additions & 2 deletions test/test_proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(self, *args, **kwargs):
self.draws = 10

def test_proposals(self):

for proposal in available_proposals():
if proposal in multivariate_proposals:
scale = num.eye(2) * 0.5
Expand All @@ -28,6 +27,5 @@ def test_proposals(self):


if __name__ == "__main__":

util.setup_logging("test_proposals", "info")
unittest.main()
9 changes: 6 additions & 3 deletions test/test_sources.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import unittest
from importlib.util import find_spec

import numpy as num
import pyrocko.moment_tensor as mtm
from numpy.testing import assert_allclose
from pyrocko import util
from pytest import mark

from beat.sources import MTQTSource

Expand All @@ -17,7 +19,6 @@ def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)

def test_MTSourceQT(self):

# from Tape & Tape 2015 Appendix A:
(u, v, kappa, sigma, h) = (
3.0 / 8.0 * pi,
Expand Down Expand Up @@ -58,10 +59,13 @@ def test_MTSourceQT(self):
print("M9 NEED", mt.m9)
print("M9 NWU", mt.m9_nwu)

@mark.skipif(
(find_spec("mtpar") is None), reason="Test needs 'mtpar' to be installed"
)
def test_vs_mtpar(self):
try:
import mtpar
except (ImportError):
except ImportError:
logger.warning(
"This test needs mtpar to be installed: "
"https://github.com/rmodrak/mtpar/"
Expand Down Expand Up @@ -122,6 +126,5 @@ def test_vs_mtpar(self):


if __name__ == "__main__":

util.setup_logging("test_sources", "info")
unittest.main()

0 comments on commit 5095907

Please sign in to comment.