Skip to content

Commit

Permalink
octree: new exploration strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Mar 4, 2024
1 parent 311c03f commit 0fc8464
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 69 deletions.
29 changes: 26 additions & 3 deletions src/qseek/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Literal

import numpy as np
from obspy import Stream
from pydantic import Field, PositiveFloat, PositiveInt, PrivateAttr
from pyrocko import obspy_compat
Expand Down Expand Up @@ -74,15 +75,37 @@ def search_phase_arrival(
inplace=False,
)
except NoData:
logger.warning("No data to pick phase arrival.")
return None

peak_idx, _ = signal.find_peaks(
search_trace.ydata,
height=threshold,
prominence=detection_blinding_seconds,
distance=detection_blinding_seconds,
prominence=threshold,
distance=int(detection_blinding_seconds * 1 / search_trace.deltat),
)
time_seconds, value = search_trace.max()
if False:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
time = search_trace.get_xdata()
std = np.std(search_trace.get_ydata())

ax.plot(time, search_trace.get_ydata())
ax.grid(alpha=0.3)
ax.axhline(threshold, color="r", linestyle="--", label="threshold")
ax.axhline(std, color="g", linestyle="--", label="std")
ax.axhline(3 * std, color="b", linestyle="dotted", label="3*std")
ax.axvline(
modelled_arrival.timestamp(),
color="k",
alpha=0.3,
label="modelled arrival",
)
if peak_idx.size:
ax.axvline(time[peak_idx], color="m", linestyle="--", label="peaks")
plt.show()

if not peak_idx.size:
return None

Expand Down
133 changes: 90 additions & 43 deletions src/qseek/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def split(self) -> tuple[Node, ...]:
raise NodeSplitError("Cannot split node below limit.")

if not self._children_cached:
half_size = self.size / 2
child_size = self.size / 2

