Skip to content

Commit

Permalink
backend: update extract_bounds_from_summary to new summary format
Browse files Browse the repository at this point in the history
  • Loading branch information
hvasbath committed Mar 14, 2024
1 parent 3df4de5 commit c90d43a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
24 changes: 20 additions & 4 deletions beat/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,11 @@ def setup(parser):
param = wmap.time_shifts_id

new_bounds[param] = extract_bounds_from_summary(
summarydf, varname=param, shape=(wmap.hypersize,), roundto=0
summarydf,
varname=param,
shape=(wmap.hypersize,),
roundto=0,
alpha=0.06,
)
new_bounds[param].append(point[param])

Expand Down Expand Up @@ -692,7 +696,11 @@ def setup(parser):
new_bounds = {}
for param in ["time"]:
new_bounds[param] = extract_bounds_from_summary(
summarydf, varname=param, shape=(n_sources[0],), roundto=0
summarydf,
varname=param,
shape=(n_sources[0],),
roundto=0,
alpha=0.06,
)
new_bounds[param].append(point[param])

Expand All @@ -710,7 +718,11 @@ def setup(parser):
shape = (n_sources,)

new_bounds[param] = extract_bounds_from_summary(
summarydf, varname=param, shape=shape, roundto=1
summarydf,
varname=param,
shape=shape,
roundto=1,
alpha=0.06,
)
new_bounds[param].append(point[param])

Expand All @@ -733,7 +745,11 @@ def setup(parser):
for param in common_source_params:
try:
new_bounds[param] = extract_bounds_from_summary(
summarydf, varname=param, shape=(n_sources,), roundto=0
summarydf,
varname=param,
shape=(n_sources,),
roundto=0,
alpha=0.06,
)
new_bounds[param].append(point[param])
except KeyError:
Expand Down
15 changes: 12 additions & 3 deletions beat/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def _create_flat_names(varname, shape):
return ["{}__{}".format(varname, "_".join(idxs)) for idxs in zip(*labels)]


def _create_flat_names_summary(varname, shape):
if not shape or sum(shape) == 1:
return [varname]

labels = (num.ravel(xs).tolist() for xs in num.indices(shape))
labels = (map(str, [xs]) for xs in labels)
return ["{}{}".format(varname, "".join(idxs)) for idxs in zip(*labels)]


def _create_shape(flat_names):
"""Determine shape from `_create_flat_names` output."""
try:
Expand Down Expand Up @@ -1350,9 +1359,9 @@ def extract_bounds_from_summary(summary, varname, shape, roundto=None, alpha=0.0
def do_nothing(value):
return value

indexes = _create_flat_names(varname, shape)
lower_quant = "hpd_{0:g}".format(100 * alpha / 2)
upper_quant = "hpd_{0:g}".format(100 * (1 - alpha / 2))
indexes = _create_flat_names_summary(varname, shape)
lower_quant = "hdi_{0:g}%".format(100 * alpha / 2)
upper_quant = "hdi_{0:g}%".format(100 * (1 - alpha / 2))

bounds = []
for quant in [lower_quant, upper_quant]:
Expand Down
5 changes: 4 additions & 1 deletion beat/plotting/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,10 @@ def draw_3d_slip_distribution(problem, po):
os.path.join(problem.outfolder, "summary.txt"), sep=r"\s+"
)
bounds = extract_bounds_from_summary(
summarydf, varname="uparr", shape=(fault.npatches,)
summarydf,
varname="uparr",
shape=(fault.npatches,),
alpha=0.06,
)
reference["slip_variation"] = bounds[1] - bounds[0]
slip_units = "m"
Expand Down

0 comments on commit c90d43a

Please sign in to comment.