Skip to content

Commit

Permalink
stats: adding statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Nov 12, 2023
1 parent 2a5cd7b commit a8b3924
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 25 deletions.
16 changes: 7 additions & 9 deletions lassie/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,7 @@ class Search(BaseModel):
_new_detection: Signal[EventDetection] = PrivateAttr(Signal())

_stats: SearchStats = PrivateAttr(SearchStats())
_runtime_stats: RuntimeStats = PrivateAttr(default_factory=RuntimeStats.new)

def model_post_init(self, *args) -> None:
self._runtime_stats.add_stats(self._stats)
self._runtime_stats.add_stats(self.data_provider._stats)
self._runtime_stats.add_stats(self.image_functions._stats)
_runtime_stats: RuntimeStats = PrivateAttr(None)

def init_rundir(self, force: bool = False) -> None:
rundir = (
Expand Down Expand Up @@ -326,7 +321,8 @@ def init_boundaries(self) -> None:
f"window length {self.window_length} is too short for the "
f"theoretical travel time range {self._shift_range} and "
f"cummulative window padding of {self._window_padding}."
" Increase the window_length time."
" Increase the window_length time to at least "
f"{self._shift_range +2*self._window_padding }"
)

logger.info("using trace window padding: %s", self._window_padding)
Expand Down Expand Up @@ -354,7 +350,9 @@ async def start(self, force_rundir: bool = False) -> None:
await self.prepare()

logger.info("starting search...")
self._runtime_stats = RuntimeStats.new()
stats = self._stats

batch_processing_start = datetime_now()
processing_start = datetime_now()

Expand All @@ -368,7 +366,7 @@ async def start(self, force_rundir: bool = False) -> None:
min_length=2 * self._window_padding,
)

# console = asyncio.create_task(self._runtime_stats.live_view())
console = asyncio.create_task(self._runtime_stats.live_view())

async for images, batch in self.image_functions.iter_images(waveform_iterator):
images.set_stations(self.stations)
Expand Down Expand Up @@ -400,7 +398,7 @@ async def start(self, force_rundir: bool = False) -> None:
batch_processing_start = datetime_now()
self.set_progress(batch.end_time)

# console.cancel()
console.cancel()
self._detections.dump_detections(jitter_location=self.octree.size_limit)
logger.info("finished search in %s", datetime_now() - processing_start)
logger.info("found %d detections", self._detections.n_detections)
Expand Down
34 changes: 21 additions & 13 deletions lassie/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@

import asyncio
import logging
from typing import Iterator, Type
from typing import Any, Iterator, Type
from weakref import WeakValueDictionary

from pydantic import BaseModel, create_model
from pydantic import BaseModel, PrivateAttr, create_model
from pydantic.fields import ComputedFieldInfo, FieldInfo
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from rich.table import Table

from lassie.utils import CONSOLE

logger = logging.getLogger(__name__)

STATS_CLASSES: set[Type[Stats]] = set()
STATS_INSTANCES: WeakValueDictionary[str, Stats] = WeakValueDictionary()


PROGRESS = Progress()
Expand All @@ -29,29 +33,28 @@ class RuntimeStats(BaseModel):
def new(cls) -> RuntimeStats:
return create_model(
"RuntimeStats",
**{stats.__name__: (stats, None) for stats in STATS_CLASSES},
**{
stats.__name__: (stats, STATS_INSTANCES.get(stats.__name__, None))
for stats in STATS_CLASSES
},
__base__=cls,
)()

def __rich__(self) -> Group:
return Group(
*(getattr(self, stat_name) for stat_name in self.model_fields_set),
*(
getattr(self, stat_name)
for stat_name in self.model_fields
if getattr(self, stat_name, None)
),
PROGRESS,
)

def add_stats(self, stats: Stats) -> None:
logger.debug("Adding stats %s", stats.__class__.__name__)
if stats.__class__.__name__ not in self.model_fields:
raise ValueError(f"{stats.__class__.__name__} is not a valid stats name")
if stats.__class__.__name__ in self.model_fields_set:
raise ValueError(f"{stats.__class__.__name__} is already set")
setattr(self, stats.__class__.__name__, stats)

async def live_view(self):
with Live(
self,
console=CONSOLE,
refresh_per_second=10,
screen=True,
auto_refresh=True,
redirect_stdout=True,
redirect_stderr=True,
Expand All @@ -61,9 +64,14 @@ async def live_view(self):


class Stats(BaseModel):
_position: int = PrivateAttr(0)

def __init_subclass__(cls: Type[Stats], **kwargs) -> None:
STATS_CLASSES.add(cls)

def model_post_init(self, __context: Any) -> None:
STATS_INSTANCES[self.__class__.__name__] = self

def populate_table(self, table: Table) -> None:
for name, field in self.iter_fields():
title = field.title or titelify(name)
Expand Down
8 changes: 7 additions & 1 deletion lassie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pydantic import BaseModel, constr
from pyrocko.util import UnavailableDecimation
from rich.console import Console
from rich.logging import RichHandler

if TYPE_CHECKING:
Expand All @@ -17,6 +18,8 @@
logger = logging.getLogger(__name__)
FORMAT = "%(message)s"

CONSOLE = Console()


PhaseDescription = Annotated[str, constr(pattern=r"[a-zA-Z]*:[a-zA-Z]*")]

Expand All @@ -41,7 +44,10 @@ class ANSI:

def setup_rich_logging(level: int) -> None:
logging.basicConfig(
level=level, format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]
level=level,
format=FORMAT,
datefmt="[%X]",
handlers=[RichHandler()],
)


Expand Down
2 changes: 1 addition & 1 deletion lassie/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class PyrockoSquirrel(WaveformProvider):
default=Path("."),
description="Path to a Squirrel environment.",
)
waveform_dirs: list[DirectoryPath] = Field(
waveform_dirs: list[Path] = Field(
default=[],
description="List of directories holding the waveform files.",
)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ dependencies = [
"rich>=13.4",
"nest_asyncio>=1.5",
"pyevtk>=1.6",
"pytorch>=2.1",

]

Expand Down

0 comments on commit a8b3924

Please sign in to comment.