From 9a1c9e0032ac564c12e0b324000b0a45d944af64 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Thu, 26 Oct 2023 12:17:20 +0200 Subject: [PATCH] fixup --- lassie/apps/lassie.py | 46 ++++++++++++++++++++++++------------ lassie/images/__init__.py | 3 ++- lassie/images/base.py | 2 +- lassie/search.py | 4 +++- lassie/tracers/__init__.py | 19 +++++++++++---- lassie/tracers/cake.py | 28 +++++++++++++++------- lassie/waveforms/squirrel.py | 5 +--- pyproject.toml | 10 ++++---- 8 files changed, 77 insertions(+), 40 deletions(-) diff --git a/lassie/apps/lassie.py b/lassie/apps/lassie.py index f64c402d..a151d658 100644 --- a/lassie/apps/lassie.py +++ b/lassie/apps/lassie.py @@ -2,6 +2,7 @@ import argparse import asyncio +import json import logging import shutil from pathlib import Path @@ -24,14 +25,15 @@ def main() -> None: parser = argparse.ArgumentParser( prog="lassie", - description="The friendly earthquake detector - V2", + description="Lassie - The friendly earthquake detector 🐕", ) parser.add_argument( "--verbose", "-v", action="count", default=0, - help="increase verbosity of the log messages, default level is INFO", + help="increase verbosity of the log messages, repeat to increase. " + "Default level is INFO", ) parser.add_argument( "--version", @@ -40,11 +42,28 @@ def main() -> None: help="show version and exit", ) - subparsers = parser.add_subparsers(title="commands", required=True, dest="command") + subparsers = parser.add_subparsers( + title="commands", + required=True, + dest="command", + description="Available commands to run Lassie. Get command help with " + "`lassie --help`.", + ) + + init_project = subparsers.add_parser( + "init", + help="initialize a new Lassie project", + description="initialze a new project with a default configuration file. ", + ) + init_project.add_argument( + "folder", + type=Path, + help="folder to initialize project in", + ) run = subparsers.add_parser( "search", - help="start a search 🐕", + help="start a search", description="detect, localize and characterize earthquakes in a dataset", ) run.add_argument("config", type=Path, help="path to config file") @@ -62,14 +81,6 @@ def main() -> None: ) continue_run.add_argument("rundir", type=Path, help="existing runding to continue") - init_project = subparsers.add_parser( - "init", - help="initialize a new Lassie project", - ) - init_project.add_argument( - "folder", type=Path, help="folder to initialize project in" - ) - features = subparsers.add_parser( "feature-extraction", help="extract features from an existing run", @@ -100,11 +111,14 @@ def main() -> None: subparsers.add_parser( "clear-cache", help="clear the cach directory", + description="clear all data in the cache directory", ) dump_schemas = subparsers.add_parser( "dump-schemas", help="dump data models to json-schema (development)", + description="dump data models to json-schema, " + "this is for development purposes only", ) dump_schemas.add_argument("folder", type=Path, help="folder to dump schemas to") @@ -128,7 +142,7 @@ def main() -> None: logger.info("initialized new project in folder %s", folder) logger.info("start detection with: lassie run %s", config_file.name) - elif args.command == "run": + elif args.command == "search": search = Search.from_config(args.config) webserver = WebServer(search) @@ -192,10 +206,12 @@ async def extract() -> None: file = args.folder / "search.schema.json" print(f"writing JSON schemas to {args.folder}") - file.write_text(Search.model_json_schema(indent=2)) + file.write_text(json.dumps(Search.model_json_schema(), indent=2)) file = args.folder / "detections.schema.json" - file.write_text(EventDetections.model_json_schema(indent=2)) + file.write_text(json.dumps(EventDetections.model_json_schema(), indent=2)) + else: + parser.error(f"unknown command: {args.command}") if __name__ == "__main__": diff --git a/lassie/images/__init__.py b/lassie/images/__init__.py index 9ad3bcd3..24db78a9 100644 --- a/lassie/images/__init__.py +++ b/lassie/images/__init__.py @@ -9,6 +9,7 @@ from lassie.images.base import ImageFunction, PickedArrival from lassie.images.phase_net import PhaseNet, PhaseNetPick +from lassie.utils import PhaseDescription if TYPE_CHECKING: from datetime import timedelta @@ -51,7 +52,7 @@ async def process_traces(self, traces: list[Trace]) -> WaveformImages: return WaveformImages(root=images) - def get_phases(self) -> tuple[str, ...]: + def get_phases(self) -> tuple[PhaseDescription, ...]: """Get all phases that are available in the image functions. Returns: diff --git a/lassie/images/base.py b/lassie/images/base.py index 7f4fd503..5d539e59 100644 --- a/lassie/images/base.py +++ b/lassie/images/base.py @@ -35,7 +35,7 @@ def blinding(self) -> timedelta: """Blinding duration for the image function. Added to padded waveforms.""" raise NotImplementedError("must be implemented by subclass") - def get_provided_phases(self) -> tuple[str, ...]: + def get_provided_phases(self) -> tuple[PhaseDescription, ...]: ... diff --git a/lassie/search.py b/lassie/search.py index 6eaa54fe..5337dce3 100644 --- a/lassie/search.py +++ b/lassie/search.py @@ -160,7 +160,9 @@ def init_boundaries(self) -> None: self._distance_range = (distances.min(), distances.max()) # Timing ranges - for phase, tracer in self.ray_tracers.iter_phase_tracer(): + for phase, tracer in self.ray_tracers.iter_phase_tracer( + phases=self.image_functions.get_phases() + ): traveltimes = tracer.get_travel_times(phase, self.octree, self.stations) self._travel_time_ranges[phase] = ( timedelta(seconds=np.nanmin(traveltimes)), diff --git a/lassie/tracers/__init__.py b/lassie/tracers/__init__.py index 8a0af1e8..3f262064 100644 --- a/lassie/tracers/__init__.py +++ b/lassie/tracers/__init__.py @@ -41,10 +41,17 @@ async def prepare( stations: Stations, phases: tuple[PhaseDescription, ...], ) -> None: - logger.info("preparing ray tracers") + prepared_tracers = [] for phase in phases: tracer = self.get_phase_tracer(phase) + if tracer in prepared_tracers: + continue + phases = tracer.get_available_phases() + logger.info( + "preparing ray tracer %s for phase %s", tracer.tracer, ", ".join(phases) + ) await tracer.prepare(octree, stations) + prepared_tracers.append(tracer) def get_available_phases(self) -> tuple[str, ...]: phases = [] @@ -71,7 +78,9 @@ def get_phase_tracer(self, phase: str) -> RayTracer: def __iter__(self) -> Iterator[RayTracer]: yield from self.root - def iter_phase_tracer(self) -> Iterator[tuple[PhaseDescription, RayTracer]]: - for tracer in self: - for phase in tracer.get_available_phases(): - yield (phase, tracer) + def iter_phase_tracer( + self, phases: tuple[PhaseDescription, ...] + ) -> Iterator[tuple[PhaseDescription, RayTracer]]: + for phase in phases: + tracer = self.get_phase_tracer(phase) + yield (phase, tracer) diff --git a/lassie/tracers/cake.py b/lassie/tracers/cake.py index 3edee3fc..88e958c0 100644 --- a/lassie/tracers/cake.py +++ b/lassie/tracers/cake.py @@ -21,6 +21,7 @@ Field, FilePath, PrivateAttr, + ValidationError, constr, model_validator, ) @@ -114,7 +115,7 @@ class EarthModel(BaseModel): @model_validator(mode="after") def load_model(self) -> EarthModel: - if self.filename is not None: + if self.filename is not None and self.raw_file_data is None: logger.info("loading velocity model from %s", self.filename) self.raw_file_data = self.filename.read_text() @@ -260,7 +261,7 @@ def check_bounds(self, requested: tuple[float, float]) -> bool: return self[0] <= requested[0] and self[1] >= requested[1] return ( - str(self.earthmodel) == str(earthmodel.hash) + self.earthmodel.hash == earthmodel.hash and self.timing == timing and check_bounds(self.distance_bounds, distance_bounds) and check_bounds(self.source_depth_bounds, source_depth_bounds) @@ -331,7 +332,7 @@ def load(cls, file: Path) -> Self: Returns: Self: Loaded SPTreeModel """ - logger.debug("loading traveltimes from %s", file) + logger.debug("loading cached traveltimes from %s", file) with zipfile.ZipFile(file, "r") as archive: path = zipfile.Path(archive) model_file = path / "model.json" @@ -368,6 +369,7 @@ def _interpolate_traveltimes_sptree( ) def init_lut(self, octree: Octree, stations: Stations) -> None: + logger.debug("initializing LUT for %d stations", stations.n_stations) self._cached_stations = stations self._cached_station_indeces = { sta.pretty_nsl: idx for idx, sta in enumerate(stations) @@ -497,7 +499,7 @@ class CakeTracer(RayTracer): "cake:P": Timing(definition="P,p"), "cake:S": Timing(definition="S,s"), } - earthmodel: EarthModel = EarthModel() + earthmodel: EarthModel = Field(default_factory=EarthModel) trim_earth_model_depth: bool = Field( default=True, description="Trim earth model to max depth of the octree.", @@ -546,10 +548,7 @@ async def prepare(self, octree: Octree, stations: Stations) -> None: node_cache_fraction * 100, ) - cached_trees = [ - TravelTimeTree.load(file) for file in self.cache_dir.glob("*.sptree") - ] - logger.debug("loaded %d cached travel time trees", len(cached_trees)) + cached_trees = self._load_cached_trees() distances = surface_distances(octree, stations) source_depths = np.asarray(octree.depth_bounds) - octree.reference.elevation @@ -588,6 +587,19 @@ async def prepare(self, octree: Octree, stations: Stations) -> None: def _get_sptree_model(self, phase: str) -> TravelTimeTree: return self._traveltime_trees[phase] + def _load_cached_trees(self) -> list[TravelTimeTree]: + trees = [] + for file in self.cache_dir.glob("*.sptree"): + try: + tree = TravelTimeTree.load(file) + except ValidationError: + logger.warning("deleting invalid cached tree %s", file) + file.unlink() + continue + trees.append(tree) + logger.debug("loaded %d cached travel time trees", len(trees)) + return trees + def get_travel_time_location( self, phase: str, diff --git a/lassie/waveforms/squirrel.py b/lassie/waveforms/squirrel.py index 3728068c..bad6fe5e 100644 --- a/lassie/waveforms/squirrel.py +++ b/lassie/waveforms/squirrel.py @@ -104,12 +104,9 @@ def _validate_model(self) -> Self: return self @field_validator("waveform_dirs") - def check_dirs(self, dirs: list[Path]) -> list[Path]: + def check_dirs(cls, dirs: list[Path]) -> list[Path]: # noqa: N805 if not dirs: raise ValueError("no waveform directories provided!") - for data_dir in dirs: - if not data_dir.exists(): - raise ValueError(f"waveform directory {data_dir} does not exist") return dirs def get_squirrel(self) -> Squirrel: diff --git a/pyproject.toml b/pyproject.toml index afe571e1..e7393335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,16 +41,16 @@ dependencies = [ ] classifiers = [ + "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Atmospheric Science", - "Topic :: Scientific/Engineering :: Image Recognition", "Topic :: Scientific/Engineering :: Physics", - "Topic :: Scientific/Engineering :: Visualization", - "Programming Language :: Python :: 3.7", - "Programming Language :: C", + "Topic :: Scientific/Engineering :: Information Analysis", "Operating System :: POSIX", "Operating System :: MacOS", + "Typing :: Typed", + "Programming Language :: Python :: 3 :: Only", + "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", ] [project.optional-dependencies]