self._children_cached = tuple(
Node(
east=self.east + east * half_size / 2,
north=self.north + north * half_size / 2,
depth=self.depth + depth * half_size / 2,
size=half_size,
east=self.east + east * child_size / 2,
north=self.north + north * child_size / 2,
depth=self.depth + depth * child_size / 2,
size=child_size,
tree=self.tree,
parent=self,
level=self.level + 1,
Expand All @@ -119,11 +119,12 @@ def split(self) -> tuple[Node, ...]:
def coordinates(self) -> tuple[float, float, float]:
return self.east, self.north, self.depth

def get_distance_border(self, include_top: bool = False) -> float:
def get_distance_border(self, trough: bool = False) -> float:
"""Distance to the closest EW, NS or bottom border of the tree.
!!! note
Surface distance is excluded.
Args:
trough (bool, optional): If True, the distance to the closest border
within the trough (open top) is returned. Defaults to False.
Raises:
AttributeError: If the parent tree is not set.
Expand All @@ -141,19 +142,23 @@ def get_distance_border(self, include_top: bool = False) -> float:
tree.east_bounds[1] - self.east,
tree.depth_bounds[1] - self.depth,
)
if not include_top:
if trough:
return trough_distance
return min(trough_distance, self.depth - tree.depth_bounds[0])

def is_inside_border(self, ignore_top: bool = False) -> bool:
def is_inside_border(self, trough: bool = False) -> bool:
"""Check if the node is within the root node border.
Args:
trough (bool, optional): If True, the node is considered inside the
trough (open top). Defaults to False.
Returns:
bool: True if the node is inside the root tree's border.
"""
if self.tree is None:
raise AttributeError("parent tree not set")
return self.get_distance_border(ignore_top) <= self.tree.root_node_size
return self.get_distance_border(trough) <= self.tree.root_node_size

def can_split(self) -> bool:
"""Check if the node can be split.
Expand Down Expand Up @@ -195,6 +200,15 @@ def distance_to_location(self, location: Location) -> float:
"""
return location.distance_to(self.as_location())

def semblance_density(self) -> float:
"""
Calculate the semblance density of the octree.
Returns:
The semblance density of the octree.
"""
return self.semblance / self.size**3

def as_location(self) -> Location:
"""Returns the location of the node.
Expand All @@ -215,6 +229,51 @@ def as_location(self) -> Location:
)
return self._location

def collides(self, other: Node) -> bool:
"""Check if two nodes collide.
Args:
other (Node): Other node to check for collision.
Returns:
bool: True if the nodes collide.
"""
return (
abs(self.east - other.east) <= (self.size + other.size) / 2
and abs(self.north - other.north) <= (self.size + other.size) / 2
and abs(self.depth - other.depth) <= (self.size + other.size) / 2
)

def get_neighbours(self) -> list[Node]:
"""Get the direct neighbours of the node from the parent tree.
Returns:
list[Node]: List of direct neighbours.
"""
if not self.tree:
raise AttributeError("parent tree not set")

return [
node
for node in self.tree.iter_nodes()
if self.collides(node) and node is not self
]

def distance_to(self, other: Node) -> float:
"""Distance to another node.
Args:
other (Node): Other node to calculate distance to.
Returns:
float: Distance to other node.
"""
return np.sqrt(
(self.east - other.east) ** 2
+ (self.north - other.north) ** 2
+ (self.depth - other.depth) ** 2
)

def __iter__(self) -> Iterator[Node]:
if self.children:
for child in self.children:
Expand Down Expand Up @@ -299,6 +358,7 @@ def check_limits(self) -> Octree:
def model_post_init(self, __context: Any) -> None:
"""Initialize octree. This method is called by the pydantic model"""
self._root_nodes = self.get_root_nodes(self.root_node_size)

logger.info(
"initializing octree volume with %d nodes and %.1f km³,"
" smallest node size: %.1f m",
Expand Down Expand Up @@ -343,6 +403,20 @@ def volume(self) -> float:
"""Volume of the octree in cubic meters"""
return reduce(mul, self.extent())

def iter_nodes(self, level: int | None = None) -> Iterator[Node]:
"""Iterate over nodes.
Args:
level (int, optional): Level to iterate over. Defaults to None.
If None, all node levels are iterated.
Yields:
Iterator[Node]: Node iterator.
"""
for node in self:
if level is None or node.level == level:
yield node

def __iter__(self) -> Iterator[Node]:
for node in self._root_nodes:
yield from node
Expand Down Expand Up @@ -520,25 +594,6 @@ def get_nodes_level(self, level: int = 0):
"""
return [node for node in self if node.level <= level]

def is_node_in_bounds(self, node: Node) -> bool:
"""Check if node is inside the absorbing boundary.
Args:
node (Node): Node to check.
Returns:
bool: Check if node is absorbed.
"""
return node.distance_border > self.absorbing_boundary

def n_levels(self) -> int:
"""Returns the number of the deepest level in the octree.
Returns:
int: Index of deepest octree level.
"""
return int(np.floor(np.log2(self.size_initial / self.size_limit)))

def smallest_node_size(self) -> float:
"""Returns the smallest possible node size.
Expand All @@ -558,7 +613,6 @@ def total_number_nodes(self) -> int:
async def interpolate_max_location(
self,
peak_node: Node,
n_neighbors: int = 5,
) -> Location:
"""Interpolate the location of the maximum semblance value.
Expand All @@ -576,21 +630,14 @@ async def interpolate_max_location(
if self._semblance is None:
raise AttributeError("no semblance values set")

node_coords = self.get_coordinates(system="raw")
node_distances = await asyncio.to_thread(
np.linalg.norm,
node_coords - peak_node.coordinates,
axis=1,
)
sorted_idx = await asyncio.to_thread(np.argsort, node_distances)
neighbor_nodes = self.get_nodes(sorted_idx[: (n_neighbors**3)])

neighbor_nodes = peak_node.get_neighbours()
neighbor_coords = np.array(
[(n.east, n.north, n.depth, n.semblance) for n in neighbor_nodes]
[
(n.east, n.north, n.depth, n.semblance)
for n in [peak_node, *neighbor_nodes]
]
)

# np.save("/tmp/neighbor_coords.npy", neighbor_coords)

neighbor_semblance = neighbor_coords[:, 3]
rbf = scipy.interpolate.RBFInterpolator(
neighbor_coords[:, :3],
Expand Down
37 changes: 14 additions & 23 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,6 @@ class Search(BaseModel):
default=0.05,
description="Detection threshold for semblance.",
)
node_split_percentile: float = Field(
default=5.0,
gt=0.0,
lt=100.0,
description="Percentiles of octree nodes to split,"
" relative to the maximum detected semblance."
" Higher percentiles are more explorative, lower values are more conservative."
" A more explorative seach will refine more nodes and consume more RAM.",
)
ignore_boundary_nodes: Literal[False, "trough", "volume"] = Field(
default=False,
description="Ignore events that are inside the first root node layer of"
Expand Down Expand Up @@ -773,26 +764,25 @@ async def search(

# Split Octree nodes above a semblance threshold. Once octree for all detections
# in frame
maxima_node_idx = await semblance.maxima_node_idx()
maxima_node_indices = await semblance.maxima_node_idx()
refine_nodes: set[Node] = set()
for time_idx, semblance_detection in zip(
detection_idx, detection_semblance, strict=True
):
for time_idx in detection_idx:
octree.map_semblance(semblance.get_semblance(time_idx))
node_idx = maxima_node_idx[time_idx]
node_idx = maxima_node_indices[time_idx]
source_node = octree[node_idx]
if not source_node.can_split():
continue

if parent.ignore_boundary_nodes and source_node.is_inside_border(
ignore_top=parent.ignore_boundary_nodes == "trough"
trough=parent.ignore_boundary_nodes == "trough"
):
continue
refine_nodes.update(source_node)
refine_nodes.update(source_node.get_neighbours())

split_nodes = octree.get_nodes_by_threshold(
semblance_detection * (1.0 - parent.node_split_percentile / 100.0)
)
refine_nodes.update(split_nodes)
densest_node = max(octree, key=lambda node: node.semblance_density())
refine_nodes.add(densest_node)
refine_nodes.update(densest_node.get_neighbours())

refine_nodes = {node for node in refine_nodes if node.can_split()}

# refine_nodes is empty when all sources fall into smallest octree nodes
if refine_nodes:
Expand Down Expand Up @@ -823,7 +813,7 @@ async def search(
octree.map_semblance(semblance_event)
source_node = octree[node_idx]
if parent.ignore_boundary_nodes and source_node.is_inside_border(
ignore_top=parent.ignore_boundary_nodes == "trough"
trough=parent.ignore_boundary_nodes == "trough"
):
continue

Expand Down Expand Up @@ -863,7 +853,8 @@ async def search(
station_delays.append(timedelta(seconds=delay))

arrivals_observed = image.search_phase_arrivals(
modelled_arrivals=[arr if arr else None for arr in arrival_times]
modelled_arrivals=[arr if arr else None for arr in arrival_times],
threshold=parent.detection_threshold,
)

phase_detections = [
Expand Down
26 changes: 26 additions & 0 deletions src/qseek/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, Literal, Sequence

import matplotlib.pyplot as plt
import numpy as np
from lru import LRU
from pydantic import (
Expand All @@ -30,6 +31,7 @@
from pyrocko import spit
from pyrocko.cake import LayeredModel, PhaseDef, load_model, m2d
from pyrocko.gf import meta
from pyrocko.plot.cake_plot import my_model_plot as earthmodel_plot

from qseek.octree import get_node_coordinates
from qseek.stats import PROGRESS
Expand Down Expand Up @@ -149,6 +151,22 @@ def get_profile_vp(self) -> np.ndarray:
def get_profile_vs(self) -> np.ndarray:
return self.layered_model.profile("vs")

def save_plot(self, filename: Path) -> None:
"""
Plot the layered model and save the figure to a file.
Args:
filename (Path): The path to save the figure.
"""
fig = plt.figure()
ax = fig.add_subplot(111)
earthmodel_plot(self.layered_model, axes=ax)
fig.savefig(filename, dpi=300)
if self.filename:
ax.set_title(f"File: {self.filename}")

logger.info("saved earth model plot to %s", filename)

@cached_property
def hash(self) -> str:
model_serialised = BytesIO()
Expand Down Expand Up @@ -600,6 +618,14 @@ async def prepare(
await tree.init_lut(octree, stations)
self._travel_time_trees[phase_descr] = tree

if rundir:
cake_plots = rundir / "cake"
cake_plots.mkdir(exist_ok=True)
for phase, tree in self._travel_time_trees.items():
tree.earthmodel.save_plot(
cake_plots / f"earthmodel_{phase.replace(':', '_')}.png",
)

def _get_sptree_model(self, phase: str) -> TravelTimeTree:
return self._travel_time_trees[phase]

Expand Down

0 comments on commit 0fc8464

Please sign in to comment.