Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Oct 26, 2023
1 parent 737655c commit 9a1c9e0
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 40 deletions.
46 changes: 31 additions & 15 deletions lassie/apps/lassie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import asyncio
import json
import logging
import shutil
from pathlib import Path
Expand All @@ -24,14 +25,15 @@
def main() -> None:
parser = argparse.ArgumentParser(
prog="lassie",
description="The friendly earthquake detector - V2",
description="Lassie - The friendly earthquake detector 🐕",
)
parser.add_argument(
"--verbose",
"-v",
action="count",
default=0,
help="increase verbosity of the log messages, default level is INFO",
help="increase verbosity of the log messages, repeat to increase. "
"Default level is INFO",
)
parser.add_argument(
"--version",
Expand All @@ -40,11 +42,28 @@ def main() -> None:
help="show version and exit",
)

subparsers = parser.add_subparsers(title="commands", required=True, dest="command")
subparsers = parser.add_subparsers(
title="commands",
required=True,
dest="command",
description="Available commands to run Lassie. Get command help with "
"`lassie <command> --help`.",
)

init_project = subparsers.add_parser(
"init",
help="initialize a new Lassie project",
description="initialze a new project with a default configuration file. ",
)
init_project.add_argument(
"folder",
type=Path,
help="folder to initialize project in",
)

run = subparsers.add_parser(
"search",
help="start a search 🐕",
help="start a search",
description="detect, localize and characterize earthquakes in a dataset",
)
run.add_argument("config", type=Path, help="path to config file")
Expand All @@ -62,14 +81,6 @@ def main() -> None:
)
continue_run.add_argument("rundir", type=Path, help="existing runding to continue")

init_project = subparsers.add_parser(
"init",
help="initialize a new Lassie project",
)
init_project.add_argument(
"folder", type=Path, help="folder to initialize project in"
)

