From 92ed89c5d8272f77638ad57b53fae309a37d04b8 Mon Sep 17 00:00:00 2001 From: miili Date: Fri, 15 Mar 2024 21:44:30 +0100 Subject: [PATCH 1/6] search: fixes --- src/qseek/images/images.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index 74771ded..9b3eb44e 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -112,6 +112,10 @@ async def worker() -> None: "start pre-processing images, queue size %d", self._queue.maxsize ) async for batch in batch_iterator: + if batch.is_empty(): + logger.debug("empty batch, skipping") + continue + start_time = datetime_now() images = await self.process_traces(batch.traces) stats.time_per_batch = datetime_now() - start_time From ea1122340c7630cf8cdd0cbfd7f689ec5513d759 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 18 Mar 2024 10:32:54 +0000 Subject: [PATCH 2/6] bugfixes --- src/qseek/images/phase_net.py | 9 ++++----- src/qseek/models/semblance.py | 9 +++------ src/qseek/models/station.py | 17 ++++++++++++----- src/qseek/octree.py | 21 +++++++++++++++++++++ 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index ef5c2e33..60b03c34 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -113,16 +113,15 @@ def search_phase_arrival( peak_delay = peak_times - event_time.timestamp() # Limit to post-event peaks - post_event_peaks = peak_delay > 0.0 - peak_idx = peak_idx[post_event_peaks] - peak_times = peak_times[post_event_peaks] - peak_delay = peak_delay[post_event_peaks] + after_event_peaks = peak_delay > 0.0 + peak_idx = peak_idx[after_event_peaks] + peak_times = peak_times[after_event_peaks] + peak_delay = peak_delay[after_event_peaks] if not peak_idx.size: return None peak_values = search_trace.get_ydata()[peak_idx] - closest_peak_idx = np.argmin(peak_delay) return ObservedArrival( diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index 136eeef1..aa26868f 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -74,13 +74,10 @@ def _populate_table(self, table: Table) -> None: ) table.add_row( "Semblance size", - f"{human_readable_bytes(self.semblance_size_bytes)}" + f"{human_readable_bytes(self.semblance_size_bytes)}/" + f"{human_readable_bytes(self.semblance_allocation_bytes)}" f" ({self.last_nodes_stacked} nodes)", ) - table.add_row( - "Memory allocated", - f"{human_readable_bytes(self.semblance_allocation_bytes)}", - ) class SemblanceCache(dict[bytes, np.ndarray]): @@ -240,7 +237,7 @@ async def apply_cache(self, cache: SemblanceCache) -> None: self.semblance_unpadded, data, mask, - nthreads=1, + nthreads=8, ) def maximum_node_semblance(self) -> np.ndarray: diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index 10742598..ded665d2 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Iterator import numpy as np -from pydantic import BaseModel, Field, FilePath, constr +from pydantic import BaseModel, DirectoryPath, Field, FilePath, constr from pyrocko.io.stationxml import load_xml from pyrocko.model import Station as PyrockoStation from pyrocko.model import dump_stations_yaml, load_stations @@ -76,7 +76,7 @@ class Stations(BaseModel): description="List of [Pyrocko station YAML]" "(https://pyrocko.org/docs/current/formats/yaml.html) files.", ) - station_xmls: list[FilePath] = Field( + station_xmls: list[FilePath | DirectoryPath] = Field( default=[], description="List of StationXML files.", ) @@ -93,9 +93,16 @@ def model_post_init(self, __context: Any) -> None: for file in self.pyrocko_station_yamls: loaded_stations += load_stations(filename=str(file.expanduser())) - for file in self.station_xmls: - station_xml = load_xml(filename=str(file.expanduser())) - loaded_stations += station_xml.get_pyrocko_stations() + for path in self.station_xmls: + if path.is_dir(): + station_xmls = path.glob("*.xml") + elif path.is_file(): + station_xmls = [path] + else: + continue + for file in station_xmls: + station_xml = load_xml(filename=str(file.expanduser())) + loaded_stations += station_xml.get_pyrocko_stations() for sta in loaded_stations: sta = Station.from_pyrocko_station(sta) diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 1030165a..291c1c4a 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -709,6 +709,27 @@ def save_pickle(self, filename: Path) -> None: with filename.open("wb") as f: pickle.dump(self, f) + def get_corners(self) -> list[Location]: + """Get the corners of the octree. + + Returns: + list[Location]: List of locations. + """ + reference = self.location + return [ + Location( + lat=reference.lat, + lon=reference.lon, + elevation=reference.elevation, + east_shift=reference.east_shift + east, + north_shift=reference.north_shift + north, + depth=reference.depth + depth, + ) + for east in (self.east_bounds.min, self.east_bounds.max) + for north in (self.north_bounds.min, self.north_bounds.max) + for depth in (self.depth_bounds.min, self.depth_bounds.max) + ] + def __hash__(self) -> int: return hash( ( From 72f70da3e070833438a00ec036bd0cbf7eaca07a Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Thu, 21 Mar 2024 11:48:18 +0000 Subject: [PATCH 3/6] bugfixes --- pyproject.toml | 16 +- src/qseek/apps/qseek.py | 67 ++- src/qseek/corrections/base.py | 9 +- src/qseek/images/base.py | 8 +- src/qseek/images/images.py | 3 +- src/qseek/images/phase_net.py | 2 +- src/qseek/magnitudes/base.py | 17 +- src/qseek/magnitudes/local_magnitude.py | 8 +- src/qseek/magnitudes/local_magnitude_model.py | 2 +- src/qseek/magnitudes/moment_magnitude.py | 61 ++- .../magnitudes/moment_magnitude_store.py | 440 ++++++++++++------ src/qseek/models/catalog.py | 102 +++- src/qseek/models/detection.py | 180 ++++--- src/qseek/models/detection_uncertainty.py | 9 +- src/qseek/models/location.py | 7 +- src/qseek/models/semblance.py | 16 +- src/qseek/models/station.py | 5 +- src/qseek/octree.py | 27 +- src/qseek/pre_processing/base.py | 18 +- src/qseek/pre_processing/module.py | 2 +- src/qseek/search.py | 64 ++- src/qseek/tracers/cake.py | 7 +- src/qseek/utils.py | 55 +-- test/test_moment_magnitude_store.py | 47 +- 24 files changed, 773 insertions(+), 399 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 403fb1e3..a43ccfb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,13 +106,27 @@ extend-select = [ 'I', 'RUF', 'T20', + 'D', ] -ignore = ["RUF012", "RUF009"] +ignore = [ + "RUF012", + "RUF009", + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", +] [tool.ruff] target-version = 'py311' +[tool.ruff.lint.pydocstyle] +convention = "google" + [tool.pytest.ini_options] markers = ["plot: plot figures in tests"] diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index 3d6ddb70..72343179 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -11,6 +11,9 @@ import nest_asyncio from pkg_resources import get_distribution +from qseek.models.detection import EventDetection +from qseek.utils import get_cpu_count + nest_asyncio.apply() logger = logging.getLogger(__name__) @@ -128,6 +131,12 @@ type=Path, help="path of existing run", ) +features_extract.add_argument( + "--recalculate", + action="store_true", + default=False, + help="recalculate all magnitudes", +) modules = subparsers.add_parser( "modules", @@ -256,35 +265,55 @@ async def run() -> None: case "feature-extraction": search = Search.load_rundir(args.rundir) search.data_provider.prepare(search.stations) + recalculate_magnitudes = args.recalculate + + tasks = [] + + def console_status(task: asyncio.Task[EventDetection]): + detection = task.result() + if detection.magnitudes: + console.print( + f"Event {str(detection.time).split('.')[0]}:", + ", ".join( + f"[bold]{m.magnitude}[/bold] {m.average:.2f}±{m.error:.2f}" + for m in detection.magnitudes + ), + ) + else: + console.print(f"Event {detection.time}: No magnitudes") - async def extract() -> None: + async def worker() -> None: for magnitude in search.magnitudes: await magnitude.prepare(search.octree, search.stations) - - iterator = asyncio.as_completed( - tuple( - search.add_magnitude_and_features(detection) - for detection in search._catalog + await search.catalog.check(repair=True) + + sem = asyncio.Semaphore(get_cpu_count()) + for detection in track( + search.catalog, + description="Calculating magnitudes", + total=search.catalog.n_events, + console=console, + ): + await sem.acquire() + task = asyncio.create_task( + search.add_magnitude_and_features( + detection, + recalculate=recalculate_magnitudes, + ) ) - ) + tasks.append(task) + task.add_done_callback(lambda _: sem.release()) + task.add_done_callback(tasks.remove) + task.add_done_callback(console_status) - for result in track( - iterator, - description="Extracting features", - total=search._catalog.n_events, - ): - event = await result - if event.magnitudes: - for mag in event.magnitudes: - print(f"{mag.magnitude} {mag.average:.2f}±{mag.error:.2f}") # noqa: T201 - print("--") # noqa: T201 + await asyncio.gather(*tasks) await search._catalog.save() await search._catalog.export_detections( jitter_location=search.octree.smallest_node_size() ) - asyncio.run(extract(), debug=loop_debug) + asyncio.run(worker(), debug=loop_debug) case "corrections": import json @@ -391,7 +420,7 @@ def is_insight(module: type) -> bool: raise EnvironmentError(f"folder {args.folder} does not exist") file = args.folder / "search.schema.json" - print(f"writing JSON schemas to {args.folder}") # noqa: T201 + console.print(f"writing JSON schemas to {args.folder}") file.write_text(json.dumps(Search.model_json_schema(), indent=2)) file = args.folder / "detections.schema.json" diff --git a/src/qseek/corrections/base.py b/src/qseek/corrections/base.py index 10b614e5..7c5543e3 100644 --- a/src/qseek/corrections/base.py +++ b/src/qseek/corrections/base.py @@ -80,10 +80,11 @@ async def prepare( """Prepare the station for the corrections. Args: - station: The station to prepare. - octree: The octree to use for the preparation. - phases: The phases to prepare the station for. - rundir: The rundir to use for the delay. Defaults to None. + stations (Stations): The station to prepare. + octree (Octree): The octree to use for the preparation. + phases (Iterable[PhaseDescription]): The phases to prepare the station for. + rundir (Path | None, optional): The rundir to use for the delay. + Defaults to None. """ ... diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index c95ff365..a19621d5 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -33,8 +33,7 @@ def name(self) -> str: return self.__class__.__name__ def get_blinding(self, sampling_rate: float) -> timedelta: - """ - Blinding duration for the image function. Added to padded waveforms. + """Blinding duration for the image function. Added to padded waveforms. Args: sampling_rate (float): The sampling rate of the waveform. @@ -73,6 +72,7 @@ def set_stations(self, stations: Stations) -> None: def resample(self, sampling_rate: float, max_normalize: bool = False) -> None: """Resample traces in-place. + Args: sampling_rate (float): Desired sampling rate in Hz. max_normalize (bool): Normalize by maximum value to keep the scale of the @@ -137,7 +137,7 @@ def search_phase_arrival( trace_idx (int): Index of the trace. event_time (datetime): Time of the event. modelled_arrival (datetime): Time to search around. - search_length_seconds (float, optional): Total search length in seconds + search_window_seconds (float, optional): Total search length in seconds around modelled arrival time. Defaults to 5. threshold (float, optional): Threshold for detection. Defaults to 0.1. @@ -158,7 +158,7 @@ def search_phase_arrivals( Args: event_time (datetime): Time of the event. modelled_arrivals (list[datetime]): Time to search around. - search_length_seconds (float, optional): Total search length in seconds + search_window_seconds (float, optional): Total search length in seconds around modelled arrival time. Defaults to 5. threshold (float, optional): Threshold for detection. Defaults to 0.1. diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index 9b3eb44e..a9f3a77f 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -99,12 +99,11 @@ async def iter_images( """Iterate over images from batches. Args: - batches (AsyncIterator[Batch]): Async iterator over batches. + batch_iterator (AsyncIterator[Batch]): Async iterator over batches. Yields: AsyncIterator[WaveformImages]: Async iterator over images. """ - stats = self._stats async def worker() -> None: diff --git a/src/qseek/images/phase_net.py b/src/qseek/images/phase_net.py index 60b03c34..925552b8 100644 --- a/src/qseek/images/phase_net.py +++ b/src/qseek/images/phase_net.py @@ -59,7 +59,7 @@ def search_phase_arrival( trace_idx (int): Index of the trace. event_time (datetime): Time of the event. modelled_arrival (datetime): Time to search around. - search_length_seconds (float, optional): Total search length in seconds + search_window_seconds (float, optional): Total search length in seconds around modelled arrival time. Defaults to 5. threshold (float, optional): Threshold for detection. Defaults to 0.1. detection_blinding_seconds (float, optional): Blinding time in seconds for diff --git a/src/qseek/magnitudes/base.py b/src/qseek/magnitudes/base.py index 71567758..81ff900b 100644 --- a/src/qseek/magnitudes/base.py +++ b/src/qseek/magnitudes/base.py @@ -110,13 +110,23 @@ def get_subclasses(cls) -> tuple[type[EventMagnitudeCalculator], ...]: """ return tuple(cls.__subclasses__()) + def has_magnitude(self, event: EventDetection) -> bool: + """Check if the given event has a magnitude. + + Args: + event (EventDetection): The event to check. + + Returns: + bool: True if the event has a magnitude, False otherwise. + """ + raise NotImplementedError + async def add_magnitude( self, squirrel: Squirrel, event: EventDetection, ) -> None: - """ - Adds a magnitude to the squirrel for the given event. + """Adds a magnitude to the squirrel for the given event. Args: squirrel (Squirrel): The squirrel object to add the magnitude to. @@ -132,8 +142,7 @@ async def prepare( octree: Octree, stations: Stations, ) -> None: - """ - Prepare the magnitudes calculation by initializing necessary data structures. + """Prepare the magnitudes calculation by initializing necessary data structures. Args: octree (Octree): The octree containing seismic event data. diff --git a/src/qseek/magnitudes/local_magnitude.py b/src/qseek/magnitudes/local_magnitude.py index 2721cb72..3b3fff3f 100644 --- a/src/qseek/magnitudes/local_magnitude.py +++ b/src/qseek/magnitudes/local_magnitude.py @@ -164,6 +164,12 @@ def validate_model(self) -> Self: self._model = LocalMagnitudeModel.get_subclass_by_name(self.model)() return self + def has_magnitude(self, event: EventDetection) -> bool: + for mag in event.magnitudes: + if type(mag) is LocalMagnitude and mag.model == self.model: + return True + return False + async def add_magnitude(self, squirrel: Squirrel, event: EventDetection) -> None: model = self._model @@ -180,7 +186,7 @@ async def add_magnitude(self, squirrel: Squirrel, event: EventDetection) -> None cut_off_fade=cut_off_fade, quantity=model.restitution_quantity, phase=None, - remove_clipped=True, + filter_clipped=True, ) if not traces: logger.warning("No restituted traces found for event %s", event.time) diff --git a/src/qseek/magnitudes/local_magnitude_model.py b/src/qseek/magnitudes/local_magnitude_model.py index 9ae9b04e..db30b10b 100644 --- a/src/qseek/magnitudes/local_magnitude_model.py +++ b/src/qseek/magnitudes/local_magnitude_model.py @@ -159,7 +159,7 @@ def get_station_magnitude( try: traces = _COMPONENT_MAP[self.component](traces) except KeyError: - logger.warning("Could not get channels for %s", receiver.nsl.pretty) + logger.debug("Could not get channels for %s", receiver.nsl.pretty) return None if not traces: return None diff --git a/src/qseek/magnitudes/moment_magnitude.py b/src/qseek/magnitudes/moment_magnitude.py index ff15276e..edbd7af0 100644 --- a/src/qseek/magnitudes/moment_magnitude.py +++ b/src/qseek/magnitudes/moment_magnitude.py @@ -51,8 +51,7 @@ def norm_traces(traces: list[Trace]) -> np.ndarray: - """ - Normalizes the traces to their maximum absolute value. + """Normalizes the traces to their maximum absolute value. Args: traces (list[Trace]): The traces to normalize. @@ -79,13 +78,12 @@ class PeakAmplitudeDefinition(PeakAmplitudesBase): description="The epicentral distance range of the stations.", ) frequency_range: Range = Field( - default=Range(min=2.0, max=6.0), + default=Range(min=2.0, max=20.0), description="The frequency range in Hz to filter the traces.", ) def filter_receivers_by_nsl(self, receivers: Iterable[Receiver]) -> set[Receiver]: - """ - Filters the list of receivers based on the NSL ID. + """Filters the list of receivers based on the NSL ID. Args: receivers (list[Receiver]): The list of receivers to filter. @@ -108,8 +106,7 @@ def filter_receivers_by_range( receivers: Iterable[Receiver], event: EventDetection, ) -> set[Receiver]: - """ - Filters the list of receivers based on the distance range. + """Filters the list of receivers based on the distance range. Args: receivers (Iterable[Receiver]): The list of receivers to filter. @@ -127,11 +124,9 @@ def filter_receivers_by_range( class StationMomentMagnitude(NamedTuple): - # quantity: MeasurementUnit distance_epi: float magnitude: float error: float - peak: float @@ -153,11 +148,15 @@ def m0(self) -> float: @property def n_stations(self) -> int: - """ - Number of stations used for calculating the moment magnitude. - """ + """Number of stations used for calculating the moment magnitude.""" return len(self.stations_magnitudes) + def csv_row(self) -> dict[str, float]: + return { + "Mw": self.average, + "Mw-error": self.error, + } + async def add_traces( self, store: PeakAmplitudesStore, @@ -190,7 +189,7 @@ async def add_traces( continue try: - model = await store.get_amplitude( + model = await store.get_amplitude_model( source_depth=event.effective_depth, distance=station.distance_epi, n_amplitudes=25, @@ -201,9 +200,13 @@ async def add_traces( logger.warning("No modelled amplitude for receiver %s", receiver.nsl) continue - magnitude = model.get_magnitude(station.peak) - error_upper = model.get_magnitude(station.peak + station.noise) - magnitude - error_lower = model.get_magnitude(station.peak - station.noise) - magnitude + magnitude = model.estimate_magnitude(station.peak) + error_upper = ( + model.estimate_magnitude(station.peak + station.noise) - magnitude + ) + error_lower = ( + model.estimate_magnitude(station.peak - station.noise) - magnitude + ) if not np.isfinite(error_lower): error_lower = error_upper @@ -278,6 +281,11 @@ async def prepare(self, octree: Octree, stations: Stations) -> None: depth_delta=definition.source_depth_delta, ) + def has_magnitude(self, event: EventDetection) -> bool: + if not event.magnitudes: + return False + return any(type(mag) is MomentMagnitude for mag in event.magnitudes) + async def add_magnitude( self, squirrel: Squirrel, @@ -298,7 +306,11 @@ async def add_magnitude( logger.info("No receivers in range for peak amplitude") continue if not store.source_depth_range.inside(event.effective_depth): - logger.info("Event depth outside of store depth range.") + logger.info( + "Event depth %.1f outside of magnitude store range (%.1f - %.1f).", + event.effective_depth, + *store.source_depth_range, + ) continue traces = await event.receivers.get_waveforms_restituted( @@ -310,7 +322,7 @@ async def add_magnitude( demean=True, seconds_fade=self.padding_seconds, cut_off_fade=False, - remove_clipped=True, + filter_clipped=True, ) if not traces: continue @@ -318,14 +330,23 @@ async def add_magnitude( for tr in traces: if store.frequency_range.min != 0.0: await asyncio.to_thread( - tr.highpass, 4, store.frequency_range.min, demean=True + tr.highpass, + 4, + store.frequency_range.min, + demean=False, ) await asyncio.to_thread( - tr.lowpass, 4, store.frequency_range.max, demean=True + tr.lowpass, + 4, + store.frequency_range.max, + demean=False, ) tr.chop(tr.tmin + self.padding_seconds, tr.tmax - self.padding_seconds) if self.processed_mseed_export is not None: + logger.debug( + "saving processed mseed traces to %s", self.processed_mseed_export + ) io.save(traces, str(self.processed_mseed_export), append=True) grouped_traces = [] diff --git a/src/qseek/magnitudes/moment_magnitude_store.py b/src/qseek/magnitudes/moment_magnitude_store.py index 9adbd10f..527b1ac7 100644 --- a/src/qseek/magnitudes/moment_magnitude_store.py +++ b/src/qseek/magnitudes/moment_magnitude_store.py @@ -5,6 +5,7 @@ import itertools import logging import struct +from collections import defaultdict from functools import cached_property from pathlib import Path from typing import ( @@ -32,9 +33,9 @@ from pyrocko import gf from pyrocko.guts import Float from pyrocko.trace import FrequencyResponse -from rich.progress import track from typing_extensions import Self +from qseek.stats import PROGRESS from qseek.utils import ( ChannelSelector, ChannelSelectors, @@ -72,11 +73,11 @@ class MTSourceCircularCrack(gf.MTSource): duration = Float.T() stress_drop = Float.T() radius = Float.T() + magnitude = Float.T() def _get_target(targets: list[gf.Target], nsl: tuple[str, str, str]) -> gf.Target: - """ - Get the target from the list of targets based on the given NSL codes. + """Get the target from the list of targets based on the given NSL codes. Args: targets (list[gf.Target]): List of targets to search from. @@ -95,12 +96,11 @@ def _get_target(targets: list[gf.Target], nsl: tuple[str, str, str]) -> gf.Targe def trace_amplitude(traces: list[Trace], channel_selector: ChannelSelector) -> float: - """ - Normalize traces channels. + """Normalize traces channels. Args: traces (list[Trace]): A list of traces to normalize. - components (str): The components to normalize. + channel_selector (ChannelSelector): The channel selector to use. Returns: Trace: The normalized trace. @@ -141,7 +141,7 @@ class PeakAmplitudesBase(BaseModel): default=1.0, ge=-1.0, le=8.0, - description="Reference magnitude in Mw.", + description="Reference moment magnitude in Mw.", ) rupture_velocities: Range = Field( default=Range(0.8, 0.9), @@ -158,14 +158,21 @@ class PeakAmplitudesBase(BaseModel): class SiteAmplitude(NamedTuple): + magnitude: float distance_epi: float peak_horizontal: float peak_vertical: float peak_absolute: float @classmethod - def from_traces(cls, receiver: gf.Receiver, traces: list[Trace]) -> Self: + def from_traces( + cls, + receiver: gf.Receiver, + traces: list[Trace], + magnitude: float, + ) -> Self: return cls( + magnitude=magnitude, distance_epi=np.sqrt(receiver.north_shift**2 + receiver.east_shift**2), peak_horizontal=trace_amplitude(traces, ChannelSelectors.Horizontal), peak_vertical=trace_amplitude(traces, ChannelSelectors.Vertical), @@ -174,7 +181,7 @@ def from_traces(cls, receiver: gf.Receiver, traces: list[Trace]) -> Self: class ModelledAmplitude(NamedTuple): - reference_magnitude: float + magnitude: float quantity: MeasurementUnit peak_amplitude: PeakAmplitude distance_epi: float @@ -188,11 +195,10 @@ def combine( other: ModelledAmplitude, weight: float = 1.0, ) -> ModelledAmplitude: - """ - Combines with another ModelledAmplitude using a weighted average. + """Combines with another ModelledAmplitude using a weighted average. Args: - amplitude (ModelledAmplitude): The ModelledAmplitude to be combined with. + other (ModelledAmplitude): The ModelledAmplitude to be combined with. weight (float, optional): The weight of the amplitude being combined. Defaults to 1.0. @@ -210,13 +216,13 @@ def combine( raise ValueError("Cannot add amplitudes with different distances") if self.quantity != other.quantity: raise ValueError("Cannot add amplitudes with different quantities ") - if self.reference_magnitude != other.reference_magnitude: + if self.magnitude != other.magnitude: raise ValueError("Cannot add amplitudes with different reference magnitude") if self.peak_amplitude != other.peak_amplitude: raise ValueError("Cannot add amplitudes with different peak amplitudes ") rcp_weight = 1.0 - weight return ModelledAmplitude( - reference_magnitude=self.reference_magnitude, + magnitude=self.magnitude, peak_amplitude=self.peak_amplitude, quantity=self.quantity, distance_epi=self.distance_epi, @@ -226,9 +232,8 @@ def combine( mad=self.mad * rcp_weight + other.mad * weight, ) - def get_magnitude(self, observed_amplitude: float) -> float: - """ - Get the moment magnitude for the given observed amplitude. + def estimate_magnitude(self, observed_amplitude: float) -> float: + """Get the moment magnitude for the given observed amplitude. Args: observed_amplitude (float): The observed amplitude. @@ -236,13 +241,13 @@ def get_magnitude(self, observed_amplitude: float) -> float: Returns: float: The moment magnitude. """ - return self.reference_magnitude + np.log10(observed_amplitude / self.median) + with np.errstate(divide="ignore", invalid="ignore"): + return self.magnitude + np.log10(observed_amplitude / self.average) class SiteAmplitudesCollection(BaseModel): source_depth: float quantity: MeasurementUnit - reference_magnitude: float rupture_velocities: Range stress_drop: Range gf_store_id: str @@ -258,32 +263,34 @@ def wrapped(self) -> np.ndarray: return wrapped _distances = cached_property[np.ndarray](_get_numpy_array("distance_epi")) + _magnitudes = cached_property[np.ndarray](_get_numpy_array("magnitude")) _vertical = cached_property[np.ndarray](_get_numpy_array("peak_vertical")) _absolute = cached_property[np.ndarray](_get_numpy_array("peak_absolute")) _horizontal = cached_property[np.ndarray](_get_numpy_array("peak_horizontal")) def _clear_cache(self) -> None: - self.__dict__.pop("_distances", None) - self.__dict__.pop("_horizontal", None) - self.__dict__.pop("_vertical", None) - self.__dict__.pop("_absolute", None) + keys = {"_distances", "_horizontal", "_vertical", "_absolute", "_magnitudes"} + for key in keys: + self.__dict__.pop(key, None) - def get_amplitude( + def get_amplitude_model( self, distance: float, n_amplitudes: int, - max_distance: float = 0.0, + distance_cutoff: float = 0.0, + reference_magnitude: float = 1.0, peak_amplitude: PeakAmplitude = "absolute", ) -> ModelledAmplitude: - """ - Get the amplitudes for a given distance. + """Get the amplitudes for a given distance. Args: distance (float): The epicentral distance to retrieve the amplitudes for. n_amplitudes (int): The number of amplitudes to retrieve. - max_distance (float): The maximum distance allowed for + distance_cutoff (float): The maximum distance allowed for the retrieved amplitudes. If 0.0, no maximum distance is applied and the number of amplitudes will be exactly n_amplitudes. Defaults to 0.0. + reference_magnitude (float, optional): The reference magnitude to retrieve + the amplitudes for. Defaults to 1.0. peak_amplitude (PeakAmplitude, optional): The type of peak amplitude to retrieve. Defaults to "absolute". @@ -294,28 +301,35 @@ def get_amplitude( ValueError: If there are not enough amplitudes in the specified range. ValueError: If the peak amplitude type is unknown. """ - site_distances = np.abs(self._distances - distance) + magnitude_idx = np.where(self._magnitudes == reference_magnitude)[0] + if not magnitude_idx.size: + raise ValueError(f"No amplitudes for magnitude {reference_magnitude}.") + + site_distances = np.abs(self._distances[magnitude_idx] - distance) distance_idx = np.argsort(site_distances) + idx = distance_idx[:n_amplitudes] + distances = site_distances[idx] - if max_distance and distances.max() > max_distance: + if distance_cutoff and distances.max() > distance_cutoff: raise ValueError( - f"Not enough amplitudes at distance {distance} and range {max_distance}" + f"Not enough amplitudes at distance {distance}" + f" at cutoff {distance_cutoff}" ) match peak_amplitude: case "horizontal": - amplitudes = self._horizontal[idx] + amplitudes = self._horizontal[magnitude_idx][idx] case "vertical": - amplitudes = self._vertical[idx] + amplitudes = self._vertical[magnitude_idx][idx] case "absolute": - amplitudes = self._absolute[idx] + amplitudes = self._absolute[magnitude_idx][idx] case _: raise ValueError(f"Unknown peak amplitude type {peak_amplitude}.") median = float(np.median(amplitudes)) return ModelledAmplitude( - reference_magnitude=self.reference_magnitude, + magnitude=reference_magnitude, peak_amplitude=peak_amplitude, quantity=self.quantity, distance_epi=distance, @@ -325,14 +339,22 @@ def get_amplitude( mad=float(np.median(np.abs(amplitudes - median))), ) - def fill(self, receivers: list[gf.Receiver], traces: list[list[Trace]]) -> None: - for receiver, rcv_traces in zip(receivers, traces, strict=True): - self.site_amplitudes.append(SiteAmplitude.from_traces(receiver, rcv_traces)) + def fill( + self, + receivers: list[gf.Receiver], + traces: list[list[Trace]], + magnitudes: list[float], + ) -> None: + for receiver, rcv_traces, magnitude in zip( + receivers, traces, magnitudes, strict=True + ): + self.site_amplitudes.append( + SiteAmplitude.from_traces(receiver, rcv_traces, magnitude) + ) self._clear_cache() def distance_range(self) -> Range: - """ - Get the distance range of the site amplitudes. + """Get the distance range of the site amplitudes. Returns: Range: The distance range. @@ -341,8 +363,7 @@ def distance_range(self) -> Range: @property def n_amplitudes(self) -> int: - """ - Get the number of amplitudes in the collection. + """Get the number of amplitudes in the collection. Returns: int: The number of amplitudes. @@ -352,6 +373,7 @@ def n_amplitudes(self) -> int: def plot( self, axes: Axes | None = None, + reference_magnitude: float = 1.0, peak_amplitude: PeakAmplitude = "absolute", ) -> None: from matplotlib.ticker import FuncFormatter @@ -371,10 +393,11 @@ def plot( interp_amplitudes: list[ModelledAmplitude] = [] for distance in np.arange(*self.distance_range(), 250.0): interp_amplitudes.append( - self.get_amplitude( + self.get_amplitude_model( distance=distance, n_amplitudes=50, peak_amplitude=peak_amplitude, + reference_magnitude=reference_magnitude, ) ) @@ -417,7 +440,7 @@ def plot( 0.025, 0.025, f"""n={self.n_amplitudes} -$M_w^r$={self.reference_magnitude} +$M_w^{{ref}}$={reference_magnitude} $z$={self.source_depth / KM} km $v_r$=[{self.rupture_velocities.min}, {self.rupture_velocities.max}]$\\cdot v_s$ $\\Delta\\sigma$=[{self.stress_drop.min / 1e6}, {self.stress_drop.max / 1e6}] MPa @@ -448,7 +471,7 @@ class PeakAmplitudesStore(PeakAmplitudesBase): default_factory=uuid4, description="Unique ID of the amplitude store.", ) - site_amplitudes: list[SiteAmplitudesCollection] = Field( + amplitude_collections: list[SiteAmplitudesCollection] = Field( default_factory=list, description="Site amplitudes per source depth.", ) @@ -460,8 +483,15 @@ class PeakAmplitudesStore(PeakAmplitudesBase): default="", description="Hash of the GF store configuration.", ) + magnitude_range: Range = Field( + default=Range(0.0, 6.0), + description="Range of moment magnitudes for the seismic sources.", + ) _rng: np.random.Generator = PrivateAttr(default_factory=np.random.default_rng) + _access_locks: dict[int, asyncio.Lock] = PrivateAttr( + default_factory=lambda: defaultdict(asyncio.Lock) + ) _engine: ClassVar[gf.LocalEngine | None] = None _cache_dir: ClassVar[Path | None] = None @@ -469,8 +499,7 @@ class PeakAmplitudesStore(PeakAmplitudesBase): @classmethod def set_engine(cls, engine: gf.LocalEngine) -> None: - """ - Set the GF engine for the store. + """Set the GF engine for the store. Args: engine (gf.LocalEngine): The engine to use. @@ -479,8 +508,7 @@ def set_engine(cls, engine: gf.LocalEngine) -> None: @classmethod def set_cache_dir(cls, cache_dir: Path) -> None: - """ - Set the cache directory for the store. + """Set the cache directory for the store. Args: cache_dir (Path): The cache directory to use. @@ -489,8 +517,7 @@ def set_cache_dir(cls, cache_dir: Path) -> None: @classmethod def from_selector(cls, selector: PeakAmplitudesBase) -> Self: - """ - Create a new PeakAmplitudesStore from the given selector. + """Create a new PeakAmplitudesStore from the given selector. Args: selector (PeakAmplitudesSelector): The selector to use. @@ -498,7 +525,6 @@ def from_selector(cls, selector: PeakAmplitudesBase) -> Self: Returns: PeakAmplitudesStore: The newly created store. """ - if cls._engine is None: raise EnvironmentError( "No GF engine available to determine frequency range." @@ -525,12 +551,11 @@ def from_selector(cls, selector: PeakAmplitudesBase) -> Self: @property def source_depth_range(self) -> Range: - return Range.from_list([sa.source_depth for sa in self.site_amplitudes]) + return Range.from_list([sa.source_depth for sa in self.amplitude_collections]) @property def gf_store_depth_range(self) -> Range: - """ - Get the depth range of the GF store. + """Get the depth range of the GF store. Returns: Range: The depth range. @@ -540,8 +565,7 @@ def gf_store_depth_range(self) -> Range: @property def gf_store_distance_range(self) -> Range: - """ - Returns the distance range for the ground motion store. + """Returns the distance range for the ground motion store. The distance range is determined by the minimum and maximum distances specified in the store's configuration. If the maximum distance exceeds @@ -557,10 +581,20 @@ def gf_store_distance_range(self) -> Range: max=min(store.config.distance_max, self.max_distance), ) - def get_store(self) -> gf.Store: - """ - Load the GF store for the given store ID. + def get_lock(self, source_depth: float, reference_magnitude: float) -> asyncio.Lock: + """Get the lock for the given source depth and reference magnitude. + + Args: + source_depth (float): The source depth. + reference_magnitude (float): The reference magnitude. + + Returns: + asyncio.Lock: The lock for the given source depth and reference magnitude. """ + return self._access_locks[hash((source_depth, reference_magnitude))] + + def get_store(self) -> gf.Store: + """Load the GF store for the given store ID.""" if self._engine is None: raise EnvironmentError("No GF engine available.") @@ -581,13 +615,18 @@ def get_store(self) -> gf.Store: return store def _get_random_source( - self, depth: float, stf: Type[gf.STF] | None = None + self, + depth: float, + reference_magnitude: float, + stf: Type[gf.STF] | None = None, ) -> MTSourceCircularCrack: - """ - Generates a random seismic source with the given depth. + """Generates a random seismic source with the given depth. Args: depth (float): The depth of the seismic source. + reference_magnitude (float): The reference moment magnitude. + stf (Type[gf.STF], optional): The source time function to use. + Defaults to None. Returns: gf.MTSource: A random moment tensor source. @@ -601,17 +640,18 @@ def _get_random_source( rupture_velocity = rng.uniform(*self.rupture_velocities) * vs radius = ( - pmt.magnitude_to_moment(self.reference_magnitude) * (7 / 16) / stress_drop + pmt.magnitude_to_moment(reference_magnitude) * (7 / 16) / stress_drop ) ** (1 / 3) duration = 1.5 * radius / rupture_velocity - moment_tensor = pmt.MomentTensor.random_dc(magnitude=self.reference_magnitude) + moment_tensor = pmt.MomentTensor.random_dc(magnitude=reference_magnitude) return MTSourceCircularCrack( + magnitude=reference_magnitude, + stress_drop=stress_drop, + radius=radius, m6=moment_tensor.m6(), depth=depth, duration=duration, - stress_drop=stress_drop, - radius=radius, - stf=stf(duration=duration) if stf else None, + stf=stf(effective_duration=duration) if stf else None, ) def _get_random_targets( @@ -619,18 +659,22 @@ def _get_random_targets( distance_range: Range, n_receivers: int, ) -> list[gf.Target]: - """ - Generate a list of receivers with random angles and distances. + """Generate a list of receivers with random angles and distances. Args: + distance_range (Range): The range of distances to generate the + receivers for. n_receivers (int): The number of receivers to generate. Returns: list[gf.Receiver]: A list of receivers with random angles and distances. """ rng = self._rng + _distance_range = np.array(distance_range) + _distance_range[_distance_range <= 0.0] = 1.0 # Add an epsilon + angles = rng.uniform(0.0, 360.0, size=n_receivers) - distances = np.exp(rng.uniform(*np.log(distance_range), size=n_receivers)) + distances = np.exp(rng.uniform(*np.log(_distance_range), size=n_receivers)) targets: list[gf.Receiver] = [] for i_receiver, (angle, distance) in enumerate( @@ -651,20 +695,22 @@ def _get_random_targets( targets.append(target) return targets # type: ignore - async def fill_source_depth( + async def compute_site_amplitudes( self, source_depth: float, + reference_magnitude: float, n_sources: int = 200, n_targets_per_source: int = 20, ) -> SiteAmplitudesCollection: - """ - Fills the moment magnitude store with amplitudes calculated - for a specific source depth. + """Fills the moment magnitude store. + + Calculates the amplitudes for a given source depth and reference magnitude. Args: source_depth (float): The depth of the seismic source. - n_targets (int, optional): The number of target locations to calculate - amplitudes for. Defaults to 20. + reference_magnitude (float): The reference moment magnitude. + n_targets_per_source (int, optional): The number of target locations to + calculate amplitudes for. Defaults to 20. n_sources (int, optional): The number of source locations to generate random sources from. Defaults to 100. """ @@ -677,20 +723,23 @@ async def fill_source_depth( target_distances = self.gf_store_distance_range logger.info( - "calculating %d amplitudes for depth %f", + "calculating %d %s amplitudes for Mw %.1f at depth %.1f", n_sources * n_targets_per_source, + self.quantity, + reference_magnitude, source_depth, ) receivers = [] receiver_traces = [] - for _ in track( - range(n_sources), + magnitudes = [] + status = PROGRESS.add_task( + f"Calculating Mw {reference_magnitude} amplitudes for depth {source_depth}", total=n_sources, - description=f"calculating amplitudes for depth {source_depth}", - ): + ) + for _ in range(n_sources): targets = self._get_random_targets(target_distances, n_targets_per_source) - source = self._get_random_source(source_depth) + source = self._get_random_source(source_depth, reference_magnitude) response = await asyncio.to_thread(engine.process, source, targets) traces: list[Trace] = response.pyrocko_traces() @@ -707,12 +756,15 @@ async def fill_source_depth( ): receivers.append(_get_target(targets, nsl)) receiver_traces.append(list(grp_traces)) + magnitudes.append(reference_magnitude) + PROGRESS.update(status, advance=1) + PROGRESS.remove_task(status) try: collection = self.get_collection(source_depth) except KeyError: collection = self.new_collection(source_depth) - collection.fill(receivers, receiver_traces) + collection.fill(receivers, receiver_traces, magnitudes) self.save() return collection @@ -724,8 +776,7 @@ async def fill_source_depth_range( n_sources: int = 400, n_targets_per_source: int = 20, ) -> None: - """ - Fills the source depth range with seismic data. + """Fills the source depth range with seismic data. Args: depth_min (float): The minimum depth of the source in meters. @@ -756,7 +807,7 @@ async def fill_source_depth_range( depths = np.arange(gf_depth_min, gf_depth_max, depth_delta) calculate_depths = depths[(depths >= depth_min) & (depths <= depth_max)] - stored_depths = [sa.source_depth for sa in self.site_amplitudes] + stored_depths = [sa.source_depth for sa in self.amplitude_collections] logger.debug("filling source depths %s", calculate_depths) for depth in calculate_depths: if depth in stored_depths: @@ -764,95 +815,108 @@ async def fill_source_depth_range( self.remove_collection(depth) else: continue - await self.fill_source_depth( - source_depth=depth, - n_sources=n_sources, - n_targets_per_source=n_targets_per_source, - ) + async with self.get_lock(depth, self.reference_magnitude): + await self.compute_site_amplitudes( + reference_magnitude=self.reference_magnitude, + source_depth=depth, + n_sources=n_sources, + n_targets_per_source=n_targets_per_source, + ) def get_collection(self, source_depth: float) -> SiteAmplitudesCollection: - """ - Get the site amplitudes collection for the given source depth. + """Get the site amplitudes collection for the given source depth. Args: - depth (float): The source depth. + source_depth (float): The source depth. Returns: SiteAmplitudesCollection: The site amplitudes collection. """ - for site_amplitudes in self.site_amplitudes: + for site_amplitudes in self.amplitude_collections: if site_amplitudes.source_depth == source_depth: return site_amplitudes raise KeyError(f"No site amplitudes for depth {source_depth}.") - def new_collection(self, depth: float) -> SiteAmplitudesCollection: - """ - Creates a new SiteAmplitudesCollection object for the given depth and - adds it to the list of site amplitudes. + def new_collection(self, source_depth: float) -> SiteAmplitudesCollection: + """Creates a new SiteAmplitudesCollection object. + + For the given depth and add it to the list of site amplitudes. Args: - depth (float): The depth for which the site amplitudes collection is + source_depth (float): The depth for which the site amplitudes collection is created. Returns: SiteAmplitudesCollection: The newly created SiteAmplitudesCollection object. """ - logger.debug("creating new site amplitudes for depth %f", depth) - self.remove_collection(depth) + logger.debug("creating new site amplitudes for depth %f", source_depth) + self.remove_collection(source_depth) collection = SiteAmplitudesCollection( - source_depth=depth, + source_depth=source_depth, **self.model_dump(exclude={"site_amplitudes"}), ) - self.site_amplitudes.append(collection) + self.amplitude_collections.append(collection) return collection def remove_collection(self, depth: float) -> None: - """ - Removes the site amplitudes collection for the given depth. + """Removes the site amplitudes collection for the given depth. Args: depth (float): The depth for which the site amplitudes collection is removed. """ - logger.debug("removing site amplitudes for depth %f", depth) try: collection = self.get_collection(depth) - self.site_amplitudes.remove(collection) + self.amplitude_collections.remove(collection) + logger.debug("removed site amplitudes for depth %f", depth) except KeyError: pass - async def get_amplitude( + async def get_amplitude_model( self, source_depth: float, distance: float, n_amplitudes: int = 25, - max_distance: float = 0.0, + distance_cutoff: float = 0.0, + reference_magnitude: float | None = None, peak_amplitude: PeakAmplitude = "absolute", auto_fill: bool = True, interpolation: Literal["nearest", "linear"] = "linear", ) -> ModelledAmplitude: - """ - Retrieves the amplitude for a given depth and distance. + """Retrieves the amplitude for a given depth and distance. Args: - depth (float): The depth of the event. + source_depth (float): The depth of the event. distance (float): The epicentral distance from the event. n_amplitudes (int, optional): The number of amplitudes to retrieve. Defaults to 10. - max_distance (float, optional): The maximum distance to consider in [m]. - Defaults to 1000.0. + distance_cutoff (float, optional): The maximum distance allowed for + the retrieved amplitudes. If 0.0, no maximum distance is applied and the + number of amplitudes will be exactly n_amplitudes. Defaults to 0.0. + reference_magnitude (float, optional): The reference moment magnitude + for the amplitudes. Defaults to 1.0. peak_amplitude (PeakAmplitude, optional): The type of peak amplitude to retrieve. Defaults to "absolute". - auto_fill (bool, optional): If True, the site amplitudes are calculated + auto_fill (bool, optional): If True, the site amplitudes for + depth-reference magnitude combinations are calculated if they are not available. Defaults to True. + interpolation (Literal["nearest", "linear"], optional): The depth + interpolation method to use. Defaults to "linear". Returns: - ModelledAmplitude: The modelled amplitude for the given depth and distance. + ModelledAmplitude: The modelled amplitude for the given depth, distance and + reference magnitude. """ if not self.source_depth_range.inside(source_depth): raise ValueError(f"Source depth {source_depth} outside range.") - source_depths = np.array([sa.source_depth for sa in self.site_amplitudes]) + source_depths = np.array([sa.source_depth for sa in self.amplitude_collections]) + reference_magnitude = ( + self.reference_magnitude + if reference_magnitude is None + else reference_magnitude + ) + match interpolation: case "nearest": idx = [np.abs(source_depths - source_depth).argmin()] @@ -861,32 +925,43 @@ async def get_amplitude( case _: raise ValueError(f"Unknown interpolation method {interpolation}.") - collections = [self.site_amplitudes[i] for i in idx] + collections = [self.amplitude_collections[i] for i in idx] amplitudes: list[ModelledAmplitude] = [] + for collection in collections: + lock = self.get_lock(collection.source_depth, reference_magnitude) try: - amplitude = collection.get_amplitude( + await lock.acquire() + amplitude = collection.get_amplitude_model( distance=distance, n_amplitudes=n_amplitudes, - max_distance=max_distance, + distance_cutoff=distance_cutoff, peak_amplitude=peak_amplitude, + reference_magnitude=reference_magnitude, ) amplitudes.append(amplitude) - except ValueError: + except ValueError as e: + logger.exception(e) if auto_fill: - await self.fill_source_depth(source_depth) - logger.info("auto-filling amplitudes for depth %f", source_depth) - return await self.get_amplitude( + await self.compute_site_amplitudes( + source_depth=collection.source_depth, + reference_magnitude=reference_magnitude, + ) + lock.release() + return await self.get_amplitude_model( source_depth=source_depth, distance=distance, n_amplitudes=n_amplitudes, - max_distance=max_distance, + reference_magnitude=reference_magnitude, + distance_cutoff=distance_cutoff, peak_amplitude=peak_amplitude, interpolation=interpolation, auto_fill=True, ) + lock.release() raise + lock.release() if not amplitudes: raise ValueError(f"No site amplitudes for depth {source_depth}.") @@ -896,11 +971,8 @@ async def get_amplitude( amplitude = amplitudes[0] case "linear": - if len(amplitudes) != 2: - raise ValueError( - f"Cannot interpolate amplitudes with {len(amplitudes)} " - f" source depths." - ) + if len(amplitudes) == 1: + return amplitudes[0] depths = source_depths[idx] weight = abs((source_depth - depths[0]) / abs(depths[1] - depths[0])) amplitude = amplitudes[0].combine(amplitudes[1], weight=weight) @@ -910,9 +982,85 @@ async def get_amplitude( raise ValueError(f"Median amplitude is zero for depth {source_depth}.") return amplitude - def hash(self) -> str: + async def find_moment_magnitude( + self, + source_depth: float, + distance: float, + observed_amplitude: float, + n_amplitudes: int = 25, + distance_cutoff: float = 0.0, + initial_reference_magnitude: float = 1.0, + peak_amplitude: PeakAmplitude = "absolute", + interpolation: Literal["nearest", "linear"] = "linear", + ) -> tuple[float, ModelledAmplitude]: + """Get the moment magnitude for the given observed amplitude. + + Args: + source_depth (float): The depth of the event. + distance (float): The epicentral distance from the event. + observed_amplitude (float): The observed amplitude. + n_amplitudes (int, optional): The number of amplitudes to retrieve. + Defaults to 10. + initial_reference_magnitude (float, optional): The initial reference + moment magnitude to use. Defaults to 1.0. + distance_cutoff (float, optional): The maximum distance allowed for + the retrieved amplitudes. If 0.0, no maximum distance is applied and the + number of amplitudes will be exactly n_amplitudes. Defaults to 0.0. + peak_amplitude (PeakAmplitude, optional): The type of peak amplitude to + retrieve. Defaults to "absolute". + interpolation (Literal["nearest", "linear"], optional): The depth + interpolation method to use. Defaults to "linear". + + Returns: + float: The moment magnitude. """ - Calculate the hash of the store from store parameters. + cache: list[tuple[float, float, ModelledAmplitude]] = [] + + def get_cache(reference_magnitude: float) -> tuple[float, ModelledAmplitude]: + for mag, est, model in cache: + if mag == reference_magnitude: + return est, model + raise KeyError(f"No estimate for magnitude {reference_magnitude}.") + + async def estimate_magnitude( + reference_magnitude: float, + ) -> tuple[float, ModelledAmplitude]: + try: + return get_cache(reference_magnitude) + except KeyError: + model = await self.get_amplitude_model( + reference_magnitude=reference_magnitude, + source_depth=source_depth, + distance=distance, + n_amplitudes=n_amplitudes, + distance_cutoff=distance_cutoff, + peak_amplitude=peak_amplitude, + interpolation=interpolation, + ) + est_magnitude = model.estimate_magnitude(observed_amplitude) + cache.append((reference_magnitude, est_magnitude, model)) + return est_magnitude, model + + reference_mag = initial_reference_magnitude + for _ in range(3): + est_magnitude, _ = await estimate_magnitude(reference_mag) + rounded_mag = np.round(est_magnitude, 0) + explore_mags = np.array([rounded_mag - 1, rounded_mag, rounded_mag + 1]) + + predictions = [await estimate_magnitude(mag) for mag in explore_mags] + predicted_mags = np.array([mag for mag, _ in predictions]) + models = [model for _, model in predictions] + + magnitude_differences = np.abs(predicted_mags - explore_mags) + min_diff = np.argmin(magnitude_differences) + + if min_diff == 1: + return predicted_mags[1], models[1] + reference_mag = explore_mags[min_diff] + return predicted_mags[min_diff], models[min_diff] + + def hash(self) -> str: + """Calculate the hash of the store from store parameters. Returns: str: The hash of the store. @@ -931,8 +1079,7 @@ def hash(self) -> str: return hashlib.sha1(data).hexdigest() def is_suited(self, selector: PeakAmplitudesBase) -> bool: - """ - Check if the given selector is suited for this store. + """Check if the given selector is suited for this store. Args: selector (PeakAmpliutdesSelector): The selector to check. @@ -957,8 +1104,7 @@ def __hash__(self) -> int: return hash(self.hash()) def save(self, path: Path | None = None) -> None: - """ - Save the site amplitudes to a JSON file. + """Save the site amplitudes to a JSON file. The site amplitudes are saved in a directory called 'site_amplitudes' within the cache directory. The file name is generated based on the store ID and @@ -995,8 +1141,7 @@ def __init__(self, cache_dir: Path, engine: gf.LocalEngine | None = None) -> Non PeakAmplitudesStore.set_cache_dir(cache_dir) def clear_cache(self): - """ - Clear the cache directory. + """Clear the cache directory. This method deletes all files in the cache directory. """ @@ -1005,8 +1150,7 @@ def clear_cache(self): file.unlink() def clean_cache(self, keep_files: int = 100) -> None: - """ - Clean the cache directory. + """Clean the cache directory. Args: keep_files (int, optional): The number of most recent files to keep in the @@ -1020,8 +1164,7 @@ def clean_cache(self, keep_files: int = 100) -> None: file.unlink() def cache_stats(self) -> CacheStats: - """ - Get the cache statistics. + """Get the cache statistics. Returns: CacheStats: The cache statistics. @@ -1036,8 +1179,7 @@ def cache_stats(self) -> CacheStats: def get_cached_stores( self, store_id: str, quantity: MeasurementUnit ) -> list[PeakAmplitudesStore]: - """ - Get the cached peak amplitude stores for the given store ID and quantity. + """Get the cached peak amplitude stores for the given store ID and quantity. Args: store_id (str): The store ID. @@ -1065,9 +1207,9 @@ def get_cached_stores( return stores def get_store(self, selector: PeakAmplitudesBase) -> PeakAmplitudesStore: - """ - Get a peak amplitude store for the given selector, either from the cache - or by creating a new store. + """Get a peak amplitude store for the given selector. + + Either from the cache or by creating a new store. Args: selector (PeakAmplitudesSelector): The selector to use. diff --git a/src/qseek/models/catalog.py b/src/qseek/models/catalog.py index 7d27439a..e25c7f64 100644 --- a/src/qseek/models/catalog.py +++ b/src/qseek/models/catalog.py @@ -1,14 +1,16 @@ from __future__ import annotations +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Iterator import aiofiles -from pydantic import BaseModel, PrivateAttr +from pydantic import BaseModel, PrivateAttr, computed_field from pyrocko import io from pyrocko.gui import marker from pyrocko.model import Event, dump_events from pyrocko.trace import Trace +from rich.progress import track from rich.table import Table from qseek.console import console @@ -31,6 +33,29 @@ class EventCatalogStats(Stats): max_semblance: float = 0.0 _position: int = 2 + _catalog: EventCatalog = PrivateAttr() + + def set_catalog(self, catalog: EventCatalog) -> None: + self._catalog = catalog + self.n_detections = catalog.n_events + + @property + def magnitudes(self) -> list[float]: + return [det.magnitude.average for det in self._catalog if det.magnitude] + + @computed_field + def mean_semblance(self) -> float: + return ( + sum(detection.semblance for detection in self._catalog) / self.n_detections + ) + + @computed_field + def magnitude_min(self) -> float: + return min(self.magnitudes) if self.magnitudes else 0.0 + + @computed_field + def magnitude_max(self) -> float: + return max(self.magnitudes) if self.magnitudes else 0.0 def new_detection(self, detection: EventDetection): self.n_detections += 1 @@ -51,7 +76,7 @@ def model_post_init(self, __context: Any) -> None: @property def n_events(self) -> int: - """Number of detections""" + """Number of detections.""" return len(self.events) @property @@ -66,6 +91,29 @@ def csv_dir(self) -> Path: dir.mkdir(exist_ok=True) return dir + async def filter_events_by_time( + self, + start_time: datetime | None, + end_time: datetime | None, + ) -> None: + """Filter the detections based on the given time range. + + Args: + start_time (datetime | None): Start time of the time range. + end_time (datetime | None): End time of the time range. + """ + events = [] + if start_time is not None and min(det.time for det in self.events) < start_time: + logger.info("filtering detections after start time %s", start_time) + events = [det for det in self.events if det.time >= start_time] + if end_time is not None and max(det.time for det in self.events) > end_time: + logger.info("filtering detections before end time %s", end_time) + events = [det for det in self.events if det.time <= end_time] + if events: + self.events = events + self._stats.n_detections = len(self.events) + await self.save() + async def add(self, detection: EventDetection) -> None: detection.set_index(self.n_events) @@ -126,10 +174,48 @@ def load_rundir(cls, rundir: Path) -> EventCatalog: stats = catalog._stats stats.n_detections = catalog.n_events - if catalog: + if catalog and catalog.n_events: stats.max_semblance = max(detection.semblance for detection in catalog) return catalog + async def check(self, repair: bool = True) -> None: + """Check the catalog for errors and inconsistencies. + + Args: + repair (bool, optional): If True, attempt to repair the catalog. + Defaults to True. + """ + logger.info("checking catalog...") + found_bad = 0 + found_duplicates = 0 + event_uids = set() + for detection in track( + self.events.copy(), + description=f"checking {self.n_events} events...", + ): + try: + _ = detection.receivers + except ValueError: + found_bad += 1 + if repair: + self.events.remove(detection) + + if detection.uid in event_uids: + found_duplicates += 1 + if repair: + self.events.remove(detection) + + event_uids.add(detection.uid) + + if found_bad or found_duplicates: + logger.info("found %d detections with invalid receivers", found_bad) + logger.info("found %d duplicate detections", found_duplicates) + if repair: + logger.info("repairing catalog") + await self.save() + else: + logger.info("all detections are ok") + async def save(self) -> None: """Save catalog to current rundir.""" logger.debug("saving %d detections", self.n_events) @@ -148,8 +234,7 @@ async def save(self) -> None: await f.writelines(lines_recv) async def export_detections(self, jitter_location: float = 0.0) -> None: - """ - Export detections to CSV and Pyrocko event lists in the current rundir. + """Export detections to CSV and Pyrocko event lists in the current rundir. Args: jitter_location (float): The amount of jitter in [m] to apply @@ -178,6 +263,7 @@ async def export_csv(self, file: Path, jitter_location: float = 0.0) -> None: jitter_location (float, optional): Randomize the location of each detection by this many meters. Defaults to 0.0. """ + logger.info("saving event CSV to %s", file) header = [] if jitter_location: @@ -225,10 +311,12 @@ def get_pyrocko_markers(self) -> list[EventMarker | PhaseMarker]: def export_pyrocko_events( self, filename: Path, jitter_location: float = 0.0 ) -> None: - """Export Pyrocko events for all detections to a file + """Export Pyrocko events for all detections to a file. Args: filename (Path): output filename + jitter_location (float, optional): Randomize the location of each detection + by this many meters. Defaults to 0.0. """ logger.info("saving Pyrocko events to %s", filename) detections = self.events @@ -241,7 +329,7 @@ def export_pyrocko_events( ) def export_pyrocko_markers(self, filename: Path) -> None: - """Export Pyrocko markers for all detections to a file + """Export Pyrocko markers for all detections to a file. Args: filename (Path): output filename diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index f0996235..bc24b70e 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -51,6 +51,7 @@ FILENAME_RECEIVERS = "detections_receivers.json" UPDATE_LOCK = asyncio.Lock() +SQUIRREL_SEM = asyncio.Semaphore(64) class ReceiverCache: @@ -70,11 +71,40 @@ def load(self) -> None: self.lines = self.file.read_text().splitlines() self.mtime = self.file.stat().st_mtime - def get_row(self, row_index: int) -> str: + def _check_mtime(self) -> None: if self.mtime is None or self.mtime != self.file.stat().st_mtime: self.load() + + def get_line(self, row_index: int) -> str: + """Retrieves the line at the specified row index. + + Args: + row_index (int): The index of the row to retrieve. + + Returns: + str: The line at the specified row index. + """ + self._check_mtime() return self.lines[row_index] + def find_uid(self, uid: UUID) -> tuple[int, str]: + """Find the given UID in the lines and return its index and value. + + get_line should be prefered over this method. + + Args: + uid (UUID): The UID to search for. + + Returns: + tuple[int, str]: A tuple containing the index and value of the found UID. + """ + self._check_mtime() + find_uid = str(uid) + for iline, line in enumerate(self.lines): + if find_uid in line: + return iline, line + raise KeyError + class PhaseDetection(BaseModel): phase: PhaseDescription @@ -114,8 +144,7 @@ def _get_csv_dict(self) -> dict[str, Any]: return csv_dict def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: - """ - Convert the observed and modeled arrivals to a list of Pyrocko PhaseMarkers. + """Convert the observed and modeled arrivals to a list of Pyrocko PhaseMarkers. Returns: list[marker.PhaseMarker]: List of Pyrocko PhaseMarker objects representing @@ -151,8 +180,7 @@ def add_phase_detection(self, arrival: PhaseDetection) -> None: self.phase_arrivals[arrival.phase] = arrival def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: - """ - Convert the phase arrivals to Pyrocko markers. + """Convert the phase arrivals to Pyrocko markers. Returns: A list of Pyrocko PhaseMarker objects. @@ -168,8 +196,7 @@ def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: def get_arrivals_time_window( self, phase: PhaseDescription | None = None ) -> tuple[datetime, datetime]: - """ - Get the time window for phase arrivals. + """Get the time window for phase arrivals. Args: phase (PhaseDescription | None): Optional phase description. @@ -198,18 +225,18 @@ class EventReceivers(BaseModel): @property def n_receivers(self) -> int: - """Number of receivers in the receiver set""" + """Number of receivers in the receiver set.""" return len(self.receivers) def n_observations(self, phase: PhaseDescription) -> int: - """Number of observations for a given phase""" + """Number of observations for a given phase.""" n_observations = 0 for receiver in self: if (arrival := receiver.phase_arrivals.get(phase)) and arrival.observed: n_observations += 1 return n_observations - def get_waveforms( + async def get_waveforms( self, squirrel: Squirrel, seconds_before: float = 3.0, @@ -217,8 +244,7 @@ def get_waveforms( phase: PhaseDescription | None = None, receivers: Iterable[Receiver] | None = None, ) -> list[Trace]: - """ - Retrieves and restitutes waveforms for a given squirrel. + """Retrieves and restitutes waveforms for a given squirrel. Args: squirrel (Squirrel): The squirrel waveform organizer. @@ -247,13 +273,15 @@ def get_waveforms( tmin = min(times).timestamp() - seconds_before tmax = max(times).timestamp() + seconds_after nslc_ids = [(*receiver.nsl, "*") for receiver in receivers] - traces = squirrel.get_waveforms( - codes=nslc_ids, - tmin=tmin, - tmax=tmax, - accessor_id=accessor_id, - want_incomplete=False, - ) + async with SQUIRREL_SEM: + traces = await asyncio.to_thread( + squirrel.get_waveforms, + codes=nslc_ids, + tmin=tmin, + tmax=tmax, + accessor_id=accessor_id, + want_incomplete=False, + ) squirrel.advance_accessor(accessor_id, cache_id="waveform") for tr in traces: @@ -274,12 +302,11 @@ async def get_waveforms_restituted( phase: PhaseDescription | None = None, quantity: MeasurementUnit = "velocity", demean: bool = True, - remove_clipped: bool = False, + filter_clipped: bool = False, freqlimits: tuple[float, float, float, float] = (0.01, 0.1, 25.0, 35.0), receivers: Iterable[Receiver] | None = None, ) -> list[Trace]: - """ - Retrieves and restitutes waveforms for a given squirrel. + """Retrieves and restitutes waveforms for a given squirrel. Args: squirrel (Squirrel): The squirrel waveform organizer. @@ -302,20 +329,20 @@ async def get_waveforms_restituted( The frequency limits. Defaults to (0.01, 0.1, 25.0, 35.0). receivers (list[Receiver] | None, optional): The receivers to retrieve waveforms for. If None, all receivers are retrieved. Defaults to None. + filter_clipped (bool, optional): Whether to filter clipped traces. + Defaults to False. Returns: list[Trace]: The restituted waveforms. """ - traces = await asyncio.to_thread( - self.get_waveforms, + traces = await self.get_waveforms( squirrel, phase=phase, seconds_after=seconds_after + seconds_fade, seconds_before=seconds_before + seconds_fade, receivers=receivers, ) - traces = filter_clipped_traces(traces) if remove_clipped else traces - + traces = filter_clipped_traces(traces) if filter_clipped else traces if not traces: return [] @@ -356,8 +383,7 @@ def get_response(tr: Trace) -> Any: return restituted_traces def get_receiver(self, nsl: NSL) -> Receiver: - """ - Get the receiver object based on given NSL tuple. + """Get the receiver object based on given NSL tuple. Args: nsl (tuple[str, str, str]): The network, station, and location tuple. @@ -378,7 +404,7 @@ def add( stations: Stations, phase_arrivals: list[PhaseDetection | None], ) -> None: - """Add receivers to the receiver set + """Add receivers to the receiver set. Args: stations: List of stations @@ -398,8 +424,7 @@ def add( receiver.add_phase_detection(arrival) def get_by_nsl(self, nsl: NSL) -> Receiver: - """ - Retrieves a receiver object by its NSL (network, station, location) tuple. + """Retrieves a receiver object by its NSL (network, station, location) tuple. Args: nsl (NSL): The NSL tuple representing @@ -417,8 +442,7 @@ def get_by_nsl(self, nsl: NSL) -> Receiver: raise KeyError(f"cannot find station {nsl.pretty}") def get_pyrocko_markers(self) -> list[marker.PhaseMarker]: - """ - Get a list of Pyrocko phase markers from all receivers. + """Get a list of Pyrocko phase markers from all receivers. Returns: A list of Pyrocko phase markers. @@ -486,8 +510,7 @@ def migrate_features(cls, v: Any) -> list[EventFeaturesType]: @classmethod def set_rundir(cls, rundir: Path) -> None: - """ - Set the rundir for the detection model. + """Set the rundir for the detection model. Args: rundir (Path): The path to the rundir. @@ -497,22 +520,21 @@ def set_rundir(cls, rundir: Path) -> None: @property def magnitude(self) -> EventMagnitude | None: - """ - Returns the magnitude of the event. + """Returns the magnitude of the event. If there are no magnitudes available, returns None. """ return self.magnitudes[0] if self.magnitudes else None async def save(self, file: Path | None = None, update: bool = False) -> None: - """ - Dump the detection data to a file. + """Dump the detection data to a file. After the detection is dumped, the receivers are dumped to a separate file and the receivers cache is cleared. Args: - directory (Path): The directory where the file will be saved. + file (Path|None): The file to dump the detection to. + If None, the rundir is used. Defaults to None. update (bool): Whether to update an existing detection or append a new one. Raises: @@ -539,31 +561,35 @@ async def save(self, file: Path | None = None, update: bool = False) -> None: await asyncio.shield(f.writelines(lines)) else: logger.debug("appending detection %d", self._detection_idx) - async with aiofiles.open(file, "a") as f: - await f.write(f"{json_data}\n") + async with UPDATE_LOCK: + async with aiofiles.open(file, "a") as f: + await f.write(f"{json_data}\n") - receiver_file = self._rundir / FILENAME_RECEIVERS - async with aiofiles.open(receiver_file, "a") as f: - await asyncio.shield(f.write(f"{self.receivers.model_dump_json()}\n")) + receiver_file = self._rundir / FILENAME_RECEIVERS + async with aiofiles.open(receiver_file, "a") as f: + await asyncio.shield( + f.write(f"{self.receivers.model_dump_json()}\n") + ) self._receivers = None # Free the memory - def set_index(self, index: int) -> None: - """ - Set the index of the detection. + def set_index(self, index: int, force: bool = False) -> None: + """Set the index of the detection. Args: index (int): The index to set. + force (bool, optional): Whether to force the index to be set. + Defaults to False. Returns: None """ - if self._detection_idx is not None: + if not force and self._detection_idx is not None: raise ValueError("cannot set index twice") self._detection_idx = index def set_uncertainty(self, uncertainty: DetectionUncertainty) -> None: - """Set detection uncertainty + """Set detection uncertainty. Args: uncertainty (DetectionUncertainty): detection uncertainty @@ -571,11 +597,14 @@ def set_uncertainty(self, uncertainty: DetectionUncertainty) -> None: self.uncertainty = uncertainty def add_magnitude(self, magnitude: EventMagnitude) -> None: - """Add magnitude to detection + """Add magnitude to detection. Args: magnitude (EventMagnitudeType): magnitude """ + for mag in self.magnitudes.copy(): + if type(magnitude) is type(mag): + self.magnitudes.remove(mag) self.magnitudes.append(magnitude) def add_feature(self, feature: EventFeature) -> None: @@ -589,8 +618,7 @@ def add_feature(self, feature: EventFeature) -> None: @computed_field @property def receivers(self) -> EventReceivers: - """ - Retrieves the event receivers associated with the detection. + """Retrieves the event receivers associated with the detection. Returns: EventReceivers: The event receivers associated with the detection. @@ -607,20 +635,30 @@ def receivers(self) -> EventReceivers: elif self._rundir and self._detection_idx is not None: if self._receiver_cache is None: raise ValueError("cannot fetch receivers without set rundir") - logger.debug("fetching receiver information from cache") - row = self._receiver_cache.get_row(self._detection_idx) - receivers = EventReceivers.model_validate_json(row) - if receivers.event_uid != self.uid: - raise ValueError(f"uid mismatch: {receivers.event_uid} != {self.uid}") + try: + line = self._receiver_cache.get_line(self._detection_idx) + receivers = EventReceivers.model_validate_json(line) + except IndexError: + receivers = None + + if not receivers or receivers.event_uid != self.uid: + logger.warning("event %s uid mismatch, using brute search", self.time) + try: + idx, line = self._receiver_cache.find_uid(self.uid) + receivers = EventReceivers.model_validate_json(line) + self.set_index(idx, force=True) + except KeyError: + raise ValueError(f"uid mismatch for event {self.time}") from None + self._receivers = receivers else: raise ValueError("cannot fetch receivers without set rundir and index") return self._receivers def as_pyrocko_event(self) -> Event: - """Get detection as Pyrocko event + """Get detection as Pyrocko event. Returns: Event: Pyrocko event @@ -640,7 +678,7 @@ def as_pyrocko_event(self) -> Event: ) def get_csv_dict(self) -> dict[str, Any]: - """Get detection as CSV line + """Get detection as CSV line. Returns: dict[str, Any]: CSV line @@ -653,7 +691,6 @@ def get_csv_dict(self) -> dict[str, Any]: "east_shift": round(self.east_shift, 2), "north_shift": round(self.north_shift, 2), "distance_border": round(self.distance_border, 2), - "in_bounds": self.in_bounds, "semblance": self.semblance, } for magnitude in self.magnitudes: @@ -661,7 +698,7 @@ def get_csv_dict(self) -> dict[str, Any]: return csv_line def get_pyrocko_markers(self) -> list[marker.EventMarker | marker.PhaseMarker]: - """Get detections as Pyrocko markers + """Get detections as Pyrocko markers. Returns: list[marker.EventMarker | marker.PhaseMarker]: Pyrocko markers @@ -677,7 +714,7 @@ def get_pyrocko_markers(self) -> list[marker.EventMarker | marker.PhaseMarker]: return pyrocko_markers def export_pyrocko_markers(self, filename: Path) -> None: - """Save detection's Pyrocko markers to file + """Save detection's Pyrocko markers to file. Args: filename (Path): path to marker file @@ -686,7 +723,7 @@ def export_pyrocko_markers(self, filename: Path) -> None: marker.save_markers(self.get_pyrocko_markers(), str(filename)) def jitter_location(self, meters: float) -> Self: - """Randomize detection location + """Randomize detection location. Args: meters (float): maximum randomization in meters @@ -702,8 +739,12 @@ def jitter_location(self, meters: float) -> Self: detection._cached_lat_lon = None return detection - def snuffle(self, squirrel: Squirrel, restituted: bool = False) -> None: - """Open snuffler for detection + def snuffle( + self, + squirrel: Squirrel, + restituted: bool | MeasurementUnit = False, + ) -> None: + """Open snuffler for detection. Args: squirrel (Squirrel): The squirrel, holding the data @@ -711,10 +752,13 @@ def snuffle(self, squirrel: Squirrel, restituted: bool = False) -> None: """ from pyrocko.trace import snuffle + restitute_unit = "velocity" if restituted is True else restituted traces = ( self.receivers.get_waveforms(squirrel) - if not restituted - else self.receivers.get_waveforms_restituted(squirrel) + if not restitute_unit + else self.receivers.get_waveforms_restituted( + squirrel, quantity=restitute_unit + ) ) snuffle( traces, diff --git a/src/qseek/models/detection_uncertainty.py b/src/qseek/models/detection_uncertainty.py index fc173fab..5b2b386c 100644 --- a/src/qseek/models/detection_uncertainty.py +++ b/src/qseek/models/detection_uncertainty.py @@ -32,13 +32,12 @@ class DetectionUncertainty(BaseModel): def from_event( cls, source_node: Node, octree: Octree, percentile: float = PERCENTILE ) -> Self: - """ - Calculate the uncertainty of an event detection. + """Calculate the uncertainty of an event detection. Args: - event: The event detection to calculate the uncertainty for. - octree: The octree to use for the calculation. - percentile: The percentile to use for the calculation. + source_node (Node): The source node of the event. + octree (Octree): The octree to use for the calculation. + percentile (float): The percentile to use for the calculation. Defaults to 0.02 (2%). Returns: diff --git a/src/qseek/models/location.py b/src/qseek/models/location.py index e1f5e4a8..5f367df2 100644 --- a/src/qseek/models/location.py +++ b/src/qseek/models/location.py @@ -91,7 +91,6 @@ def surface_distance_to(self, other: Location) -> float: Returns: float: The surface distance in [m]. """ - if self._same_origin(other): return math.sqrt( (self.north_shift - other.north_shift) ** 2 @@ -129,7 +128,7 @@ def distance_to(self, other: Location) -> float: return math.sqrt((sx - ox) ** 2 + (sy - oy) ** 2 + (sz - oz) ** 2) def offset_from(self, other: Location) -> tuple[float, float, float]: - """Return offset vector (east, north, depth) from other location in [m] + """Return offset vector (east, north, depth) from other location in [m]. Args: other (Location): The other location. @@ -185,9 +184,7 @@ def shift(self, east: float, north: float, elevation: float) -> Self: return shifted def origin(self) -> Location: - """ - Returns the origin location based on the latitude, longitude, - and effective elevation. + """Get the origin location. Returns: Location: The origin location. diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index aa26868f..8ac988fc 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -200,8 +200,7 @@ def get_time_from_index(self, index: int) -> datetime: return self.start_time + timedelta(seconds=index / self.sampling_rate) def get_semblance(self, time_idx: int) -> np.ndarray: - """ - Get the semblance values at a specific time index. + """Get the semblance values at a specific time index. Parameters: time_idx (int): The index of the desired time. @@ -212,8 +211,7 @@ def get_semblance(self, time_idx: int) -> np.ndarray: return self.semblance[:, time_idx] async def apply_cache(self, cache: SemblanceCache) -> None: - """ - Applies the cached data to the `semblance_unpadded` array. + """Applies the cached data to the `semblance_unpadded` array. Args: cache (SemblanceCache): The cache containing the cached data. @@ -255,7 +253,7 @@ async def maxima_semblance( Args: trim_padding (bool, optional): Trim padded data in post-processing. - nparallel (int, optional): Number of threads for calculation. + nthreads (int, optional): Number of threads for calculation. Defaults to 12. Returns: @@ -282,7 +280,9 @@ async def maxima_node_idx( """Indices of maximum semblance at any time step. Args: - nparallel (int, optional): Number of threads for calculation. + trim_padding (bool, optional): Trim padded data in post-processing. + Defaults to True. + nthreads (int, optional): Number of threads for calculation. Defaults to 12. Returns: @@ -330,6 +330,8 @@ async def find_peaks( distance (float): Minium distance of a peak to other peaks. trim_padding (bool, optional): Trim padded data in post-processing. Defaults to True. + nthreads (int, optional): Number of threads for calculation. + Defaults to 12. Returns: tuple[np.ndarray, np.ndarray]: Indices of peaks and peak values. @@ -427,6 +429,8 @@ def normalize( Args: factor (int | float): Normalization factor. + semblance_cache (SemblanceCache | None, optional): Cache of the semblance. + Defaults to None. """ if factor == 1.0: return diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index ded665d2..1bfd7145 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -179,7 +179,7 @@ def select_from_traces(self, traces: Iterable[Trace]) -> Stations: """Select stations by NSL code. Args: - selection (Iterable[Trace]): Iterable of Pyrocko Traces + traces (Iterable[Trace]): Iterable of Pyrocko Traces Returns: Stations: Containing only selected stations. @@ -216,8 +216,7 @@ def get_coordinates(self, system: CoordSystem = "geographic") -> np.ndarray: ) def as_pyrocko_stations(self) -> list[PyrockoStation]: - """ - Convert the stations to PyrockoStation objects. + """Convert the stations to PyrockoStation objects. Returns: A list of PyrockoStation objects. diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 291c1c4a..f2a89e6b 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -84,7 +84,7 @@ class Node: _location: Location | None = None def split(self) -> tuple[Node, ...]: - """Split the node into 8 children""" + """Split the node into 8 children.""" if not self.tree: raise EnvironmentError("Parent tree is not set.") @@ -149,8 +149,8 @@ def is_inside_border(self, with_surface: bool = False) -> bool: """Check if the node is within the root node border. Args: - trough (bool, optional): If True, the node is considered inside the - trough (open top). Defaults to False. + with_surface (bool, optional): If True, the surface is considered + as a border. Defaults to False. Returns: bool: True if the node is inside the root tree's border. @@ -200,8 +200,7 @@ def distance_to_location(self, location: Location) -> float: return location.distance_to(self.as_location()) def semblance_density(self) -> float: - """ - Calculate the semblance density of the octree. + """Calculate the semblance density of the octree. Returns: The semblance density of the octree. @@ -355,7 +354,7 @@ def check_limits(self) -> Octree: return self def model_post_init(self, __context: Any) -> None: - """Initialize octree. This method is called by the pydantic model""" + """Initialize octree. This method is called by the pydantic model.""" self._root_nodes = self.get_root_nodes(self.root_node_size) logger.info( @@ -394,12 +393,12 @@ def get_root_nodes(self, length: float) -> list[Node]: @cached_property def n_nodes(self) -> int: - """Number of nodes in the octree""" + """Number of nodes in the octree.""" return sum(1 for _ in self) @property def volume(self) -> float: - """Volume of the octree in cubic meters""" + """Volume of the octree in cubic meters.""" return reduce(mul, self.extent()) def iter_nodes(self, level: int | None = None) -> Iterator[Node]: @@ -433,7 +432,7 @@ def _clear_cache(self) -> None: del self.n_nodes def reset(self) -> Self: - """Reset the octree to its initial state""" + """Reset the octree to its initial state.""" logger.debug("resetting tree") self._clear_cache() self._root_nodes = self.get_root_nodes(self.root_node_size) @@ -459,11 +458,14 @@ def reduce_axis( self, surface: Literal["NE", "ED", "ND"] = "NE", max_level: int = -1, - accumulator: Callable = np.max, + accumulator: Callable[np.ndarray] = np.max, ) -> np.ndarray: - """Reduce the octree's nodes to the surface + """Reduce the octree's nodes to the surface. Args: + surface (Literal["NE", "ED", "ND"], optional): Surface to reduce to. + Defaults to "NE". + max_level (int, optional): Maximum level to reduce to. Defaults to -1. accumulator (Callable, optional): Accumulator function. Defaults to np.max. Returns: @@ -553,8 +555,7 @@ def distances_stations_surface(self, stations: Stations) -> np.ndarray: ).reshape(-1, stations.n_stations) def get_nodes(self, indices: Iterable[int]) -> list[Node]: - """ - Retrieves a list of nodes from the octree based on the given indices. + """Retrieves a list of nodes from the octree based on the given indices. Args: indices (Iterable[int]): The indices of the nodes to retrieve. diff --git a/src/qseek/pre_processing/base.py b/src/qseek/pre_processing/base.py index 8f4a3ac3..9b04fb68 100644 --- a/src/qseek/pre_processing/base.py +++ b/src/qseek/pre_processing/base.py @@ -31,17 +31,14 @@ def validate_stations(cls, v) -> set[NSL]: @classmethod def get_subclasses(cls) -> tuple[type[BatchPreProcessing], ...]: - """ - Returns a tuple of all the subclasses of BasePreProcessing. - """ + """Returns a tuple of all the subclasses of BasePreProcessing.""" return tuple(cls.__subclasses__()) def select_traces(self, batch: WaveformBatch) -> list[Trace]: - """ - Selects traces from the given list based on the stations specified. + """Selects traces from the given list based on the stations specified. Args: - traces (list[Trace]): The list of traces to select from. + batch (WaveformBatch): The batch of traces to select from. Returns: list[Trace]: The selected traces. @@ -57,17 +54,14 @@ def select_traces(self, batch: WaveformBatch) -> list[Trace]: return traces async def prepare(self) -> None: - """ - Prepare the pre-processing module. - """ + """Prepare the pre-processing module.""" pass async def process_batch(self, batch: WaveformBatch) -> WaveformBatch: - """ - Process a list of traces. + """Process a list of traces. Args: - traces (list[Trace]): The list of traces to be processed. + batch (WaveformBatch): The batch of traces to process. Returns: list[Trace]: The processed list of traces. diff --git a/src/qseek/pre_processing/module.py b/src/qseek/pre_processing/module.py index a6a2cf5b..689f2f9e 100644 --- a/src/qseek/pre_processing/module.py +++ b/src/qseek/pre_processing/module.py @@ -103,11 +103,11 @@ async def worker() -> None: start_time = datetime_now() for process in self: batch = await process.process_batch(batch) - await self._queue.put(batch) stats.time_per_batch = datetime_now() - start_time stats.bytes_per_second = ( batch.cumulative_bytes / stats.time_per_batch.total_seconds() ) + await self._queue.put(batch) await self._queue.put(None) diff --git a/src/qseek/search.py b/src/qseek/search.py index bf28db6e..3d0f852b 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -38,6 +38,7 @@ PhaseDescription, alog_call, datetime_now, + get_cpu_count, human_readable_bytes, time_to_path, ) @@ -90,8 +91,7 @@ def time_remaining(self) -> timedelta: @computed_field @property def processing_rate(self) -> float: - """ - Calculate the processing rate of the search. + """Calculate the processing rate of the search. Returns: float: The processing rate in bytes per second. @@ -110,8 +110,7 @@ def processing_speed(self) -> timedelta: @computed_field @property def processed_percent(self) -> float: - """ - Calculate the percentage of processed batches. + """Calculate the percentage of processed batches. Returns: float: The percentage of processed batches. @@ -129,7 +128,7 @@ def add_processed_batch( duration: timedelta, show_log: bool = False, ) -> None: - self.batch_count = batch.i_batch + self.batch_count = batch.i_batch + 1 self.batch_count_total = batch.n_batches self.batch_time = batch.end_time self.processed_bytes += batch.cumulative_bytes @@ -166,7 +165,7 @@ def tts(duration: timedelta) -> str: table.add_row( "Progress ", f"[bold]{self.processed_percent:.1f}%[/bold]" - f" ([bold]{self.batch_count+1}[/bold]/{self.batch_count_total or '?'}," + f" ([bold]{self.batch_count}[/bold]/{self.batch_count_total or '?'}," f' {self.batch_time.strftime("%Y-%m-%d %H:%M:%S")})', ) table.add_row( @@ -300,7 +299,9 @@ class Search(BaseModel): _config_stem: str = PrivateAttr("") _rundir: Path = PrivateAttr() - _feature_semaphore: asyncio.Semaphore = PrivateAttr(asyncio.Semaphore(16)) + _compute_semaphore: asyncio.Semaphore = PrivateAttr( + asyncio.Semaphore(max(1, get_cpu_count() - 4)) + ) # Signals _new_detection: Signal[EventDetection] = PrivateAttr(Signal()) @@ -415,8 +416,7 @@ async def init_boundaries(self) -> None: ) async def prepare(self) -> None: - """ - Prepares the search by initializing necessary components and data. + """Prepares the search by initializing necessary components and data. This method prepares the search by performing the following steps: 1. Prepares the data provider with the given stations. @@ -466,6 +466,10 @@ async def start(self, force_rundir: bool = False) -> None: if self._progress.time_progress: logger.info("continuing search from %s", self._progress.time_progress) + await self._catalog.filter_events_by_time( + start_time=None, + end_time=self._progress.time_progress, + ) batches = self.data_provider.iter_batches( window_increment=self.window_length, @@ -509,38 +513,45 @@ async def start(self, force_rundir: bool = False) -> None: ) console.cancel() logger.info("finished search in %s", datetime_now() - processing_start) - logger.info("found %d detections", self._catalog.n_events) + logger.info("detected %d events", self._catalog.n_events) async def new_detections(self, detections: list[EventDetection]) -> None: - """ - Process new detections. + """Process new detections. Args: detections (list[EventDetection]): List of new event detections. """ + catalog = self.catalog await asyncio.gather( *(self.add_magnitude_and_features(det) for det in detections) ) for detection in detections: - await self._catalog.add(detection) + await catalog.add(detection) await self._new_detection.emit(detection) - if ( - self._catalog.n_events - and self._catalog.n_events - self._last_detection_export > 100 - ): - await self._catalog.export_detections( + if not catalog.n_events: + return + + threshold = np.floor(np.log10(catalog.n_events)) - 1 + new_threshold = max(10, 10**threshold) + if catalog.n_events - self._last_detection_export > new_threshold: + await catalog.export_detections( jitter_location=self.octree.smallest_node_size() ) - self._last_detection_export = self._catalog.n_events + self._last_detection_export = catalog.n_events - async def add_magnitude_and_features(self, event: EventDetection) -> EventDetection: - """ - Adds magnitude and features to the given event. + async def add_magnitude_and_features( + self, + event: EventDetection, + recalculate: bool = True, + ) -> EventDetection: + """Adds magnitude and features to the given event. Args: event (EventDetection): The event to add magnitude and features to. + recalculate (bool, optional): Whether to overwrite existing magnitudes and + features. Defaults to True. """ if not event.in_bounds: return event @@ -550,8 +561,10 @@ async def add_magnitude_and_features(self, event: EventDetection) -> EventDetect except NotImplementedError: return event - async with self._feature_semaphore: + async with self._compute_semaphore: for mag_calculator in self.magnitudes: + if not recalculate and mag_calculator.has_magnitude(event): + continue logger.debug("adding magnitude from %s", mag_calculator.magnitude) await mag_calculator.add_magnitude(squirrel, event) @@ -688,8 +701,7 @@ async def calculate_semblance( ) async def get_images(self, sampling_rate: float | None = None) -> WaveformImages: - """ - Retrieves waveform images for the specified sampling rate. + """Retrieves waveform images for the specified sampling rate. Args: sampling_rate (float | None, optional): The desired sampling rate in Hz. @@ -719,6 +731,8 @@ async def search( Args: octree (Octree | None, optional): The octree to use for the search. Defaults to None. + semblance_cache (SemblanceCache | None, optional): The semblance cache to + use for the search. Defaults to None. Returns: tuple[list[EventDetection], Trace]: The event detections and the diff --git a/src/qseek/tracers/cake.py b/src/qseek/tracers/cake.py index b8d74ad8..f916ec55 100644 --- a/src/qseek/tracers/cake.py +++ b/src/qseek/tracers/cake.py @@ -151,8 +151,7 @@ def get_profile_vs(self) -> np.ndarray: return self.layered_model.profile("vs") def save_plot(self, filename: Path) -> None: - """ - Plot the layered model and save the figure to a file. + """Plot the layered model and save the figure to a file. Args: filename (Path): The path to save the figure. @@ -312,7 +311,7 @@ def save(self, path: Path) -> Path: """Save the model and traveltimes to an .sptree archive. Args: - folder (Path): Folder or file to save tree into. If path is a folder a + path (Path): Folder or file to save tree into. If path is a folder a native name from the model's hash is used Returns: @@ -398,7 +397,7 @@ async def init_lut(self, octree: Octree, stations: Stations) -> None: self._node_lut[node.hash()] = traveltimes.astype(np.float32) def lut_fill_level(self) -> float: - """Return the fill level of the LUT as a float between 0.0 and 1.0""" + """Return the fill level of the LUT as a float between 0.0 and 1.0.""" return len(self._node_lut) / self._node_lut.get_size() async def fill_lut(self, nodes: Sequence[Node]) -> None: diff --git a/src/qseek/utils.py b/src/qseek/utils.py index afc8ac61..ec7f9173 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -101,8 +101,7 @@ def pretty(self) -> str: return ".".join(self) def match(self, other: NSL) -> bool: - """ - Check if the current NSL object matches another NSL object. + """Check if the current NSL object matches another NSL object. Args: other (NSL): The NSL object to compare with. @@ -118,8 +117,7 @@ def match(self, other: NSL) -> bool: @classmethod def parse(cls, nsl: str) -> NSL: - """ - Parse the given NSL string and return an NSL object. + """Parse the given NSL string and return an NSL object. Args: nsl (str): The NSL string to parse. @@ -148,8 +146,7 @@ class _Range(NamedTuple): max: float def inside(self, value: float) -> bool: - """ - Check if a value is inside the range. + """Check if a value is inside the range. Args: value (float): The value to check. @@ -161,8 +158,7 @@ def inside(self, value: float) -> bool: @classmethod def from_list(cls, array: np.ndarray | list[float]) -> _Range: - """ - Create a Range object from a numpy array. + """Create a Range object from a numpy array. Parameters: - array: numpy.ndarray @@ -184,8 +180,7 @@ def _range_validator(v: _Range) -> _Range: def time_to_path(datetime: datetime) -> str: - """ - Converts a datetime object to a string representation of a file path. + """Converts a datetime object to a string representation of a file path. Args: datetime (datetime): The datetime object to convert. @@ -197,8 +192,7 @@ def time_to_path(datetime: datetime) -> str: def as_array(iterable: Iterable[float], dtype: np.dtype = float) -> np.ndarray: - """ - Convert an iterable of floats into a NumPy array. + """Convert an iterable of floats into a NumPy array. Parameters: iterable (Iterable[float]): An iterable containing float values. @@ -210,8 +204,7 @@ def as_array(iterable: Iterable[float], dtype: np.dtype = float) -> np.ndarray: def weighted_median(data: np.ndarray, weights: np.ndarray | None = None) -> float: - """ - Calculate the weighted median of an array/list using numpy. + """Calculate the weighted median of an array/list using numpy. Parameters: data (np.ndarray): The input array/list. @@ -254,8 +247,7 @@ def weighted_median(data: np.ndarray, weights: np.ndarray | None = None) -> floa async def async_weighted_median( data: np.ndarray, weights: np.ndarray | None = None ) -> float: - """ - Asynchronously calculate the weighted median of an array/list using numpy. + """Asynchronously calculate the weighted median of an array/list using numpy. Parameters: data (np.ndarray): The input array/list. @@ -296,8 +288,7 @@ async def async_weighted_median( def to_datetime(time: float) -> datetime: - """ - Convert a UNIX timestamp to a datetime object in UTC timezone. + """Convert a UNIX timestamp to a datetime object in UTC timezone. Args: time (float): The UNIX timestamp to convert. @@ -309,8 +300,7 @@ def to_datetime(time: float) -> datetime: def resample(trace: Trace, sampling_rate: float) -> None: - """ - Downsamples the given trace to the specified sampling rate in-place. + """Downsamples the given trace to the specified sampling rate in-place. Args: trace (Trace): The trace to be downsampled. @@ -361,8 +351,7 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def human_readable_bytes(size: int | float) -> str: - """ - Convert a size in bytes to a human-readable string representation. + """Convert a size in bytes to a human-readable string representation. Args: size (int | float): The size in bytes. @@ -375,8 +364,7 @@ def human_readable_bytes(size: int | float) -> str: def datetime_now() -> datetime: - """ - Get the current datetime in UTC timezone. + """Get the current datetime in UTC timezone. Returns: datetime: The current datetime in UTC timezone. @@ -385,8 +373,7 @@ def datetime_now() -> datetime: def get_cpu_count() -> int: - """ - Get the number of CPUs available for the current job/task. + """Get the number of CPUs available for the current job/task. The function first checks if the environment variable SLURM_CPUS_PER_TASK is set. If it is set, the value is returned as the number of CPUs. @@ -417,8 +404,7 @@ def filter_clipped_traces( counts_threshold: int = 20, max_bits: tuple[int, ...] = (24, 32), ) -> list[Trace]: - """ - Filters out clipped traces from the given list of traces. + """Filters out clipped traces from the given list of traces. Args: traces (list[Trace]): The list of traces to filter. @@ -455,8 +441,7 @@ def filter_clipped_traces( def camel_case_to_snake_case(name: str) -> str: - """ - Converts a camel case string to snake case. + """Converts a camel case string to snake case. Args: name (str): The camel case string to be converted. @@ -472,8 +457,7 @@ def camel_case_to_snake_case(name: str) -> str: def load_insights() -> None: - """ - Imports the qseek.insights package if available. + """Imports the qseek.insights package if available. This function attempts to import the qseek.insights package and logs a debug message indicating whether the package was successfully imported or not. @@ -503,11 +487,10 @@ class ChannelSelector: normalize: bool = False def get_traces(self, traces_flt: list[Trace]) -> list[Trace]: - """ - Filter and normalize a list of traces based on the specified channels. + """Filter and normalize a list of traces based on the specified channels. Args: - traces (list[Trace]): The list of traces to filter. + traces_flt (list[Trace]): The list of traces to filter. Returns: list[Trace]: The filtered and normalized list of traces. @@ -562,7 +545,7 @@ class ChannelSelectors: def generate_docs(model: BaseModel, exclude: dict | set | None = None) -> str: - """Takes model and dumps markdown for documentation""" + """Takes model and dumps markdown for documentation.""" def generate_submodel(model: BaseModel) -> list[str]: lines = [] diff --git a/test/test_moment_magnitude_store.py b/test/test_moment_magnitude_store.py index 5076ee2e..f1216dbd 100644 --- a/test/test_moment_magnitude_store.py +++ b/test/test_moment_magnitude_store.py @@ -45,17 +45,35 @@ async def test_peak_amplitudes(engine: gf.LocalEngine) -> None: ) PeakAmplitudesStore.set_engine(engine) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - await store.fill_source_depth(source_depth=2 * KM) - await store.get_amplitude( + await store.compute_site_amplitudes(source_depth=2 * KM, reference_magnitude=1.0) + await store.get_amplitude_model( source_depth=2 * KM, distance=10 * KM, n_amplitudes=10, - max_distance=1 * KM, + distance_cutoff=1 * KM, auto_fill=False, interpolation="nearest", ) +@pytest.mark.asyncio +async def test_peak_amplitude_estimation(engine: gf.LocalEngine) -> None: + store_id = "reykjanes_qseis" + peak_amplitudes = PeakAmplitudesBase( + gf_store_id=store_id, + quantity="displacement", + ) + PeakAmplitudesStore.set_engine(engine) + store = PeakAmplitudesStore.from_selector(peak_amplitudes) + await store.compute_site_amplitudes(source_depth=2 * KM, reference_magnitude=1.0) + + await store.find_moment_magnitude( + source_depth=2 * KM, + distance=10 * KM, + observed_amplitude=0.0001, + ) + + @pytest.mark.plot @pytest.mark.asyncio async def test_peak_amplitude_plot(engine: gf.LocalEngine) -> None: @@ -69,17 +87,27 @@ async def test_peak_amplitude_plot(engine: gf.LocalEngine) -> None: PeakAmplitudesStore.set_engine(engine) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - collection = await store.fill_source_depth(source_depth=2 * KM) + collection = await store.compute_site_amplitudes( + source_depth=2 * KM, reference_magnitude=1.0 + ) collection.plot(peak_amplitude=plot_amplitude) + await store.find_moment_magnitude( + source_depth=2 * KM, + distance=10 * KM, + observed_amplitude=0.01, + ) + peak_amplitudes = PeakAmplitudesBase( gf_store_id=store_id, quantity="velocity", ) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - collection = await store.fill_source_depth(source_depth=2 * KM) - collection.plot(peak_amplitude=plot_amplitude) + collection = await store.compute_site_amplitudes( + source_depth=2 * KM, reference_magnitude=2.0 + ) + collection.plot(peak_amplitude=plot_amplitude, reference_magnitude=2.0) peak_amplitudes = PeakAmplitudesBase( gf_store_id=store_id, @@ -87,7 +115,9 @@ async def test_peak_amplitude_plot(engine: gf.LocalEngine) -> None: ) store = PeakAmplitudesStore.from_selector(peak_amplitudes) - collection = await store.fill_source_depth(source_depth=2 * KM) + collection = await store.compute_site_amplitudes( + source_depth=2 * KM, reference_magnitude=1.0 + ) collection.plot(peak_amplitude=plot_amplitude) @@ -116,10 +146,11 @@ async def test_peak_amplitude_surface(engine: gf.LocalEngine) -> None: amplitudes: list[ModelledAmplitude] = [] for dist in distances: amplitudes.append( - await store.get_amplitude( + await store.get_amplitude_model( source_depth=depth, distance=dist, n_amplitudes=25, + reference_magnitude=1.0, peak_amplitude=plot_amplitude, auto_fill=False, ) From 7a6e17574960c1ffcbdcab3cf5aa98466e1211a2 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 22 Mar 2024 16:04:16 +0000 Subject: [PATCH 4/6] update --- src/qseek/apps/qseek.py | 28 +++++++++++++++++++--------- src/qseek/models/detection.py | 10 ++++++++-- src/qseek/pre_processing/module.py | 1 + 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index 72343179..1cc66462 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -12,7 +12,6 @@ from pkg_resources import get_distribution from qseek.models.detection import EventDetection -from qseek.utils import get_cpu_count nest_asyncio.apply() @@ -137,6 +136,12 @@ default=False, help="recalculate all magnitudes", ) +features_extract.add_argument( + "--nparallel", + type=int, + default=32, + help="number of parallel tasks for feature extraction", +) modules = subparsers.add_parser( "modules", @@ -203,7 +208,7 @@ def main() -> None: load_insights() from rich import box - from rich.progress import track + from rich.progress import Progress from rich.prompt import IntPrompt from rich.table import Table @@ -282,18 +287,20 @@ def console_status(task: asyncio.Task[EventDetection]): else: console.print(f"Event {detection.time}: No magnitudes") + progress = Progress() + tracker = progress.add_task( + "Calculating magnitudes", + total=search.catalog.n_events, + console=console, + ) + async def worker() -> None: for magnitude in search.magnitudes: await magnitude.prepare(search.octree, search.stations) await search.catalog.check(repair=True) - sem = asyncio.Semaphore(get_cpu_count()) - for detection in track( - search.catalog, - description="Calculating magnitudes", - total=search.catalog.n_events, - console=console, - ): + sem = asyncio.Semaphore(args.nparallel) + for detection in search.catalog: await sem.acquire() task = asyncio.create_task( search.add_magnitude_and_features( @@ -305,6 +312,9 @@ async def worker() -> None: task.add_done_callback(lambda _: sem.release()) task.add_done_callback(tasks.remove) task.add_done_callback(console_status) + task.add_done_callback( + lambda _: progress.update(tracker, advance=1) + ) await asyncio.gather(*tasks) diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index bc24b70e..e1d4a5bd 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -274,8 +274,7 @@ async def get_waveforms( tmax = max(times).timestamp() + seconds_after nslc_ids = [(*receiver.nsl, "*") for receiver in receivers] async with SQUIRREL_SEM: - traces = await asyncio.to_thread( - squirrel.get_waveforms, + traces = await squirrel.get_waveforms_async( codes=nslc_ids, tmin=tmin, tmax=tmax, @@ -526,6 +525,13 @@ def magnitude(self) -> EventMagnitude | None: """ return self.magnitudes[0] if self.magnitudes else None + async def update(self) -> None: + """Update detection in database. + + Doing this often requires a lot of I/O. + """ + await self.save(update=True) + async def save(self, file: Path | None = None, update: bool = False) -> None: """Dump the detection data to a file. diff --git a/src/qseek/pre_processing/module.py b/src/qseek/pre_processing/module.py index 689f2f9e..ec1f504e 100644 --- a/src/qseek/pre_processing/module.py +++ b/src/qseek/pre_processing/module.py @@ -103,6 +103,7 @@ async def worker() -> None: start_time = datetime_now() for process in self: batch = await process.process_batch(batch) + await asyncio.sleep(0.0) stats.time_per_batch = datetime_now() - start_time stats.bytes_per_second = ( batch.cumulative_bytes / stats.time_per_batch.total_seconds() From c9ae75753ae7e9d2ae6b7d2db9f10bcb0c0c9ffc Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 22 Mar 2024 16:04:50 +0000 Subject: [PATCH 5/6] squirrel: async loading --- src/qseek/waveforms/squirrel.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index ebda8a7b..253dd892 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -5,7 +5,7 @@ import logging from datetime import datetime, timedelta from pathlib import Path -from typing import TYPE_CHECKING, AsyncIterator, Iterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Literal from pydantic import ( AwareDatetime, @@ -42,36 +42,30 @@ class SquirrelPrefetcher: def __init__( self, - iterator: Iterator[Batch], + iterator: AsyncIterator[Batch], queue_size: int = 8, ) -> None: self.iterator = iterator self.queue = asyncio.Queue(maxsize=queue_size) self._load_queue = asyncio.Queue(maxsize=queue_size) self._fetched_batches = 0 - self._task = asyncio.create_task(self.prefetch_worker()) async def prefetch_worker(self) -> None: logger.info( - "start prefetching data, queue size %d", + "start pre-fetching data, queue size %d", self.queue.maxsize, ) - async def load_data() -> None | Batch: - while True: - start = datetime_now() - batch = await asyncio.to_thread(next, self.iterator, None) - if batch is None: - await self.queue.put(None) - return - logger.debug("read waveform batch in %s", datetime_now() - start) - self._fetched_batches += 1 - self.load_time = datetime_now() - start - await self.queue.put(batch) + start = datetime_now() + async for batch in self.iterator: + self.load_time = datetime_now() - start + self._fetched_batches += 1 + logger.debug("read waveform batch in %s", self.load_time) + start = datetime_now() + await self.queue.put(batch) - await asyncio.create_task(load_data()) - logger.debug("loading waveform batches to finish") + await self.queue.put(None) class SquirrelStats(Stats): @@ -209,7 +203,7 @@ async def iter_batches( end_time - start_time, ) - iterator = squirrel.chopper_waveforms( + iterator = squirrel.chopper_waveforms_async( tmin=(start_time + window_padding).timestamp(), tmax=(end_time - window_padding).timestamp(), tinc=window_increment.total_seconds(), From d2ea9298eec65e074baf7d9cda9a183c93b2ba45 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 25 Mar 2024 13:35:59 +0000 Subject: [PATCH 6/6] bugfixes --- src/qseek/testing.py | 165 ++++++++++++++++++++++++++++ src/qseek/utils.py | 13 ++- test/conftest.py | 165 +--------------------------- test/test_moment_magnitude_store.py | 4 + test/test_utils.py | 25 +++++ 5 files changed, 205 insertions(+), 167 deletions(-) create mode 100644 src/qseek/testing.py create mode 100644 test/test_utils.py diff --git a/src/qseek/testing.py b/src/qseek/testing.py new file mode 100644 index 00000000..b623e254 --- /dev/null +++ b/src/qseek/testing.py @@ -0,0 +1,165 @@ +import asyncio +import random +from datetime import timedelta +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Generator + +import aiohttp +import numpy as np +import pytest +from rich.progress import Progress + +from qseek.models.catalog import EventCatalog +from qseek.models.detection import EventDetection +from qseek.models.location import Location +from qseek.models.station import Station, Stations +from qseek.octree import Octree +from qseek.tracers.cake import EarthModel, Timing, TravelTimeTree +from qseek.utils import Range, datetime_now + +DATA_DIR = Path(__file__).parent / "data" + +DATA_URL = "https://data.pyrocko.org/testing/lassie-v2/" +DATA_FILES = { + "FORGE_3D_5_large.P.mod.hdr", + "FORGE_3D_5_large.P.mod.buf", + "FORGE_3D_5_large.S.mod.hdr", + "FORGE_3D_5_large.S.mod.buf", +} + +KM = 1e3 + + +async def download_test_data() -> None: + request_files = [ + DATA_DIR / filename + for filename in DATA_FILES + if not (DATA_DIR / filename).exists() + ] + + if not request_files: + return + + async with aiohttp.ClientSession() as session: + for file in request_files: + url = DATA_URL + file.name + with Progress() as progress: + async with session.get(url) as response: + task = progress.add_task( + f"Downloading {url}", + total=response.content_length, + ) + with file.open("wb") as f: + while True: + chunk = await response.content.read(1024) + if not chunk: + break + f.write(chunk) + progress.advance(task, len(chunk)) + + +def pytest_addoption(parser) -> None: + parser.addoption("--plot", action="store_true", default=False) + + +@pytest.fixture(scope="session") +def plot(pytestconfig) -> bool: + return pytestconfig.getoption("plot") + + +@pytest.fixture(scope="session") +def travel_time_tree() -> TravelTimeTree: + return TravelTimeTree.new( + earthmodel=EarthModel(), + distance_bounds=(0 * KM, 15 * KM), + receiver_depth_bounds=(0 * KM, 0 * KM), + source_depth_bounds=(0 * KM, 10 * KM), + spatial_tolerance=100, + time_tolerance=0.05, + timing=Timing(definition="P,p"), + ) + + +@pytest.fixture(scope="session") +def data_dir() -> Path: + if not DATA_DIR.exists(): + DATA_DIR.mkdir() + + asyncio.run(download_test_data()) + return DATA_DIR + + +@pytest.fixture(scope="session") +def octree() -> Octree: + return Octree( + location=Location( + lat=10.0, + lon=10.0, + elevation=1.0 * KM, + ), + root_node_size=2 * KM, + n_levels=3, + east_bounds=Range(-10 * KM, 10 * KM), + north_bounds=Range(-10 * KM, 10 * KM), + depth_bounds=Range(0 * KM, 10 * KM), + absorbing_boundary=1 * KM, + ) + + +@pytest.fixture(scope="session") +def stations() -> Stations: + n_stations = 20 + stations: list[Station] = [] + for i_sta in range(n_stations): + station = Station( + network="XX", + station="STA%02d" % i_sta, + lat=10.0, + lon=10.0, + elevation=random.uniform(0, 0.8) * KM, + depth=random.uniform(0, 0.2) * KM, + north_shift=random.uniform(-10, 10) * KM, + east_shift=random.uniform(-10, 10) * KM, + ) + stations.append(station) + return Stations(stations=stations) + + +@pytest.fixture(scope="session") +def fixed_stations() -> Stations: + n_stations = 20 + rng = np.random.RandomState(0) + stations: list[Station] = [] + for i_sta in range(n_stations): + station = Station( + network="FX", + station="STA%02d" % i_sta, + lat=10.0, + lon=10.0, + elevation=rng.uniform(0, 1) * KM, + north_shift=rng.uniform(-10, 10) * KM, + east_shift=rng.uniform(-10, 10) * KM, + ) + stations.append(station) + return Stations(stations=stations) + + +@pytest.fixture(scope="session") +def detections() -> Generator[EventCatalog, None, None]: + n_detections = 2000 + detections: list[EventDetection] = [] + for _ in range(n_detections): + time = datetime_now() - timedelta(days=random.uniform(0, 365)) + detection = EventDetection( + lat=10.0, + lon=10.0, + east_shift=random.uniform(-10, 10) * KM, + north_shift=random.uniform(-10, 10) * KM, + distance_border=1000.0, + semblance=random.uniform(0, 1), + time=time, + ) + detections.append(detection) + with TemporaryDirectory() as tmpdir: + yield EventCatalog(rundir=Path(tmpdir), events=detections) diff --git a/src/qseek/utils.py b/src/qseek/utils.py index ec7f9173..8176f1e3 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -26,7 +26,7 @@ ) import numpy as np -from pydantic import AfterValidator, BaseModel, ByteSize, constr +from pydantic import AfterValidator, BaseModel, BeforeValidator, ByteSize, constr from pyrocko.util import UnavailableDecimation from rich.logging import RichHandler @@ -91,7 +91,7 @@ async def wait_all(cls) -> None: await asyncio.gather(*cls.tasks) -class NSL(NamedTuple): +class _NSL(NamedTuple): network: str station: str location: str @@ -116,7 +116,7 @@ def match(self, other: NSL) -> bool: return self.network == other.network @classmethod - def parse(cls, nsl: str) -> NSL: + def parse(cls, nsl: str | NSL) -> NSL: """Parse the given NSL string and return an NSL object. Args: @@ -130,6 +130,10 @@ def parse(cls, nsl: str) -> NSL: """ if not nsl: raise ValueError("invalid empty NSL") + if type(nsl) is _NSL: + return nsl + if not isinstance(nsl, str): + raise ValueError(f"invalid NSL {nsl}") parts = nsl.split(".") n_parts = len(parts) if n_parts >= 3: @@ -141,6 +145,9 @@ def parse(cls, nsl: str) -> NSL: raise ValueError(f"invalid NSL {nsl}") +NSL = Annotated[_NSL, BeforeValidator(_NSL.parse)] + + class _Range(NamedTuple): min: float max: float diff --git a/test/conftest.py b/test/conftest.py index d198370d..3fc91ea7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,164 +1 @@ -import asyncio -import random -from datetime import timedelta -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Generator - -import aiohttp -import numpy as np -import pytest -from qseek.models.catalog import EventCatalog -from qseek.models.detection import EventDetection -from qseek.models.location import Location -from qseek.models.station import Station, Stations -from qseek.octree import Octree -from qseek.tracers.cake import EarthModel, Timing, TravelTimeTree -from qseek.utils import Range, datetime_now -from rich.progress import Progress - -DATA_DIR = Path(__file__).parent / "data" - -DATA_URL = "https://data.pyrocko.org/testing/lassie-v2/" -DATA_FILES = { - "FORGE_3D_5_large.P.mod.hdr", - "FORGE_3D_5_large.P.mod.buf", - "FORGE_3D_5_large.S.mod.hdr", - "FORGE_3D_5_large.S.mod.buf", -} - -KM = 1e3 - - -async def download_test_data() -> None: - request_files = [ - DATA_DIR / filename - for filename in DATA_FILES - if not (DATA_DIR / filename).exists() - ] - - if not request_files: - return - - async with aiohttp.ClientSession() as session: - for file in request_files: - url = DATA_URL + file.name - with Progress() as progress: - async with session.get(url) as response: - task = progress.add_task( - f"Downloading {url}", - total=response.content_length, - ) - with file.open("wb") as f: - while True: - chunk = await response.content.read(1024) - if not chunk: - break - f.write(chunk) - progress.advance(task, len(chunk)) - - -def pytest_addoption(parser) -> None: - parser.addoption("--plot", action="store_true", default=False) - - -@pytest.fixture(scope="session") -def plot(pytestconfig) -> bool: - return pytestconfig.getoption("plot") - - -@pytest.fixture(scope="session") -def travel_time_tree() -> TravelTimeTree: - return TravelTimeTree.new( - earthmodel=EarthModel(), - distance_bounds=(0 * KM, 15 * KM), - receiver_depth_bounds=(0 * KM, 0 * KM), - source_depth_bounds=(0 * KM, 10 * KM), - spatial_tolerance=100, - time_tolerance=0.05, - timing=Timing(definition="P,p"), - ) - - -@pytest.fixture(scope="session") -def data_dir() -> Path: - if not DATA_DIR.exists(): - DATA_DIR.mkdir() - - asyncio.run(download_test_data()) - return DATA_DIR - - -@pytest.fixture(scope="session") -def octree() -> Octree: - return Octree( - location=Location( - lat=10.0, - lon=10.0, - elevation=1.0 * KM, - ), - root_node_size=2 * KM, - n_levels=3, - east_bounds=Range(-10 * KM, 10 * KM), - north_bounds=Range(-10 * KM, 10 * KM), - depth_bounds=Range(0 * KM, 10 * KM), - absorbing_boundary=1 * KM, - ) - - -@pytest.fixture(scope="session") -def stations() -> Stations: - n_stations = 20 - stations: list[Station] = [] - for i_sta in range(n_stations): - station = Station( - network="XX", - station="STA%02d" % i_sta, - lat=10.0, - lon=10.0, - elevation=random.uniform(0, 0.8) * KM, - depth=random.uniform(0, 0.2) * KM, - north_shift=random.uniform(-10, 10) * KM, - east_shift=random.uniform(-10, 10) * KM, - ) - stations.append(station) - return Stations(stations=stations) - - -@pytest.fixture(scope="session") -def fixed_stations() -> Stations: - n_stations = 20 - rng = np.random.RandomState(0) - stations: list[Station] = [] - for i_sta in range(n_stations): - station = Station( - network="FX", - station="STA%02d" % i_sta, - lat=10.0, - lon=10.0, - elevation=rng.uniform(0, 1) * KM, - north_shift=rng.uniform(-10, 10) * KM, - east_shift=rng.uniform(-10, 10) * KM, - ) - stations.append(station) - return Stations(stations=stations) - - -@pytest.fixture(scope="session") -def detections() -> Generator[EventCatalog, None, None]: - n_detections = 2000 - detections: list[EventDetection] = [] - for _ in range(n_detections): - time = datetime_now() - timedelta(days=random.uniform(0, 365)) - detection = EventDetection( - lat=10.0, - lon=10.0, - east_shift=random.uniform(-10, 10) * KM, - north_shift=random.uniform(-10, 10) * KM, - distance_border=1000.0, - semblance=random.uniform(0, 1), - time=time, - ) - detections.append(detection) - with TemporaryDirectory() as tmpdir: - yield EventCatalog(rundir=Path(tmpdir), events=detections) +pytest_plugins = ["qseek.testing"] diff --git a/test/test_moment_magnitude_store.py b/test/test_moment_magnitude_store.py index f1216dbd..7903c580 100644 --- a/test/test_moment_magnitude_store.py +++ b/test/test_moment_magnitude_store.py @@ -56,6 +56,10 @@ async def test_peak_amplitudes(engine: gf.LocalEngine) -> None: ) +@pytest.mark.skipif( + not has_store("reykjanes_qseis"), + reason="reykjanes_qseis not available", +) @pytest.mark.asyncio async def test_peak_amplitude_estimation(engine: gf.LocalEngine) -> None: store_id = "reykjanes_qseis" diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..30d55623 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel +from qseek.utils import NSL + + +def test_nsl(): + nsl_id = "6E.TE234." + nsl = NSL(*nsl_id.split(".")) + + assert nsl.network == "6E" + assert nsl.station == "TE234" + assert nsl.location == "" + + class Model(BaseModel): + nsl: NSL + nsl_list: list[NSL] + + Model(nsl=nsl, nsl_list=[nsl, nsl, nsl]) + + json = """ + { + "nsl": "6E.TE234.", + "nsl_list": ["6E.TE234.", "6E.TE234.", "6E.TE234."] + } + """ + Model.model_validate_json(json)