Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed May 27, 2024
1 parent 3ba1cd6 commit f06c178
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 29 deletions.
4 changes: 3 additions & 1 deletion src/qseek/apps/qseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import logging
import shutil
from pathlib import Path
from typing import TYPE_CHECKING

import nest_asyncio
from pkg_resources import get_distribution

from qseek.models.detection import EventDetection
if TYPE_CHECKING:
from qseek.models.detection import EventDetection

nest_asyncio.apply()

Expand Down
1 change: 1 addition & 0 deletions src/qseek/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def search_phase_arrival(
Returns:
datetime | None: Time of arrival, None is none found.
"""
# TODO adapt threshold to the seisbench model
trace = self.traces[trace_idx]
window_length = timedelta(seconds=search_window_seconds)
try:
Expand Down
1 change: 1 addition & 0 deletions src/qseek/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None:
raise ValueError("no stations available, add waveforms to start detection")

def __iter__(self) -> Iterator[Station]:
# TODO: this is inefficient
return (sta for sta in self.stations if sta.nsl.pretty not in self.blacklist)

@property
Expand Down
24 changes: 13 additions & 11 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from qseek.pre_processing.frequency_filters import Bandpass
from qseek.pre_processing.module import Downsample, PreProcessing
from qseek.signals import Signal
from qseek.station_weights import StationWeights
from qseek.spatial_weights import SpatialWeights
from qseek.stats import RuntimeStats, Stats
from qseek.tracers.tracers import RayTracer, RayTracers
from qseek.utils import (
Expand Down Expand Up @@ -184,8 +184,8 @@ def tts(duration: timedelta) -> str:
table.add_row(
"Resources",
f"CPU {self.cpu_percent:.1f}%, "
f"RAM {human_readable_bytes(self.memory_used)}"
f"/{self.memory_total.human_readable()}",
f"RAM {human_readable_bytes(self.memory_used, decimal=True)}"
f"/{self.memory_total.human_readable(decimal=True)}",
)
table.add_row(
"Progress ",
Expand Down Expand Up @@ -238,9 +238,9 @@ class Search(BaseModel):
default=RayTracers(root=[tracer() for tracer in RayTracer.get_subclasses()]),
description="List of ray tracers for travel time calculation.",
)
station_weights: StationWeights | None = Field(
default=StationWeights(),
description="Station weights for spatial weighting.",
spatial_weights: SpatialWeights | None = Field(
default=SpatialWeights(),
description="Spatial weights for distance weighting.",
)
station_corrections: StationCorrectionType | None = Field(
default=None,
Expand Down Expand Up @@ -463,8 +463,8 @@ async def prepare(self) -> None:
self.data_provider.prepare(self.stations)
await self.pre_processing.prepare()

if self.station_weights:
self.station_weights.prepare(self.stations, self.octree)
if self.spatial_weights:
self.spatial_weights.prepare(self.stations, self.octree)

if self.station_corrections:
await self.station_corrections.prepare(
Expand Down Expand Up @@ -722,8 +722,8 @@ async def calculate_semblance(
weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32)
weights[traveltimes_bad] = 0.0

if parent.station_weights:
weights *= await parent.station_weights.get_weights(octree, image.stations)
if parent.spatial_weights:
weights *= await parent.spatial_weights.get_weights(octree, image.stations)

with np.errstate(divide="ignore", invalid="ignore"):
weights /= weights.sum(axis=1, keepdims=True)
Expand Down Expand Up @@ -861,7 +861,9 @@ async def search(
except NodeSplitError:
continue
logger.info(
"energy detected, refined %d nodes, level %d",
"detected %d energy burst%s - refined %d nodes, lowest level %d",
detection_idx.size,
"s" if detection_idx.size > 1 else "",
len(refine_nodes),
new_level,
)
Expand Down
30 changes: 21 additions & 9 deletions src/qseek/station_weights.py → src/qseek/spatial_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,26 @@
logger = logging.getLogger(__name__)


class StationWeights(BaseModel):
class SpatialWeights(BaseModel):
exponent: float = Field(
default=0.5,
description="Exponent of the exponential decay function. Default is 1.5.",
default=3.0,
description="Exponent of the spatial decay function. Default is 3.",
ge=0.0,
le=3.0,
)
radius_meters: PositiveFloat = Field(
default=8000.0,
description="Radius in meters for the exponential decay function. "
"Default is 8000.",
description="Cutoff distance for the spatial decay function in meters."
" Default is 8000.",
)
waterlevel: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Waterlevel for the exponential decay function. Default is 0.0.",
)
lut_cache_size: ByteSize = Field(
default=200 * MB,
description="Size of the LRU cache in bytes. Default is 1e9.",
description="Size of the LRU cache in bytes. Default is 200 MB.",
)

_node_lut: dict[bytes, np.ndarray] = PrivateAttr()
Expand All @@ -47,14 +52,21 @@ def get_distances(self, nodes: Iterable[Node]) -> np.ndarray:
self._station_coords_ecef - node_coords[:, np.newaxis], axis=2
)

def calc_weights(self, distances: np.ndarray) -> np.ndarray:
def calc_weights_exp(self, distances: np.ndarray) -> np.ndarray:
exp = self.exponent
# radius = distances.min(axis=1)[:, np.newaxis]
radius = self.radius_meters
return np.exp(-(distances**exp) / (radius**exp))

def calc_weights(self, distances: np.ndarray) -> np.ndarray:
exp = self.exponent
radius = self.radius_meters
return (1 - self.waterlevel) / (
1 + (distances / radius) ** exp
) + self.waterlevel

def prepare(self, stations: Stations, octree: Octree) -> None:
logger.info("preparing station weights")
logger.info("preparing spatial weights")

bytes_per_node = stations.n_stations * np.float32().itemsize
lru_cache_size = int(self.lut_cache_size / bytes_per_node)
Expand Down
17 changes: 11 additions & 6 deletions src/qseek/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,10 @@ def get_travel_time_location(
source: Location,
receiver: Location,
) -> float:
if phase not in self.phases:
raise ValueError(f"Phase {phase} is not defined.")
tree = self._get_sptree_model(phase)
try:
tree = self._get_sptree_model(phase)
except KeyError as exc:
raise ValueError(f"Phase {phase} is not defined.") from exc
return tree.get_travel_time(source, receiver)

async def get_travel_times(
Expand All @@ -659,9 +660,13 @@ async def get_travel_times(
octree: Octree,
stations: Stations,
) -> np.ndarray:
if phase not in self.phases:
raise ValueError(f"Phase {phase} is not defined.")
return await self._get_sptree_model(phase).get_travel_times(octree, stations)
try:
return await self._get_sptree_model(phase).get_travel_times(
octree,
stations,
)
except KeyError as exc:
raise ValueError(f"Phase {phase} is not defined.") from exc

def get_arrivals(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/qseek/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,17 +386,19 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return wrapper


def human_readable_bytes(size: int | float) -> str:
def human_readable_bytes(size: int | float, decimal: bool = False) -> str:
"""Convert a size in bytes to a human-readable string representation.
Args:
size (int | float): The size in bytes.
decimal: If True, use decimal units (e.g. 1000 bytes per KB).
If False, use binary units (e.g. 1024 bytes per KiB).
Returns:
str: The human-readable string representation of the size.
"""
return ByteSize(size).human_readable()
return ByteSize.human_readable(size, decimal=decimal)


def datetime_now() -> datetime:
Expand Down

0 comments on commit f06c178

Please sign in to comment.