From a8b3924513b59f65de1f05c6c2f8a061e712d5e2 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Sun, 12 Nov 2023 15:40:36 +0100 Subject: [PATCH] stats: adding statistics --- lassie/search.py | 16 +++++++--------- lassie/stats.py | 34 +++++++++++++++++++++------------- lassie/utils.py | 8 +++++++- lassie/waveforms/squirrel.py | 2 +- pyproject.toml | 1 - 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/lassie/search.py b/lassie/search.py index 4048373c..fe39e802 100644 --- a/lassie/search.py +++ b/lassie/search.py @@ -229,12 +229,7 @@ class Search(BaseModel): _new_detection: Signal[EventDetection] = PrivateAttr(Signal()) _stats: SearchStats = PrivateAttr(SearchStats()) - _runtime_stats: RuntimeStats = PrivateAttr(default_factory=RuntimeStats.new) - - def model_post_init(self, *args) -> None: - self._runtime_stats.add_stats(self._stats) - self._runtime_stats.add_stats(self.data_provider._stats) - self._runtime_stats.add_stats(self.image_functions._stats) + _runtime_stats: RuntimeStats = PrivateAttr(None) def init_rundir(self, force: bool = False) -> None: rundir = ( @@ -326,7 +321,8 @@ def init_boundaries(self) -> None: f"window length {self.window_length} is too short for the " f"theoretical travel time range {self._shift_range} and " f"cummulative window padding of {self._window_padding}." - " Increase the window_length time." + " Increase the window_length time to at least " + f"{self._shift_range +2*self._window_padding }" ) logger.info("using trace window padding: %s", self._window_padding) @@ -354,7 +350,9 @@ async def start(self, force_rundir: bool = False) -> None: await self.prepare() logger.info("starting search...") + self._runtime_stats = RuntimeStats.new() stats = self._stats + batch_processing_start = datetime_now() processing_start = datetime_now() @@ -368,7 +366,7 @@ async def start(self, force_rundir: bool = False) -> None: min_length=2 * self._window_padding, ) - # console = asyncio.create_task(self._runtime_stats.live_view()) + console = asyncio.create_task(self._runtime_stats.live_view()) async for images, batch in self.image_functions.iter_images(waveform_iterator): images.set_stations(self.stations) @@ -400,7 +398,7 @@ async def start(self, force_rundir: bool = False) -> None: batch_processing_start = datetime_now() self.set_progress(batch.end_time) - # console.cancel() + console.cancel() self._detections.dump_detections(jitter_location=self.octree.size_limit) logger.info("finished search in %s", datetime_now() - processing_start) logger.info("found %d detections", self._detections.n_detections) diff --git a/lassie/stats.py b/lassie/stats.py index 4ad11284..753b403c 100644 --- a/lassie/stats.py +++ b/lassie/stats.py @@ -2,9 +2,10 @@ import asyncio import logging -from typing import Iterator, Type +from typing import Any, Iterator, Type +from weakref import WeakValueDictionary -from pydantic import BaseModel, create_model +from pydantic import BaseModel, PrivateAttr, create_model from pydantic.fields import ComputedFieldInfo, FieldInfo from rich.console import Group from rich.live import Live @@ -12,9 +13,12 @@ from rich.progress import Progress from rich.table import Table +from lassie.utils import CONSOLE + logger = logging.getLogger(__name__) STATS_CLASSES: set[Type[Stats]] = set() +STATS_INSTANCES: WeakValueDictionary[str, Stats] = WeakValueDictionary() PROGRESS = Progress() @@ -29,29 +33,28 @@ class RuntimeStats(BaseModel): def new(cls) -> RuntimeStats: return create_model( "RuntimeStats", - **{stats.__name__: (stats, None) for stats in STATS_CLASSES}, + **{ + stats.__name__: (stats, STATS_INSTANCES.get(stats.__name__, None)) + for stats in STATS_CLASSES + }, __base__=cls, )() def __rich__(self) -> Group: return Group( - *(getattr(self, stat_name) for stat_name in self.model_fields_set), + *( + getattr(self, stat_name) + for stat_name in self.model_fields + if getattr(self, stat_name, None) + ), PROGRESS, ) - def add_stats(self, stats: Stats) -> None: - logger.debug("Adding stats %s", stats.__class__.__name__) - if stats.__class__.__name__ not in self.model_fields: - raise ValueError(f"{stats.__class__.__name__} is not a valid stats name") - if stats.__class__.__name__ in self.model_fields_set: - raise ValueError(f"{stats.__class__.__name__} is already set") - setattr(self, stats.__class__.__name__, stats) - async def live_view(self): with Live( self, + console=CONSOLE, refresh_per_second=10, - screen=True, auto_refresh=True, redirect_stdout=True, redirect_stderr=True, @@ -61,9 +64,14 @@ async def live_view(self): class Stats(BaseModel): + _position: int = PrivateAttr(0) + def __init_subclass__(cls: Type[Stats], **kwargs) -> None: STATS_CLASSES.add(cls) + def model_post_init(self, __context: Any) -> None: + STATS_INSTANCES[self.__class__.__name__] = self + def populate_table(self, table: Table) -> None: for name, field in self.iter_fields(): title = field.title or titelify(name) diff --git a/lassie/utils.py b/lassie/utils.py index 0a377f1d..7e24943a 100644 --- a/lassie/utils.py +++ b/lassie/utils.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, constr from pyrocko.util import UnavailableDecimation +from rich.console import Console from rich.logging import RichHandler if TYPE_CHECKING: @@ -17,6 +18,8 @@ logger = logging.getLogger(__name__) FORMAT = "%(message)s" +CONSOLE = Console() + PhaseDescription = Annotated[str, constr(pattern=r"[a-zA-Z]*:[a-zA-Z]*")] @@ -41,7 +44,10 @@ class ANSI: def setup_rich_logging(level: int) -> None: logging.basicConfig( - level=level, format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] + level=level, + format=FORMAT, + datefmt="[%X]", + handlers=[RichHandler()], ) diff --git a/lassie/waveforms/squirrel.py b/lassie/waveforms/squirrel.py index c7cb4b9c..1458c14b 100644 --- a/lassie/waveforms/squirrel.py +++ b/lassie/waveforms/squirrel.py @@ -108,7 +108,7 @@ class PyrockoSquirrel(WaveformProvider): default=Path("."), description="Path to a Squirrel environment.", ) - waveform_dirs: list[DirectoryPath] = Field( + waveform_dirs: list[Path] = Field( default=[], description="List of directories holding the waveform files.", ) diff --git a/pyproject.toml b/pyproject.toml index 212beb83..47acf54c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies = [ "rich>=13.4", "nest_asyncio>=1.5", "pyevtk>=1.6", - "pytorch>=2.1", ]