features = subparsers.add_parser(
"feature-extraction",
help="extract features from an existing run",
Expand Down Expand Up @@ -100,11 +111,14 @@ def main() -> None:
subparsers.add_parser(
"clear-cache",
help="clear the cach directory",
description="clear all data in the cache directory",
)

dump_schemas = subparsers.add_parser(
"dump-schemas",
help="dump data models to json-schema (development)",
description="dump data models to json-schema, "
"this is for development purposes only",
)
dump_schemas.add_argument("folder", type=Path, help="folder to dump schemas to")

Expand All @@ -128,7 +142,7 @@ def main() -> None:
logger.info("initialized new project in folder %s", folder)
logger.info("start detection with: lassie run %s", config_file.name)

elif args.command == "run":
elif args.command == "search":
search = Search.from_config(args.config)

webserver = WebServer(search)
Expand Down Expand Up @@ -192,10 +206,12 @@ async def extract() -> None:

file = args.folder / "search.schema.json"
print(f"writing JSON schemas to {args.folder}")
file.write_text(Search.model_json_schema(indent=2))
file.write_text(json.dumps(Search.model_json_schema(), indent=2))

file = args.folder / "detections.schema.json"
file.write_text(EventDetections.model_json_schema(indent=2))
file.write_text(json.dumps(EventDetections.model_json_schema(), indent=2))
else:
parser.error(f"unknown command: {args.command}")


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion lassie/images/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from lassie.images.base import ImageFunction, PickedArrival
from lassie.images.phase_net import PhaseNet, PhaseNetPick
from lassie.utils import PhaseDescription

if TYPE_CHECKING:
from datetime import timedelta
Expand Down Expand Up @@ -51,7 +52,7 @@ async def process_traces(self, traces: list[Trace]) -> WaveformImages:

return WaveformImages(root=images)

def get_phases(self) -> tuple[str, ...]:
def get_phases(self) -> tuple[PhaseDescription, ...]:
"""Get all phases that are available in the image functions.
Returns:
Expand Down
2 changes: 1 addition & 1 deletion lassie/images/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def blinding(self) -> timedelta:
"""Blinding duration for the image function. Added to padded waveforms."""
raise NotImplementedError("must be implemented by subclass")

def get_provided_phases(self) -> tuple[str, ...]:
def get_provided_phases(self) -> tuple[PhaseDescription, ...]:
...


Expand Down
4 changes: 3 additions & 1 deletion lassie/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def init_boundaries(self) -> None:
self._distance_range = (distances.min(), distances.max())

# Timing ranges
for phase, tracer in self.ray_tracers.iter_phase_tracer():
for phase, tracer in self.ray_tracers.iter_phase_tracer(
phases=self.image_functions.get_phases()
):
traveltimes = tracer.get_travel_times(phase, self.octree, self.stations)
self._travel_time_ranges[phase] = (
timedelta(seconds=np.nanmin(traveltimes)),
Expand Down
19 changes: 14 additions & 5 deletions lassie/tracers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,17 @@ async def prepare(
stations: Stations,
phases: tuple[PhaseDescription, ...],
) -> None:
logger.info("preparing ray tracers")
prepared_tracers = []
for phase in phases:
tracer = self.get_phase_tracer(phase)
if tracer in prepared_tracers:
continue
phases = tracer.get_available_phases()
logger.info(
"preparing ray tracer %s for phase %s", tracer.tracer, ", ".join(phases)
)
await tracer.prepare(octree, stations)
prepared_tracers.append(tracer)

def get_available_phases(self) -> tuple[str, ...]:
phases = []
Expand All @@ -71,7 +78,9 @@ def get_phase_tracer(self, phase: str) -> RayTracer:
def __iter__(self) -> Iterator[RayTracer]:
yield from self.root

def iter_phase_tracer(self) -> Iterator[tuple[PhaseDescription, RayTracer]]:
for tracer in self:
for phase in tracer.get_available_phases():
yield (phase, tracer)
def iter_phase_tracer(
self, phases: tuple[PhaseDescription, ...]
) -> Iterator[tuple[PhaseDescription, RayTracer]]:
for phase in phases:
tracer = self.get_phase_tracer(phase)
yield (phase, tracer)
28 changes: 20 additions & 8 deletions lassie/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Field,
FilePath,
PrivateAttr,
ValidationError,
constr,
model_validator,
)
Expand Down Expand Up @@ -114,7 +115,7 @@ class EarthModel(BaseModel):

@model_validator(mode="after")
def load_model(self) -> EarthModel:
if self.filename is not None:
if self.filename is not None and self.raw_file_data is None:
logger.info("loading velocity model from %s", self.filename)
self.raw_file_data = self.filename.read_text()

Expand Down Expand Up @@ -260,7 +261,7 @@ def check_bounds(self, requested: tuple[float, float]) -> bool:
return self[0] <= requested[0] and self[1] >= requested[1]

return (
str(self.earthmodel) == str(earthmodel.hash)
self.earthmodel.hash == earthmodel.hash
and self.timing == timing
and check_bounds(self.distance_bounds, distance_bounds)
and check_bounds(self.source_depth_bounds, source_depth_bounds)
Expand Down Expand Up @@ -331,7 +332,7 @@ def load(cls, file: Path) -> Self:
Returns:
Self: Loaded SPTreeModel
"""
logger.debug("loading traveltimes from %s", file)
logger.debug("loading cached traveltimes from %s", file)
with zipfile.ZipFile(file, "r") as archive:
path = zipfile.Path(archive)
model_file = path / "model.json"
Expand Down Expand Up @@ -368,6 +369,7 @@ def _interpolate_traveltimes_sptree(
)

def init_lut(self, octree: Octree, stations: Stations) -> None:
logger.debug("initializing LUT for %d stations", stations.n_stations)
self._cached_stations = stations
self._cached_station_indeces = {
sta.pretty_nsl: idx for idx, sta in enumerate(stations)
Expand Down Expand Up @@ -497,7 +499,7 @@ class CakeTracer(RayTracer):
"cake:P": Timing(definition="P,p"),
"cake:S": Timing(definition="S,s"),
}
earthmodel: EarthModel = EarthModel()
earthmodel: EarthModel = Field(default_factory=EarthModel)
trim_earth_model_depth: bool = Field(
default=True,
description="Trim earth model to max depth of the octree.",
Expand Down Expand Up @@ -546,10 +548,7 @@ async def prepare(self, octree: Octree, stations: Stations) -> None:
node_cache_fraction * 100,
)

cached_trees = [
TravelTimeTree.load(file) for file in self.cache_dir.glob("*.sptree")
]
logger.debug("loaded %d cached travel time trees", len(cached_trees))
cached_trees = self._load_cached_trees()

distances = surface_distances(octree, stations)
source_depths = np.asarray(octree.depth_bounds) - octree.reference.elevation
Expand Down Expand Up @@ -588,6 +587,19 @@ async def prepare(self, octree: Octree, stations: Stations) -> None:
def _get_sptree_model(self, phase: str) -> TravelTimeTree:
return self._traveltime_trees[phase]

def _load_cached_trees(self) -> list[TravelTimeTree]:
trees = []
for file in self.cache_dir.glob("*.sptree"):
try:
tree = TravelTimeTree.load(file)
except ValidationError:
logger.warning("deleting invalid cached tree %s", file)
file.unlink()
continue
trees.append(tree)
logger.debug("loaded %d cached travel time trees", len(trees))
return trees

def get_travel_time_location(
self,
phase: str,
Expand Down
5 changes: 1 addition & 4 deletions lassie/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,9 @@ def _validate_model(self) -> Self:
return self

@field_validator("waveform_dirs")
def check_dirs(self, dirs: list[Path]) -> list[Path]:
def check_dirs(cls, dirs: list[Path]) -> list[Path]: # noqa: N805
if not dirs:
raise ValueError("no waveform directories provided!")
for data_dir in dirs:
if not data_dir.exists():
raise ValueError(f"waveform directory {data_dir} does not exist")
return dirs

def get_squirrel(self) -> Squirrel:
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ dependencies = [
]

classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Atmospheric Science",
"Topic :: Scientific/Engineering :: Image Recognition",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Scientific/Engineering :: Visualization",
"Programming Language :: Python :: 3.7",
"Programming Language :: C",
"Topic :: Scientific/Engineering :: Information Analysis",
"Operating System :: POSIX",
"Operating System :: MacOS",
"Typing :: Typed",
"Programming Language :: Python :: 3 :: Only",
"License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)",
]

[project.optional-dependencies]
Expand Down

0 comments on commit 9a1c9e0

Please sign in to comment.