diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index 49c1d06e..4d28062d 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -89,15 +89,16 @@ def resample(self, sampling_rate: float, max_normalize: bool = False) -> None: max_normalize (bool): Normalize by maximum value to keep the scale of the maximum detection. Defaults to False. """ + if self.sampling_rate == sampling_rate: + return + downsample = self.sampling_rate > sampling_rate for tr in self.traces: - if max_normalize: - # We can use maximum here since the PhaseNet output is single-sided - _, max_value = tr.max() resample(tr, sampling_rate) if max_normalize and downsample: + _, max_value = tr.max() tr.ydata /= tr.ydata.max() tr.ydata *= max_value diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index a9f3a77f..d4451004 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -12,7 +12,7 @@ from qseek.images.base import ImageFunction from qseek.images.phase_net import PhaseNet from qseek.stats import Stats -from qseek.utils import PhaseDescription, datetime_now, human_readable_bytes +from qseek.utils import QUEUE_SIZE, PhaseDescription, datetime_now, human_readable_bytes if TYPE_CHECKING: from pyrocko.trace import Trace @@ -72,7 +72,9 @@ def _populate_table(self, table: Table) -> None: class ImageFunctions(RootModel): root: list[ImageFunctionType] = [PhaseNet()] - _queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr() + _queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr( + asyncio.Queue(maxsize=QUEUE_SIZE) + ) _processed_images: int = PrivateAttr(0) _stats: ImageFunctionsStats = PrivateAttr(default_factory=ImageFunctionsStats) @@ -81,7 +83,6 @@ def model_post_init(self, __context: Any) -> None: phases = self.get_phases() if len(set(phases)) != len(phases): raise ValueError("A phase was provided twice") - self._queue = asyncio.Queue(maxsize=16) self._stats.set_queue(self._queue) async def process_traces(self, traces: list[Trace]) -> WaveformImages: diff --git a/src/qseek/models/catalog.py b/src/qseek/models/catalog.py index a997b921..630a26b4 100644 --- a/src/qseek/models/catalog.py +++ b/src/qseek/models/catalog.py @@ -62,7 +62,7 @@ def new_detection(self, detection: EventDetection): self.max_semblance = max(self.max_semblance, detection.semblance) def _populate_table(self, table: Table) -> None: - table.add_row("No. Detections", f"[bold]{self.n_detections} :dim_button:") + table.add_row("No. Detections", f"[bold]{self.n_detections} :fire:") table.add_row("Maximum semblance", f"{self.max_semblance:.4f}") diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index 2b9dc6e0..dc60f607 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -165,7 +165,6 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: raise ValueError("no stations available, add waveforms to start detection") def __iter__(self) -> Iterator[Station]: - # TODO: this is inefficient return (sta for sta in self.stations if sta.nsl.pretty not in self.blacklist) @property @@ -186,14 +185,18 @@ def select_from_traces(self, traces: Iterable[Trace]) -> Stations: Returns: Stations: Containing only selected stations. """ + available_stations = tuple(self) + available_nsls = tuple(sta.nsl for sta in available_stations) + selected_stations = [] - for nsl in ((tr.network, tr.station, tr.location) for tr in traces): - for sta in self: - if sta.nsl == nsl: - selected_stations.append(sta) - break - else: - raise ValueError(f"could not find a station for {'.'.join(nsl)} ") + for nsl in {(tr.network, tr.station, tr.location) for tr in traces}: + try: + sta_idx = available_nsls.index(nsl) + selected_stations.append(available_stations[sta_idx]) + except ValueError as exc: + raise ValueError( + f"could not find a station for {'.'.join(nsl)} " + ) from exc return Stations.model_construct(stations=selected_stations) def get_centroid(self) -> Location: diff --git a/src/qseek/pre_processing/module.py b/src/qseek/pre_processing/module.py index ec1f504e..5cabb57b 100644 --- a/src/qseek/pre_processing/module.py +++ b/src/qseek/pre_processing/module.py @@ -16,7 +16,7 @@ Lowpass, ) from qseek.stats import Stats -from qseek.utils import datetime_now, human_readable_bytes +from qseek.utils import QUEUE_SIZE, datetime_now, human_readable_bytes if TYPE_CHECKING: from rich.table import Table @@ -71,7 +71,9 @@ class PreProcessing(RootModel): "The first module is the first to be applied.", ) - _queue: asyncio.Queue[WaveformBatch | None] = asyncio.Queue(maxsize=12) + _queue: asyncio.Queue[WaveformBatch | None] = PrivateAttr( + asyncio.Queue(maxsize=QUEUE_SIZE) + ) _stats: PreProcessingStats = PrivateAttr(default_factory=PreProcessingStats) def model_post_init(self, __context: Any) -> None: diff --git a/src/qseek/search.py b/src/qseek/search.py index a27b142a..3b5d48fe 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -500,12 +500,6 @@ async def start(self, force_rundir: bool = False) -> None: await self.prepare() - logger.info("starting search") - stats = self._stats - stats.reset_start_time() - - processing_start = datetime_now() - if self._progress.time_progress: logger.info("continuing search from %s", self._progress.time_progress) await self._catalog.check(repair=True) @@ -513,6 +507,8 @@ async def start(self, force_rundir: bool = False) -> None: start_time=None, end_time=self._progress.time_progress, ) + else: + logger.info("starting search") batches = self.data_provider.iter_batches( window_increment=self.window_length, @@ -522,6 +518,10 @@ async def start(self, force_rundir: bool = False) -> None: ) processed_batches = self.pre_processing.iter_batches(batches) + stats = self._stats + stats.reset_start_time() + + processing_start = datetime_now() console = asyncio.create_task(RuntimeStats.live_view()) async for images, batch in self.image_functions.iter_images(processed_batches): diff --git a/src/qseek/utils.py b/src/qseek/utils.py index ed40411a..80ba1e10 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -41,6 +41,7 @@ PhaseDescription = Annotated[str, constr(pattern=r"[a-zA-Z]*:[a-zA-Z]*")] +QUEUE_SIZE = 16 CACHE_DIR = Path.home() / ".cache" / "qseek" if not CACHE_DIR.exists(): logger.info("creating cache dir %s", CACHE_DIR) diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index 3fcb0d47..87398f65 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -22,7 +22,7 @@ from qseek.models.station import Stations from qseek.stats import Stats -from qseek.utils import datetime_now, human_readable_bytes, to_datetime +from qseek.utils import QUEUE_SIZE, datetime_now, human_readable_bytes, to_datetime from qseek.waveforms.base import WaveformBatch, WaveformProvider if TYPE_CHECKING: @@ -40,14 +40,10 @@ class SquirrelPrefetcher: _fetched_batches: int _task: asyncio.Task[None] - def __init__( - self, - iterator: Iterator[Batch], - queue_size: int = 8, - ) -> None: + def __init__(self, iterator: Iterator[Batch]) -> None: self.iterator = iterator - self.queue = asyncio.Queue(maxsize=queue_size) - self._load_queue = asyncio.Queue(maxsize=queue_size) + self.queue = asyncio.Queue(maxsize=QUEUE_SIZE) + self._load_queue = asyncio.Queue(maxsize=QUEUE_SIZE) self._fetched_batches = 0 self._task = asyncio.create_task(self.prefetch_worker()) @@ -143,10 +139,6 @@ class PyrockoSquirrel(WaveformProvider): description="Channel selector for waveforms, " "e.g. `['HH', 'EN']`.", ) ) - async_prefetch_batches: PositiveInt = Field( - default=10, - description="Queue size for asynchronous pre-fetcher.", - ) n_threads: PositiveInt = Field( default=8, description="Number of threads for loading waveforms.", @@ -227,10 +219,7 @@ async def iter_batches( codes=[(*nsl, "*") for nsl in self._stations.get_all_nsl()], channel_priorities=self.channel_selector, ) - prefetcher = SquirrelPrefetcher( - iterator, - queue_size=self.async_prefetch_batches, - ) + prefetcher = SquirrelPrefetcher(iterator) stats.set_queue(prefetcher.queue) while True: