From 85991f4f911dad3c502b25541342f21b53b1f38b Mon Sep 17 00:00:00 2001 From: hvasbath Date: Wed, 13 Mar 2024 22:48:12 +0100 Subject: [PATCH] tests: test_ffi_gfstacking cleanup and working --- beat/ffi/base.py | 3 +- test/test_ffi_gfstacking.py | 221 +++++++++++++++--------------------- 2 files changed, 92 insertions(+), 132 deletions(-) diff --git a/beat/ffi/base.py b/beat/ffi/base.py index b19b5c7a..1ae2403b 100644 --- a/beat/ffi/base.py +++ b/beat/ffi/base.py @@ -626,7 +626,6 @@ def stack_all( matrix : size (ntargets, nsamples) option : tensor.batched_dot(sd.dimshuffle((1,0,2)), u).sum(axis=0) """ - if targetidxs is None: raise ValueError("Target indexes have to be defined!") @@ -695,7 +694,7 @@ def stack_all( s_st_floor_rt_floor, ], axis=1, - ).T # + ) # else: raise NotImplementedError( diff --git a/test/test_ffi_gfstacking.py b/test/test_ffi_gfstacking.py index 63305cda..35fc03e3 100644 --- a/test/test_ffi_gfstacking.py +++ b/test/test_ffi_gfstacking.py @@ -24,7 +24,40 @@ project_dir = Path("/home/vasyurhm/BEATS/LaquilaJointPonlyUPDATE_wide_kin3_v2") -# @mark.skip("Needs version 2.0.0 compliant setup") +def array_to_traces(synthetics, reference_times, deltat, targets, location_tag=None): + synth_traces = [] + for i, target in enumerate(targets): + tr = trace.Trace(ydata=synthetics[i, :], tmin=reference_times[i], deltat=deltat) + + tr.set_codes(*target.codes) + if location_tag is not None: + tr.set_location(location_tag) + + synth_traces.append(tr) + + return synth_traces + + +def get_max_relative_and_absolute_errors(a, b): + abs_err = num.abs(a - b).max() + rel_err = (num.abs((a - b) / b).max(),) + print("absolute", abs_err) + print("relative", rel_err) + return abs_err, rel_err + + +def assert_traces(ref_traces, test_traces): + assert len(ref_traces) == len(test_traces) + + for ref_trace, test_trace in zip(ref_traces, test_traces): + num.testing.assert_allclose( + ref_trace.ydata, test_trace.ydata, rtol=5e-6, atol=5e-6 + ) + num.testing.assert_allclose( + ref_trace.tmin, test_trace.tmin, rtol=1e-3, atol=1e-3 + ) + + @mark.skipif(project_dir.is_dir() is False, reason="Needs project dir") def test_gf_stacking(): # general @@ -59,14 +92,9 @@ def test_gf_stacking(): + time_shift ) - print(starttimes) - # defining distributed slip values for slip parallel and perpendicular directions uparr = num.ones((npdip, npstrike)) * 2.0 - # uparr[1:3, 3:7] = 1.5 uperp = num.zeros((npdip, npstrike)) - # uperp[0,0] = 1. - # uperp[3,9] = 1. uperp[1:3, 3:7] = 1.0 # define rupture durations on each patch @@ -79,8 +107,6 @@ def test_gf_stacking(): "velocities": velocities.ravel(), } - print("fault parameters", slips) - # update patches with distributed slip and STF values for comp in components: patches = fault.get_subfault_patches(0, datatype="seismic", component=comp) @@ -96,9 +122,6 @@ def test_gf_stacking(): # synthetics generation engine = gf.LocalEngine(store_superdirs=store_superdirs) - - patchidx = fault.patchmap(index=0, dipidx=nuc_dip_idx, strikeidx=nuc_strike_idx) # noqa: F841 - targets = sc.wavemaps[0].targets filterer = sc.wavemaps[0].config.filterer ntargets = len(targets) @@ -109,155 +132,93 @@ def test_gf_stacking(): ) ats = gfs.reference_times - arrival_taper.b - if False: - traces, tmins = heart.seis_synthetics( - engine, - patches, - targets, - arrival_times=ats, - wavename="any_P", - arrival_taper=arrival_taper, - filterer=filterer, - outmode="stacked_traces", - ) + # seismosizer engine --> reference + ref_traces, _ = heart.seis_synthetics( + engine, + patches, + targets, + arrival_times=ats, + wavename="any_P", + arrival_taper=arrival_taper, + filterer=filterer, + outmode="stacked_traces", + ) - targetidxs = num.lib.index_tricks.s_[:] + targetidxs = num.atleast_2d(num.arange(ntargets)).T if False: # for station corrections maybe in the future? - station_corrections = num.zeros(len(traces)) + station_corrections = num.zeros(len(ref_traces)) starttimes = ( num.tile(starttimes, ntargets) + num.repeat(station_corrections, fault.npatches) ).reshape(ntargets, fault.npatches) targetidxs = num.atleast_2d(num.arange(ntargets)).T + elif True: + starttimes = num.tile(starttimes, ntargets).reshape((ntargets, uparr.size)) durations_dim2 = num.atleast_2d(durations.ravel()) + patchidxs = num.arange(uparr.size, dtype="int") + + # numpy stacking gfs.set_stack_mode("numpy") synthetics_nn = gfs.stack_all( + patchidxs=patchidxs, targetidxs=targetidxs, - starttimes=starttimes, + starttimes=starttimes[:, patchidxs], durations=durations_dim2, slips=slips[components[0]], interpolation="nearest_neighbor", ) synthetics_ml = gfs.stack_all( + patchidxs=patchidxs, targetidxs=targetidxs, - starttimes=starttimes, + starttimes=starttimes[:, patchidxs], durations=durations_dim2, slips=slips[components[0]], interpolation="multilinear", ) + # Pytensor stacking gfs.init_optimization() - if True: - synthetics_nn_t = gfs.stack_all( - targetidxs=targetidxs, - starttimes=starttimes, - durations=durations_dim2, - slips=slips[components[0]], - interpolation="nearest_neighbor", - ).eval() - - synthetics_ml_t = gfs.stack_all( - targetidxs=targetidxs, - starttimes=starttimes, - durations=durations_dim2, - slips=slips[components[0]], - interpolation="multilinear", - ).eval() - - synth_traces_nn = [] - for i, target in enumerate(targets): - tr = trace.Trace( - ydata=synthetics_nn[i, :], tmin=gfs.reference_times[i], deltat=gfs.deltat - ) - # print('trace tmin synthst', tr.tmin) - tr.set_codes(*target.codes) - tr.set_location("nn") - synth_traces_nn.append(tr) + synthetics_nn_t = gfs.stack_all( + targetidxs=targetidxs, + starttimes=starttimes, + durations=durations_dim2, + slips=slips[components[0]], + interpolation="nearest_neighbor", + ).eval() - synth_traces_ml = [] - for i, target in enumerate(targets): - tr = trace.Trace( - ydata=synthetics_ml[i, :], tmin=gfs.reference_times[i], deltat=gfs.deltat + synthetics_ml_t = gfs.stack_all( + targetidxs=targetidxs, + starttimes=starttimes, + durations=durations_dim2, + slips=slips[components[0]], + interpolation="multilinear", + ).eval() + + all_synth_traces = [] + for test_synthetics, location_tag in zip( + [synthetics_nn, synthetics_ml, synthetics_nn_t, synthetics_ml_t], + ["nn", "ml", "nn_t", "ml_t"], + ): + test_traces = array_to_traces( + test_synthetics, + reference_times=gfs.reference_times, + deltat=gfs.deltat, + targets=targets, + location_tag=location_tag, ) - # print 'trace tmin synthst', tr.tmin - tr.set_codes(*target.codes) - tr.set_location("ml") - synth_traces_ml.append(tr) - - if True: - synth_traces_nn_t = [] - for i, target in enumerate(targets): - tr = trace.Trace( - ydata=synthetics_nn_t[i, :], - tmin=gfs.reference_times[i], - deltat=gfs.deltat, - ) - # print('trace tmin synthst', tr.tmin) - tr.set_codes(*target.codes) - tr.set_location("nn_t") - synth_traces_nn_t.append(tr) - - synth_traces_ml_t = [] - for i, target in enumerate(targets): - tr = trace.Trace( - ydata=synthetics_ml_t[i, :], - tmin=gfs.reference_times[i], - deltat=gfs.deltat, - ) - # print 'trace tmin synthst', tr.tmin - tr.set_codes(*target.codes) - tr.set_location("ml_t") - synth_traces_ml_t.append(tr) - - # display to check - trace.snuffle( - # traces - synth_traces_nn + synth_traces_ml + synth_traces_nn_t + synth_traces_ml_t, - stations=sc.wavemaps[0].stations, - events=[event], - ) - traces1, tmins = heart.seis_synthetics( - engine, - [patches[0]], - targets, - arrival_times=ats, - wavename="any_P", - arrival_taper=arrival_taper, - filterer=filterer, - outmode="stacked_traces", - ) - - gfs.set_stack_mode("numpy") + assert_traces(ref_traces, test_traces) + all_synth_traces.extend(test_traces) - synth_traces_ml1 = [] - for i in range(1): - synthetics_ml1 = gfs.stack_all( - targetidxs=targetidxs, - patchidxs=[i], - starttimes=starttimes[0], - durations=num.atleast_2d(durations.ravel()[0]), - slips=num.atleast_1d(slips[components[0]][0]), - interpolation="multilinear", + if False: + # display to check + trace.snuffle( + ref_traces + all_synth_traces, + stations=sc.wavemaps[0].stations, + events=[event], ) - - for i, target in enumerate(targets): - tr = trace.Trace( - ydata=synthetics_ml1[i, :], - tmin=gfs.reference_times[i], - deltat=gfs.deltat, - ) - print("trace tmin synthst", tr.tmin) - # print(target.codes) - tr.set_codes(*target.codes) - tr.set_location("ml%i" % i) - synth_traces_ml1.append(tr) - - trace.snuffle( - traces1 + synth_traces_ml1, stations=sc.wavemaps[0].stations, events=[event] - )