From f06c178e41893435c6b8f8399b9c43f70345ea85 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 27 May 2024 15:10:17 +0000 Subject: [PATCH] upd --- src/qseek/apps/qseek.py | 4 ++- src/qseek/images/phase_net.py | 1 + src/qseek/models/station.py | 1 + src/qseek/search.py | 24 ++++++++------- ...{station_weights.py => spatial_weights.py} | 30 +++++++++++++------ src/qseek/tracers/cake.py | 17 +++++++---- src/qseek/utils.py | 6 ++-- 7 files changed, 54 insertions(+), 29 deletions(-) rename src/qseek/{station_weights.py => spatial_weights.py} (85%) diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index 1cc66462..55561abb 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -7,11 +7,13 @@ import logging import shutil from pathlib import Path +from typing import TYPE_CHECKING import nest_asyncio from pkg_resources import get_distribution -from qseek.models.detection import EventDetection +if TYPE_CHECKING: + from qseek.models.detection import EventDetection nest_asyncio.apply() diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index 46ccc4c9..c658c3b1 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -68,6 +68,7 @@ def search_phase_arrival( Returns: datetime | None: Time of arrival, None is none found. """ + # TODO adapt threshold to the seisbench model trace = self.traces[trace_idx] window_length = timedelta(seconds=search_window_seconds) try: diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index f2d8e3e3..2b9dc6e0 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -165,6 +165,7 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: raise ValueError("no stations available, add waveforms to start detection") def __iter__(self) -> Iterator[Station]: + # TODO: this is inefficient return (sta for sta in self.stations if sta.nsl.pretty not in self.blacklist) @property diff --git a/src/qseek/search.py b/src/qseek/search.py index 1cef0d03..2b5e29de 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -34,7 +34,7 @@ from qseek.pre_processing.frequency_filters import Bandpass from qseek.pre_processing.module import Downsample, PreProcessing from qseek.signals import Signal -from qseek.station_weights import StationWeights +from qseek.spatial_weights import SpatialWeights from qseek.stats import RuntimeStats, Stats from qseek.tracers.tracers import RayTracer, RayTracers from qseek.utils import ( @@ -184,8 +184,8 @@ def tts(duration: timedelta) -> str: table.add_row( "Resources", f"CPU {self.cpu_percent:.1f}%, " - f"RAM {human_readable_bytes(self.memory_used)}" - f"/{self.memory_total.human_readable()}", + f"RAM {human_readable_bytes(self.memory_used, decimal=True)}" + f"/{self.memory_total.human_readable(decimal=True)}", ) table.add_row( "Progress ", @@ -238,9 +238,9 @@ class Search(BaseModel): default=RayTracers(root=[tracer() for tracer in RayTracer.get_subclasses()]), description="List of ray tracers for travel time calculation.", ) - station_weights: StationWeights | None = Field( - default=StationWeights(), - description="Station weights for spatial weighting.", + spatial_weights: SpatialWeights | None = Field( + default=SpatialWeights(), + description="Spatial weights for distance weighting.", ) station_corrections: StationCorrectionType | None = Field( default=None, @@ -463,8 +463,8 @@ async def prepare(self) -> None: self.data_provider.prepare(self.stations) await self.pre_processing.prepare() - if self.station_weights: - self.station_weights.prepare(self.stations, self.octree) + if self.spatial_weights: + self.spatial_weights.prepare(self.stations, self.octree) if self.station_corrections: await self.station_corrections.prepare( @@ -722,8 +722,8 @@ async def calculate_semblance( weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) weights[traveltimes_bad] = 0.0 - if parent.station_weights: - weights *= await parent.station_weights.get_weights(octree, image.stations) + if parent.spatial_weights: + weights *= await parent.spatial_weights.get_weights(octree, image.stations) with np.errstate(divide="ignore", invalid="ignore"): weights /= weights.sum(axis=1, keepdims=True) @@ -861,7 +861,9 @@ async def search( except NodeSplitError: continue logger.info( - "energy detected, refined %d nodes, level %d", + "detected %d energy burst%s - refined %d nodes, lowest level %d", + detection_idx.size, + "s" if detection_idx.size > 1 else "", len(refine_nodes), new_level, ) diff --git a/src/qseek/station_weights.py b/src/qseek/spatial_weights.py similarity index 85% rename from src/qseek/station_weights.py rename to src/qseek/spatial_weights.py index 3500a252..bb241230 100644 --- a/src/qseek/station_weights.py +++ b/src/qseek/spatial_weights.py @@ -19,21 +19,26 @@ logger = logging.getLogger(__name__) -class StationWeights(BaseModel): +class SpatialWeights(BaseModel): exponent: float = Field( - default=0.5, - description="Exponent of the exponential decay function. Default is 1.5.", + default=3.0, + description="Exponent of the spatial decay function. Default is 3.", ge=0.0, - le=3.0, ) radius_meters: PositiveFloat = Field( default=8000.0, - description="Radius in meters for the exponential decay function. " - "Default is 8000.", + description="Cutoff distance for the spatial decay function in meters." + " Default is 8000.", + ) + waterlevel: float = Field( + default=0.0, + ge=0.0, + le=1.0, + description="Waterlevel for the exponential decay function. Default is 0.0.", ) lut_cache_size: ByteSize = Field( default=200 * MB, - description="Size of the LRU cache in bytes. Default is 1e9.", + description="Size of the LRU cache in bytes. Default is 200 MB.", ) _node_lut: dict[bytes, np.ndarray] = PrivateAttr() @@ -47,14 +52,21 @@ def get_distances(self, nodes: Iterable[Node]) -> np.ndarray: self._station_coords_ecef - node_coords[:, np.newaxis], axis=2 ) - def calc_weights(self, distances: np.ndarray) -> np.ndarray: + def calc_weights_exp(self, distances: np.ndarray) -> np.ndarray: exp = self.exponent # radius = distances.min(axis=1)[:, np.newaxis] radius = self.radius_meters return np.exp(-(distances**exp) / (radius**exp)) + def calc_weights(self, distances: np.ndarray) -> np.ndarray: + exp = self.exponent + radius = self.radius_meters + return (1 - self.waterlevel) / ( + 1 + (distances / radius) ** exp + ) + self.waterlevel + def prepare(self, stations: Stations, octree: Octree) -> None: - logger.info("preparing station weights") + logger.info("preparing spatial weights") bytes_per_node = stations.n_stations * np.float32().itemsize lru_cache_size = int(self.lut_cache_size / bytes_per_node) diff --git a/src/qseek/tracers/cake.py b/src/qseek/tracers/cake.py index f6e0ee3f..1c7594aa 100644 --- a/src/qseek/tracers/cake.py +++ b/src/qseek/tracers/cake.py @@ -648,9 +648,10 @@ def get_travel_time_location( source: Location, receiver: Location, ) -> float: - if phase not in self.phases: - raise ValueError(f"Phase {phase} is not defined.") - tree = self._get_sptree_model(phase) + try: + tree = self._get_sptree_model(phase) + except KeyError as exc: + raise ValueError(f"Phase {phase} is not defined.") from exc return tree.get_travel_time(source, receiver) async def get_travel_times( @@ -659,9 +660,13 @@ async def get_travel_times( octree: Octree, stations: Stations, ) -> np.ndarray: - if phase not in self.phases: - raise ValueError(f"Phase {phase} is not defined.") - return await self._get_sptree_model(phase).get_travel_times(octree, stations) + try: + return await self._get_sptree_model(phase).get_travel_times( + octree, + stations, + ) + except KeyError as exc: + raise ValueError(f"Phase {phase} is not defined.") from exc def get_arrivals( self, diff --git a/src/qseek/utils.py b/src/qseek/utils.py index b5f25642..ed40411a 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -386,17 +386,19 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: return wrapper -def human_readable_bytes(size: int | float) -> str: +def human_readable_bytes(size: int | float, decimal: bool = False) -> str: """Convert a size in bytes to a human-readable string representation. Args: size (int | float): The size in bytes. + decimal: If True, use decimal units (e.g. 1000 bytes per KB). + If False, use binary units (e.g. 1024 bytes per KiB). Returns: str: The human-readable string representation of the size. """ - return ByteSize(size).human_readable() + return ByteSize.human_readable(size, decimal=decimal) def datetime_now() -> datetime: