diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df667eb2..2331cf87 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,13 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files + - repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: "v0.0.291" + hooks: + - id: ruff - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.9.1 hooks: - id: black # It is recommended to specify the latest version of Python @@ -17,8 +22,3 @@ repos: # pre-commit's default_language_version, see # https://pre-commit.com/#top_level-default_language_version # language_version: python3.9 - - repo: https://github.com/charliermarsh/ruff-pre-commit - # Ruff version. - rev: "v0.0.287" - hooks: - - id: ruff diff --git a/lassie/images/phase_net.py b/lassie/images/phase_net.py index 670d8bfe..daeed7da 100644 --- a/lassie/images/phase_net.py +++ b/lassie/images/phase_net.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Literal from obspy import Stream -from pydantic import PositiveFloat, PositiveInt, PrivateAttr, conint +from pydantic import Field, PositiveFloat, PositiveInt, PrivateAttr from pyrocko import obspy_compat from seisbench import logger @@ -84,10 +84,10 @@ def search_phase_arrival( class PhaseNet(ImageFunction): image: Literal["PhaseNet"] = "PhaseNet" model: ModelName = "ethz" - window_overlap_samples: conint(ge=1000, le=3000) = 2000 + window_overlap_samples: int = Field(default=2000, ge=1000, le=3000) torch_use_cuda: bool = False torch_cpu_threads: PositiveInt = 4 - batch_size: conint(ge=64) = 64 + batch_size: int = Field(default=64, ge=64) stack_method: StackMethod = "avg" phase_map: dict[PhaseName, str] = { "P": "constant:P", diff --git a/lassie/models/station.py b/lassie/models/station.py index 6e6054a2..88d4a3c4 100644 --- a/lassie/models/station.py +++ b/lassie/models/station.py @@ -11,8 +11,8 @@ from pyrocko.model import dump_stations_yaml, load_stations if TYPE_CHECKING: - from pyrocko.trace import Trace from pyrocko.squirrel import Squirrel + from pyrocko.trace import Trace from lassie.models.location import CoordSystem, Location diff --git a/lassie/octree.py b/lassie/octree.py index c43311e3..6101a662 100644 --- a/lassie/octree.py +++ b/lassie/octree.py @@ -15,7 +15,6 @@ Field, PositiveFloat, PrivateAttr, - confloat, field_validator, model_validator, ) @@ -66,7 +65,7 @@ class Node(BaseModel): semblance: float = 0.0 tree: Octree | None = Field(None, exclude=True) - children: tuple[Node, ...] = Field((), exclude=True) + children: tuple[Node, ...] = Field(default=(), exclude=True) _hash: bytes | None = PrivateAttr(None) _children_cached: tuple[Node, ...] = PrivateAttr(()) @@ -185,7 +184,7 @@ class Octree(BaseModel): east_bounds: tuple[float, float] = (-10 * KM, 10 * KM) north_bounds: tuple[float, float] = (-10 * KM, 10 * KM) depth_bounds: tuple[float, float] = (0 * KM, 20 * KM) - absorbing_boundary: confloat(ge=0.0) = 1 * KM + absorbing_boundary: float = Field(default=1 * KM, ge=0.0) _root_nodes: list[Node] = PrivateAttr([]) _cached_coordinates: dict[CoordSystem, np.ndarray] = PrivateAttr({}) diff --git a/lassie/tracers/fast_marching/fast_marching.py b/lassie/tracers/fast_marching/fast_marching.py index 32ce96d5..7f20101b 100644 --- a/lassie/tracers/fast_marching/fast_marching.py +++ b/lassie/tracers/fast_marching/fast_marching.py @@ -401,7 +401,7 @@ async def _calculate_travel_times( async def worker_station_travel_time(station: Station) -> None: volume = await StationTravelTimeVolume.calculate_from_eikonal( - self._velocity_model, # noqa + self._velocity_model, station, save=cache_dir, executor=executor, diff --git a/lassie/waveforms/squirrel.py b/lassie/waveforms/squirrel.py index 28b40b62..268de664 100644 --- a/lassie/waveforms/squirrel.py +++ b/lassie/waveforms/squirrel.py @@ -7,7 +7,14 @@ from pathlib import Path from typing import TYPE_CHECKING, AsyncIterator, Iterator, Literal -from pydantic import AwareDatetime, PositiveInt, PrivateAttr, constr, model_validator +from pydantic import ( + AwareDatetime, + Field, + PositiveFloat, + PositiveInt, + PrivateAttr, + model_validator, +) from pyrocko.squirrel import Squirrel from typing_extensions import Self @@ -22,17 +29,36 @@ class SquirrelPrefetcher: - def __init__(self, iterator: Iterator[Batch], queue_size: int = 4) -> None: + def __init__( + self, + iterator: Iterator[Batch], + queue_size: int = 4, + freq_min: float | None = None, + freq_max: float | None = None, + ) -> None: self.iterator = iterator self.queue: asyncio.Queue[Batch | None] = asyncio.Queue(maxsize=queue_size) + self.freq_min = freq_min + self.freq_max = freq_max self._task = asyncio.create_task(self.prefetch_worker()) async def prefetch_worker(self) -> None: logger.info("start prefetching squirrel data") + + def filter_freqs(batch: Batch) -> Batch: + if self.freq_min: + for tr in batch.traces: + tr.highpass(4, self.freq_min) + if self.freq_max: + for tr in batch.traces: + tr.lowpass(4, self.freq_max) + return batch + while True: start = datetime_now() batch = await asyncio.to_thread(lambda: next(self.iterator, None)) + await asyncio.to_thread(filter_freqs, batch) logger.debug("prefetched waveforms in %s", datetime_now() - start) if batch is None: logger.debug("squirrel prefetcher finished") @@ -49,16 +75,21 @@ class PyrockoSquirrel(WaveformProvider): start_time: AwareDatetime | None = None end_time: AwareDatetime | None = None - channel_selector: constr(max_length=3) = "*" + freq_min: PositiveFloat | None = None + freq_max: PositiveFloat | None = None + + channel_selector: str = Field(default="*", max_length=3) async_prefetch_batches: PositiveInt = 4 _squirrel: Squirrel | None = PrivateAttr(None) _stations: Stations = PrivateAttr() @model_validator(mode="after") - def _validate_time_span(self) -> Self: # noqa: N805 + def _validate_time_span(self) -> Self: if self.start_time and self.end_time and self.start_time > self.end_time: raise ValueError("start_time must be before end_time") + if self.freq_min and self.freq_max and self.freq_min > self.freq_max: + raise ValueError("freq_min must be less than freq_max") return self def get_squirrel(self) -> Squirrel: diff --git a/pyproject.toml b/pyproject.toml index a7614309..afe571e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "scipy>=1.8.0", "pyrocko>=2022.06.10", "seisbench>=0.5.0", - "pydantic>=2.3", + "pydantic>=2.4.2", "aiohttp>=3.8", "aiohttp_cors>=0.7.0", "typing-extensions>=4.6", @@ -81,8 +81,22 @@ Issues = "https://git.pyrocko.org/pyrocko/lassie/issues" [tool.setuptools_scm] [tool.ruff] -extend-select = ['W', 'N', 'DTZ', 'FA', 'G', 'RET', 'SIM', 'B', 'RET', 'C4'] +extend-select = [ + 'W', + 'N', + 'DTZ', + 'FA', + 'G', + 'RET', + 'SIM', + 'B', + 'RET', + 'C4', + 'I', + 'RUF', +] target-version = 'py310' +ignore = ["RUF012", "RUF009"] [tool.pytest.ini_options] markers = ["plot: plot figures in tests"]