Skip to content

Commit

Permalink
tests: test_ffi_gfstacking cleanup and working
Browse files Browse the repository at this point in the history
  • Loading branch information
hvasbath committed Mar 13, 2024
1 parent 9bbf712 commit 85991f4
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 132 deletions.
3 changes: 1 addition & 2 deletions beat/ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def stack_all(
matrix : size (ntargets, nsamples)
option : tensor.batched_dot(sd.dimshuffle((1,0,2)), u).sum(axis=0)
"""

if targetidxs is None:
raise ValueError("Target indexes have to be defined!")

Expand Down Expand Up @@ -695,7 +694,7 @@ def stack_all(
s_st_floor_rt_floor,
],
axis=1,
).T #
) #

else:
raise NotImplementedError(
Expand Down
221 changes: 91 additions & 130 deletions test/test_ffi_gfstacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,40 @@
project_dir = Path("/home/vasyurhm/BEATS/LaquilaJointPonlyUPDATE_wide_kin3_v2")


# @mark.skip("Needs version 2.0.0 compliant setup")
def array_to_traces(synthetics, reference_times, deltat, targets, location_tag=None):
synth_traces = []
for i, target in enumerate(targets):
tr = trace.Trace(ydata=synthetics[i, :], tmin=reference_times[i], deltat=deltat)

tr.set_codes(*target.codes)
if location_tag is not None:
tr.set_location(location_tag)

synth_traces.append(tr)

return synth_traces


def get_max_relative_and_absolute_errors(a, b):
abs_err = num.abs(a - b).max()
rel_err = (num.abs((a - b) / b).max(),)
print("absolute", abs_err)
print("relative", rel_err)
return abs_err, rel_err


def assert_traces(ref_traces, test_traces):
assert len(ref_traces) == len(test_traces)

for ref_trace, test_trace in zip(ref_traces, test_traces):
num.testing.assert_allclose(
ref_trace.ydata, test_trace.ydata, rtol=5e-6, atol=5e-6
)
num.testing.assert_allclose(
ref_trace.tmin, test_trace.tmin, rtol=1e-3, atol=1e-3
)


@mark.skipif(project_dir.is_dir() is False, reason="Needs project dir")
def test_gf_stacking():
# general
Expand Down Expand Up @@ -59,14 +92,9 @@ def test_gf_stacking():
+ time_shift
)

print(starttimes)

# defining distributed slip values for slip parallel and perpendicular directions
uparr = num.ones((npdip, npstrike)) * 2.0
# uparr[1:3, 3:7] = 1.5
uperp = num.zeros((npdip, npstrike))
# uperp[0,0] = 1.
# uperp[3,9] = 1.
uperp[1:3, 3:7] = 1.0

# define rupture durations on each patch
Expand All @@ -79,8 +107,6 @@ def test_gf_stacking():
"velocities": velocities.ravel(),
}

print("fault parameters", slips)

# update patches with distributed slip and STF values
for comp in components:
patches = fault.get_subfault_patches(0, datatype="seismic", component=comp)
Expand All @@ -96,9 +122,6 @@ def test_gf_stacking():

# synthetics generation
engine = gf.LocalEngine(store_superdirs=store_superdirs)

patchidx = fault.patchmap(index=0, dipidx=nuc_dip_idx, strikeidx=nuc_strike_idx) # noqa: F841

targets = sc.wavemaps[0].targets
filterer = sc.wavemaps[0].config.filterer
ntargets = len(targets)
Expand All @@ -109,155 +132,93 @@ def test_gf_stacking():
)
ats = gfs.reference_times - arrival_taper.b

if False:
traces, tmins = heart.seis_synthetics(
engine,
patches,
targets,
arrival_times=ats,
wavename="any_P",
arrival_taper=arrival_taper,
filterer=filterer,
outmode="stacked_traces",
)
# seismosizer engine --> reference
ref_traces, _ = heart.seis_synthetics(
engine,
patches,
targets,
arrival_times=ats,
wavename="any_P",
arrival_taper=arrival_taper,
filterer=filterer,
outmode="stacked_traces",
)

targetidxs = num.lib.index_tricks.s_[:]
targetidxs = num.atleast_2d(num.arange(ntargets)).T

if False:
# for station corrections maybe in the future?
station_corrections = num.zeros(len(traces))
station_corrections = num.zeros(len(ref_traces))
starttimes = (
num.tile(starttimes, ntargets)
+ num.repeat(station_corrections, fault.npatches)
).reshape(ntargets, fault.npatches)
targetidxs = num.atleast_2d(num.arange(ntargets)).T
elif True:
starttimes = num.tile(starttimes, ntargets).reshape((ntargets, uparr.size))

durations_dim2 = num.atleast_2d(durations.ravel())
patchidxs = num.arange(uparr.size, dtype="int")

# numpy stacking
gfs.set_stack_mode("numpy")
synthetics_nn = gfs.stack_all(
patchidxs=patchidxs,
targetidxs=targetidxs,
starttimes=starttimes,
starttimes=starttimes[:, patchidxs],
durations=durations_dim2,
slips=slips[components[0]],
interpolation="nearest_neighbor",
)

synthetics_ml = gfs.stack_all(
patchidxs=patchidxs,
targetidxs=targetidxs,
starttimes=starttimes,
starttimes=starttimes[:, patchidxs],
durations=durations_dim2,
slips=slips[components[0]],
interpolation="multilinear",
)

# Pytensor stacking
gfs.init_optimization()

if True:
synthetics_nn_t = gfs.stack_all(
targetidxs=targetidxs,
starttimes=starttimes,
durations=durations_dim2,
slips=slips[components[0]],
interpolation="nearest_neighbor",
).eval()

synthetics_ml_t = gfs.stack_all(
targetidxs=targetidxs,
starttimes=starttimes,
durations=durations_dim2,
slips=slips[components[0]],
interpolation="multilinear",
).eval()

synth_traces_nn = []
for i, target in enumerate(targets):
tr = trace.Trace(
ydata=synthetics_nn[i, :], tmin=gfs.reference_times[i], deltat=gfs.deltat
)
# print('trace tmin synthst', tr.tmin)
tr.set_codes(*target.codes)
tr.set_location("nn")
synth_traces_nn.append(tr)
synthetics_nn_t = gfs.stack_all(
targetidxs=targetidxs,
starttimes=starttimes,
durations=durations_dim2,
slips=slips[components[0]],
interpolation="nearest_neighbor",
).eval()

synth_traces_ml = []
for i, target in enumerate(targets):
tr = trace.Trace(
ydata=synthetics_ml[i, :], tmin=gfs.reference_times[i], deltat=gfs.deltat
synthetics_ml_t = gfs.stack_all(
targetidxs=targetidxs,
starttimes=starttimes,
durations=durations_dim2,
slips=slips[components[0]],
interpolation="multilinear",
).eval()

all_synth_traces = []
for test_synthetics, location_tag in zip(
[synthetics_nn, synthetics_ml, synthetics_nn_t, synthetics_ml_t],
["nn", "ml", "nn_t", "ml_t"],
):
test_traces = array_to_traces(
test_synthetics,
reference_times=gfs.reference_times,
deltat=gfs.deltat,
targets=targets,
location_tag=location_tag,
)
# print 'trace tmin synthst', tr.tmin
tr.set_codes(*target.codes)
tr.set_location("ml")
synth_traces_ml.append(tr)

if True:
synth_traces_nn_t = []
for i, target in enumerate(targets):
tr = trace.Trace(
ydata=synthetics_nn_t[i, :],
tmin=gfs.reference_times[i],
deltat=gfs.deltat,
)
# print('trace tmin synthst', tr.tmin)
tr.set_codes(*target.codes)
tr.set_location("nn_t")
synth_traces_nn_t.append(tr)

synth_traces_ml_t = []
for i, target in enumerate(targets):
tr = trace.Trace(
ydata=synthetics_ml_t[i, :],
tmin=gfs.reference_times[i],
deltat=gfs.deltat,
)
# print 'trace tmin synthst', tr.tmin
tr.set_codes(*target.codes)
tr.set_location("ml_t")
synth_traces_ml_t.append(tr)

# display to check
trace.snuffle(
# traces
synth_traces_nn + synth_traces_ml + synth_traces_nn_t + synth_traces_ml_t,
stations=sc.wavemaps[0].stations,
events=[event],
)

traces1, tmins = heart.seis_synthetics(
engine,
[patches[0]],
targets,
arrival_times=ats,
wavename="any_P",
arrival_taper=arrival_taper,
filterer=filterer,
outmode="stacked_traces",
)

gfs.set_stack_mode("numpy")
assert_traces(ref_traces, test_traces)
all_synth_traces.extend(test_traces)

synth_traces_ml1 = []
for i in range(1):
synthetics_ml1 = gfs.stack_all(
targetidxs=targetidxs,
patchidxs=[i],
starttimes=starttimes[0],
durations=num.atleast_2d(durations.ravel()[0]),
slips=num.atleast_1d(slips[components[0]][0]),
interpolation="multilinear",
if False:
# display to check
trace.snuffle(
ref_traces + all_synth_traces,
stations=sc.wavemaps[0].stations,
events=[event],
)

for i, target in enumerate(targets):
tr = trace.Trace(
ydata=synthetics_ml1[i, :],
tmin=gfs.reference_times[i],
deltat=gfs.deltat,
)
print("trace tmin synthst", tr.tmin)
# print(target.codes)
tr.set_codes(*target.codes)
tr.set_location("ml%i" % i)
synth_traces_ml1.append(tr)

trace.snuffle(
traces1 + synth_traces_ml1, stations=sc.wavemaps[0].stations, events=[event]
)

0 comments on commit 85991f4

Please sign in to comment.