Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Oct 18, 2023
1 parent 91b830c commit 6dafdfc
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 20 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: "v0.0.291"
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.9.1
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
# language_version: python3.9
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: "v0.0.287"
hooks:
- id: ruff
6 changes: 3 additions & 3 deletions lassie/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, Literal

from obspy import Stream
from pydantic import PositiveFloat, PositiveInt, PrivateAttr, conint
from pydantic import Field, PositiveFloat, PositiveInt, PrivateAttr
from pyrocko import obspy_compat
from seisbench import logger

Expand Down Expand Up @@ -84,10 +84,10 @@ def search_phase_arrival(
class PhaseNet(ImageFunction):
image: Literal["PhaseNet"] = "PhaseNet"
model: ModelName = "ethz"
window_overlap_samples: conint(ge=1000, le=3000) = 2000
window_overlap_samples: int = Field(default=2000, ge=1000, le=3000)
torch_use_cuda: bool = False
torch_cpu_threads: PositiveInt = 4
batch_size: conint(ge=64) = 64
batch_size: int = Field(default=64, ge=64)
stack_method: StackMethod = "avg"
phase_map: dict[PhaseName, str] = {
"P": "constant:P",
Expand Down
2 changes: 1 addition & 1 deletion lassie/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from pyrocko.model import dump_stations_yaml, load_stations

if TYPE_CHECKING:
from pyrocko.trace import Trace
from pyrocko.squirrel import Squirrel
from pyrocko.trace import Trace

from lassie.models.location import CoordSystem, Location

Expand Down
5 changes: 2 additions & 3 deletions lassie/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Field,
PositiveFloat,
PrivateAttr,
confloat,
field_validator,
model_validator,
)
Expand Down Expand Up @@ -66,7 +65,7 @@ class Node(BaseModel):
semblance: float = 0.0

tree: Octree | None = Field(None, exclude=True)
children: tuple[Node, ...] = Field((), exclude=True)
children: tuple[Node, ...] = Field(default=(), exclude=True)

_hash: bytes | None = PrivateAttr(None)
_children_cached: tuple[Node, ...] = PrivateAttr(())
Expand Down Expand Up @@ -185,7 +184,7 @@ class Octree(BaseModel):
east_bounds: tuple[float, float] = (-10 * KM, 10 * KM)
north_bounds: tuple[float, float] = (-10 * KM, 10 * KM)
depth_bounds: tuple[float, float] = (0 * KM, 20 * KM)
absorbing_boundary: confloat(ge=0.0) = 1 * KM
absorbing_boundary: float = Field(default=1 * KM, ge=0.0)

_root_nodes: list[Node] = PrivateAttr([])
_cached_coordinates: dict[CoordSystem, np.ndarray] = PrivateAttr({})
Expand Down
2 changes: 1 addition & 1 deletion lassie/tracers/fast_marching/fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ async def _calculate_travel_times(

async def worker_station_travel_time(station: Station) -> None:
volume = await StationTravelTimeVolume.calculate_from_eikonal(
self._velocity_model, # noqa
self._velocity_model,
station,
save=cache_dir,
executor=executor,
Expand Down
39 changes: 35 additions & 4 deletions lassie/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from pathlib import Path
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Literal

from pydantic import AwareDatetime, PositiveInt, PrivateAttr, constr, model_validator
from pydantic import (
AwareDatetime,
Field,
PositiveFloat,
PositiveInt,
PrivateAttr,
model_validator,
)
from pyrocko.squirrel import Squirrel
from typing_extensions import Self

Expand All @@ -22,17 +29,36 @@


class SquirrelPrefetcher:
def __init__(self, iterator: Iterator[Batch], queue_size: int = 4) -> None:
def __init__(
self,
iterator: Iterator[Batch],
queue_size: int = 4,
freq_min: float | None = None,
freq_max: float | None = None,
) -> None:
self.iterator = iterator
self.queue: asyncio.Queue[Batch | None] = asyncio.Queue(maxsize=queue_size)
self.freq_min = freq_min
self.freq_max = freq_max

self._task = asyncio.create_task(self.prefetch_worker())

async def prefetch_worker(self) -> None:
logger.info("start prefetching squirrel data")

def filter_freqs(batch: Batch) -> Batch:
if self.freq_min:
for tr in batch.traces:
tr.highpass(4, self.freq_min)
if self.freq_max:
for tr in batch.traces:
tr.lowpass(4, self.freq_max)
return batch

while True:
start = datetime_now()
batch = await asyncio.to_thread(lambda: next(self.iterator, None))
await asyncio.to_thread(filter_freqs, batch)
logger.debug("prefetched waveforms in %s", datetime_now() - start)
if batch is None:
logger.debug("squirrel prefetcher finished")
Expand All @@ -49,16 +75,21 @@ class PyrockoSquirrel(WaveformProvider):
start_time: AwareDatetime | None = None
end_time: AwareDatetime | None = None

channel_selector: constr(max_length=3) = "*"
freq_min: PositiveFloat | None = None
freq_max: PositiveFloat | None = None

channel_selector: str = Field(default="*", max_length=3)
async_prefetch_batches: PositiveInt = 4

_squirrel: Squirrel | None = PrivateAttr(None)
_stations: Stations = PrivateAttr()

@model_validator(mode="after")
def _validate_time_span(self) -> Self: # noqa: N805
def _validate_time_span(self) -> Self:
if self.start_time and self.end_time and self.start_time > self.end_time:
raise ValueError("start_time must be before end_time")
if self.freq_min and self.freq_max and self.freq_min > self.freq_max:
raise ValueError("freq_min must be less than freq_max")
return self

def get_squirrel(self) -> Squirrel:
Expand Down
18 changes: 16 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"scipy>=1.8.0",
"pyrocko>=2022.06.10",
"seisbench>=0.5.0",
"pydantic>=2.3",
"pydantic>=2.4.2",
"aiohttp>=3.8",
"aiohttp_cors>=0.7.0",
"typing-extensions>=4.6",
Expand Down Expand Up @@ -81,8 +81,22 @@ Issues = "https://git.pyrocko.org/pyrocko/lassie/issues"
[tool.setuptools_scm]

[tool.ruff]
extend-select = ['W', 'N', 'DTZ', 'FA', 'G', 'RET', 'SIM', 'B', 'RET', 'C4']
extend-select = [
'W',
'N',
'DTZ',
'FA',
'G',
'RET',
'SIM',
'B',
'RET',
'C4',
'I',
'RUF',
]
target-version = 'py310'
ignore = ["RUF012", "RUF009"]

[tool.pytest.ini_options]
markers = ["plot: plot figures in tests"]

0 comments on commit 6dafdfc

Please sign in to comment.