diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2331cf87..451c1ad0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,23 +2,24 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files + - id: mixed-line-ending - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: "v0.0.291" + rev: "v0.1.1" hooks: - id: ruff - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.0 hooks: - id: black # It is recommended to specify the latest version of Python # supported by your project here, or alternatively use # pre-commit's default_language_version, see # https://pre-commit.com/#top_level-default_language_version - # language_version: python3.9 + # language_version: python3.10 diff --git a/lassie/models/detection.py b/lassie/models/detection.py index d2fe8b4d..527d4c25 100644 --- a/lassie/models/detection.py +++ b/lassie/models/detection.py @@ -366,8 +366,7 @@ def as_pyrocko_event(self) -> Event: lon=self.lon, east_shift=self.east_shift, north_shift=self.north_shift, - depth=self.depth, - elevation=self.elevation, + depth=self.effective_depth, magnitude=self.magnitude or self.semblance, magnitude_type=self.magnitude_type, ) @@ -535,7 +534,7 @@ def save_csv(self, file: Path, jitter_location: float = 0.0) -> None: detection = detection.jitter_location(jitter_location) lat, lon = detection.effective_lat_lon lines.append( - f"{lat:.5f}, {lon:.5f}, {-detection.effective_elevation:.1f}," + f"{lat:.5f}, {lon:.5f}, {detection.effective_depth:.1f}," f" {detection.semblance}, {detection.time}, {detection.distance_border}" ) file.write_text("\n".join(lines)) diff --git a/lassie/models/location.py b/lassie/models/location.py index 391adb60..d1ab0c5a 100644 --- a/lassie/models/location.py +++ b/lassie/models/location.py @@ -5,7 +5,7 @@ import struct from typing import TYPE_CHECKING, Iterable, Literal, TypeVar -from pydantic import BaseModel, PrivateAttr +from pydantic import BaseModel, Field, PrivateAttr from pyrocko import orthodrome as od from typing_extensions import Self @@ -18,10 +18,22 @@ class Location(BaseModel): lat: float lon: float - east_shift: float = 0.0 - north_shift: float = 0.0 - elevation: float = 0.0 - depth: float = 0.0 + east_shift: float = Field( + default=0.0, + description="east shift towards geographical reference in meters.", + ) + north_shift: float = Field( + default=0.0, + description="north shift towards geographical reference in meters.", + ) + elevation: float = Field( + default=0.0, + description="elevation in meters.", + ) + depth: float = Field( + default=0.0, + description="depth in meters, positive is down.", + ) _cached_lat_lon: tuple[float, float] | None = PrivateAttr(None) diff --git a/lassie/search.py b/lassie/search.py index a9a47bbc..6eaa54fe 100644 --- a/lassie/search.py +++ b/lassie/search.py @@ -253,7 +253,7 @@ async def start(self, force_rundir: bool = False) -> None: self._detections.add(detection) await self._new_detection.emit(detection) - if batch.i_batch % 50 == 0: + if self._detections.n_detections % 50 == 0: self._detections.dump_detections(jitter_location=self.octree.size_limit) processing_time = datetime_now() - batch_processing_start @@ -283,7 +283,9 @@ async def start(self, force_rundir: bool = False) -> None: batch_processing_start = datetime_now() self.set_progress(batch.end_time) + self._detections.dump_detections(jitter_location=self.octree.size_limit) logger.info("finished search in %s", datetime_now() - processing_start) + logger.info("found %d detections", self._detections.n_detections) async def add_features(self, event: EventDetection) -> None: try: @@ -326,6 +328,7 @@ def from_config( return model def __del__(self) -> None: + # FIXME: Replace with signal overserver? if hasattr(self, "_detections"): with contextlib.suppress(Exception): self._detections.dump_detections(jitter_location=self.octree.size_limit) @@ -379,10 +382,10 @@ async def calculate_semblance( traveltimes_bad = np.isnan(traveltimes) traveltimes[traveltimes_bad] = 0.0 - station_contribution = (~traveltimes_bad).sum(axis=1, dtype=float) + station_contribution = (~traveltimes_bad).sum(axis=1, dtype=np.float32) shifts = np.round(-traveltimes / image.delta_t).astype(np.int32) - weights = np.full_like(shifts, fill_value=image.weight, dtype=float) + weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) # Normalize by number of station contribution with np.errstate(divide="ignore", invalid="ignore"): diff --git a/lassie/tracers/cake.py b/lassie/tracers/cake.py index d3e3bc9b..3edee3fc 100644 --- a/lassie/tracers/cake.py +++ b/lassie/tracers/cake.py @@ -2,6 +2,7 @@ import logging import re +import struct import zipfile from datetime import datetime, timedelta from functools import cached_property @@ -91,21 +92,21 @@ class CakeArrival(ModelledArrival): class EarthModel(BaseModel): filename: FilePath | None = Field( - DEFAULT_VELOCITY_MODEL_FILE, + default=DEFAULT_VELOCITY_MODEL_FILE, description="Path to velocity model.", ) format: Literal["nd", "hyposat"] = Field( - "nd", + default="nd", description="Format of the velocity model. nd or hyposat is supported.", ) crust2_profile: constr(to_upper=True) | tuple[float, float] = Field( - "", + default="", description="Crust2 profile name or a tuple of (lat, lon) coordinates.", ) raw_file_data: str | None = Field( - None, - description="Raw .nd file data.", + default=None, + description="Raw `.nd` file data.", ) _layered_model: LayeredModel = PrivateAttr() @@ -173,6 +174,27 @@ def id(self) -> str: return re.sub(r"[\,\s\;]", "", self.definition) +def surface_distances(nodes: Sequence[Node], stations: Stations) -> np.ndarray: + """Returns the surface distance from all nodes to all stations. + + Args: + nodes (Sequence[Node]): Nodes to calculate distance from. + stations (Stations): Stations to calculate distance to. + + Returns: + np.ndarray: Distances in shape (n-nodes, n-stations). + """ + node_coords = get_node_coordinates(nodes, system="geographic") + n_nodes = node_coords.shape[0] + + node_coords = np.repeat(node_coords, stations.n_stations, axis=0) + sta_coords = np.vstack(n_nodes * [stations.get_coordinates(system="geographic")]) + + return od.distance_accurate50m_numpy( + node_coords[:, 0], node_coords[:, 1], sta_coords[:, 0], sta_coords[:, 1] + ).reshape(-1, stations.n_stations) + + class TravelTimeTree(BaseModel): earthmodel: EarthModel timing: Timing @@ -234,11 +256,11 @@ def is_suited( time_tolerance: float, spatial_tolerance: float, ) -> bool: - def check_bounds(self, requested) -> bool: + def check_bounds(self, requested: tuple[float, float]) -> bool: return self[0] <= requested[0] and self[1] >= requested[1] return ( - str(self.earthmodel.layered_model) == str(earthmodel.layered_model) + str(self.earthmodel) == str(earthmodel.hash) and self.timing == timing and check_bounds(self.distance_bounds, distance_bounds) and check_bounds(self.source_depth_bounds, source_depth_bounds) @@ -249,7 +271,18 @@ def check_bounds(self, requested) -> bool: @property def filename(self) -> Path: - return Path(f"{self.timing.id}-{self.earthmodel.hash}.sptree") + hash = sha1(self.earthmodel.hash.encode()) + hash.update( + struct.pack( + "dddddddd", + *self.distance_bounds, + *self.source_depth_bounds, + *self.receiver_depth_bounds, + self.time_tolerance, + self.spatial_tolerance, + ) + ) + return Path(f"{self.timing.id}-{hash.hexdigest()}.sptree") @classmethod def new(cls, **data) -> Self: @@ -302,7 +335,6 @@ def load(cls, file: Path) -> Self: with zipfile.ZipFile(file, "r") as archive: path = zipfile.Path(archive) model_file = path / "model.json" - print(model_file.read_text()) model = cls.model_validate_json(model_file.read_text()) model._file = file return model @@ -345,38 +377,34 @@ def init_lut(self, octree: Octree, stations: Stations) -> None: for node, traveltimes in zip(octree, station_traveltimes, strict=True): self._node_lut[node.hash()] = traveltimes.astype(np.float32) + def lut_fill_level(self) -> float: + """Return the fill level of the LUT as a float between 0.0 and 1.0""" + return len(self._node_lut) / self._node_lut.get_size() + def fill_lut(self, nodes: Sequence[Node]) -> None: logger.debug("filling traveltimes LUT for %d nodes", len(nodes)) stations = self._cached_stations - node_coords = get_node_coordinates(nodes, system="geographic") - sta_coords = stations.get_coordinates(system="geographic") - - sta_coords = np.array(od.geodetic_to_ecef(*sta_coords.T)).T - node_coords = np.array(od.geodetic_to_ecef(*node_coords.T)).T - - receiver_distances = np.linalg.norm( - sta_coords - node_coords[:, np.newaxis], axis=2 - ) - traveltimes = self._interpolate_travel_times( - receiver_distances, + surface_distances(nodes, stations), np.array([sta.effective_depth for sta in stations]), - np.array([node.depth for node in nodes]), + np.array([node.as_location().effective_depth for node in nodes]), ) for node, times in zip(nodes, traveltimes, strict=True): self._node_lut[node.hash()] = times.astype(np.float32) - def lut_fill_level(self) -> float: - """Return the fill level of the LUT as a float between 0.0 and 1.0""" - return len(self._node_lut) / self._node_lut.get_size() - def get_travel_times(self, octree: Octree, stations: Stations) -> np.ndarray: - station_indices = np.fromiter( - (self._cached_station_indeces[sta.pretty_nsl] for sta in stations), - dtype=int, - ) + try: + station_indices = np.fromiter( + (self._cached_station_indeces[sta.pretty_nsl] for sta in stations), + dtype=int, + ) + except KeyError as exc: + raise ValueError( + "stations not found in cached stations, " + "was the LUT initialized with `TravelTimeTree.init_lut`?" + ) from exc stations_traveltimes = [] fill_nodes = [] @@ -407,9 +435,11 @@ def interpolate_travel_times( octree: Octree, stations: Stations, ) -> np.ndarray: - receiver_distances = octree.distances_stations(stations) + receiver_distances = surface_distances(octree, stations) receiver_depths = np.array([sta.effective_depth for sta in stations]) - source_depths = np.array([node.depth for node in octree]) + source_depths = np.array( + [node.as_location().effective_depth for node in octree] + ) return self._interpolate_travel_times( receiver_distances, receiver_depths, source_depths @@ -452,7 +482,7 @@ def get_travel_time(self, source: Location, receiver: Location) -> float: coordinates = [ receiver.effective_depth, source.effective_depth, - receiver.distance_to(source), + receiver.surface_distance_to(source), ] try: traveltime = self._get_sptree().interpolate(coordinates) or np.nan @@ -521,16 +551,15 @@ async def prepare(self, octree: Octree, stations: Stations) -> None: ] logger.debug("loaded %d cached travel time trees", len(cached_trees)) - distances = octree.distances_stations(stations) - source_depths = np.asarray(octree.depth_bounds) + distances = surface_distances(octree, stations) + source_depths = np.asarray(octree.depth_bounds) - octree.reference.elevation receiver_depths = np.fromiter((sta.effective_depth for sta in stations), float) - receiver_depths_bounds = (receiver_depths.min(), receiver_depths.max()) - source_depth_bounds = (source_depths.min(), source_depths.max()) distance_bounds = (distances.min(), distances.max()) + source_depth_bounds = (source_depths.min(), source_depths.max()) + receiver_depths_bounds = (receiver_depths.min(), receiver_depths.max()) # FIXME: Time tolerance is too hardcoded. Is 5x a good value? time_tolerance = octree.smallest_node_size() / (self.get_vmin() * 5.0) - # if self.trim_earth_model_depth: # self.earthmodel.trim(-source_depth_bounds[1]) diff --git a/lassie/tracers/fast_marching/velocity_models.py b/lassie/tracers/fast_marching/velocity_models.py index d64dcc24..b04bcb7b 100644 --- a/lassie/tracers/fast_marching/velocity_models.py +++ b/lassie/tracers/fast_marching/velocity_models.py @@ -4,6 +4,7 @@ import re from hashlib import sha1 from pathlib import Path +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Annotated, Any, Literal, Union import numpy as np @@ -16,6 +17,7 @@ model_validator, ) from pydantic.dataclasses import dataclass +from pyrocko.cake import LayeredModel, load_model from scipy.interpolate import RegularGridInterpolator from typing_extensions import Self @@ -87,6 +89,18 @@ def velocity_model(self) -> np.ndarray: raise ValueError("Velocity model not set.") return self._velocity_model + @property + def east_coords(self) -> np.ndarray: + return self._east_coords + + @property + def north_coords(self) -> np.ndarray: + return self._north_coords + + @property + def depth_coords(self) -> np.ndarray: + return self._depth_coords + def hash(self) -> str: """Return hash of velocity model. @@ -259,6 +273,8 @@ def get_model(self, octree: Octree) -> VelocityModel3D: @dataclass class NonLinLocHeader: + """Helper class representing a NonLinLoc header file.""" + origin: Location nx: int ny: int @@ -335,10 +351,12 @@ def from_header_file( @property def dtype(self) -> np.dtype: + """dtype of the grid.""" return DTYPE_MAP[self.grid_dtype] @property def grid_spacing(self) -> float: + """grid spacing, homogeneous in three directions.""" return self.delta_x @property @@ -456,7 +474,77 @@ def get_model(self, octree: Octree) -> VelocityModel3D: return velocity_model.resample(grid_spacing, self.interpolation) +class VelocityModel2D(VelocityModelFactory): + # For mere testing purposes of the 3D tracer against Pyrocko cake 2D travel times + model: Literal["VelocityModel2D"] = "VelocityModel2D" + velocity: Literal["vp", "vs"] = Field( + default="vp", + description="velocity to extract from the 2D model, choose from 'vp' or 'vs'.", + ) + format: Literal["nd", "hyposat"] = Field( + default="nd", + description="Format of the velocity model. nd or hyposat is supported.", + ) + filename: FilePath = Field( + ..., + description="Path to `.nd` file holding the 2D velocity model information.", + ) + raw_file_data: str | None = Field( + default=None, + description="Raw `.nd` file data.", + ) + + _layered_model: LayeredModel = PrivateAttr() + + @model_validator(mode="after") + def load_model(self) -> VelocityModel2D: + if self.filename is not None: + logger.info("loading velocity model from %s", self.filename) + self.raw_file_data = self.filename.read_text() + + if self.raw_file_data is not None: + with NamedTemporaryFile("w") as tmpfile: + tmpfile.write(self.raw_file_data) + tmpfile.flush() + self._layered_model = load_model( + tmpfile.name, + format=self.format, + ) + else: + raise AttributeError("No velocity model or crust2 profile defined.") + return self + + def get_model(self, octree: Octree) -> VelocityModel3D: + if self.grid_spacing == "octree": + grid_spacing = octree.smallest_node_size() + else: + grid_spacing = self.grid_spacing + + model = VelocityModel3D( + center=octree.reference, + grid_spacing=grid_spacing, + east_bounds=octree.east_bounds, + north_bounds=octree.north_bounds, + depth_bounds=octree.depth_bounds, + ) + + velocities = [] + for depth in model.depth_coords: + material = self._layered_model.material(z=depth) + if self.velocity == "vp": + velocities.append(material.vp) + elif self.velocity == "vs": + velocities.append(material.vs) + else: + raise ValueError(f"Invalid velocity {self.velocity}") + + velocities = np.array(velocities) + + model.velocity_model[:, :, :] = velocities[np.newaxis, np.newaxis, :] + return model + + VelocityModels = Annotated[ - Union[Constant3DVelocityModel, NonLinLocVelocityModel], + Union[Constant3DVelocityModel, NonLinLocVelocityModel, VelocityModel2D], Field(..., discriminator="model"), ] diff --git a/lassie/waveforms/squirrel.py b/lassie/waveforms/squirrel.py index 268de664..a0801748 100644 --- a/lassie/waveforms/squirrel.py +++ b/lassie/waveforms/squirrel.py @@ -44,26 +44,36 @@ def __init__( self._task = asyncio.create_task(self.prefetch_worker()) async def prefetch_worker(self) -> None: - logger.info("start prefetching squirrel data") + logger.info("start prefetching data, queue size %d", self.queue.maxsize) def filter_freqs(batch: Batch) -> Batch: + # Filter traces in-place + start = None if self.freq_min: + start = datetime_now() for tr in batch.traces: - tr.highpass(4, self.freq_min) + tr.highpass(4, corner=self.freq_min) if self.freq_max: + start = start or datetime_now() for tr in batch.traces: - tr.lowpass(4, self.freq_max) + tr.lowpass(4, corner=self.freq_max) + if start: + logger.debug("filtered traces in %s", datetime_now() - start) return batch while True: start = datetime_now() batch = await asyncio.to_thread(lambda: next(self.iterator, None)) - await asyncio.to_thread(filter_freqs, batch) - logger.debug("prefetched waveforms in %s", datetime_now() - start) if batch is None: logger.debug("squirrel prefetcher finished") await self.queue.put(None) break + + await asyncio.to_thread(filter_freqs, batch) + logger.debug("prefetched waveforms in %s", datetime_now() - start) + if self.queue.empty(): + logger.warning("queue ran empty, prefetching is too slow") + await self.queue.put(batch) @@ -146,7 +156,12 @@ async def iter_batches( (*nsl, self.channel_selector) for nsl in self._stations.get_all_nsl() ], ) - prefetcher = SquirrelPrefetcher(iterator, self.async_prefetch_batches) + prefetcher = SquirrelPrefetcher( + iterator, + self.async_prefetch_batches, + self.freq_min, + self.freq_max, + ) while True: batch = await prefetcher.queue.get() diff --git a/test/conftest.py b/test/conftest.py index 29e32c23..11f9189c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -116,7 +116,8 @@ def stations() -> Stations: station="STA%02d" % i_sta, lat=10.0, lon=10.0, - elevation=random.uniform(0, 1) * KM, + elevation=random.uniform(0, 0.8) * KM, + depth=random.uniform(0, 0.2) * KM, north_shift=random.uniform(-10, 10) * KM, east_shift=random.uniform(-10, 10) * KM, ) diff --git a/test/test_cake.py b/test/test_cake.py index cccb4079..314af37c 100644 --- a/test/test_cake.py +++ b/test/test_cake.py @@ -1,19 +1,57 @@ from __future__ import annotations +import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING import numpy as np +import pytest from lassie.models.location import Location -from lassie.tracers.cake import TravelTimeTree - -if TYPE_CHECKING: - from lassie.models.station import Stations - from lassie.octree import Octree +from lassie.models.station import Station, Stations +from lassie.octree import Octree +from lassie.tracers.cake import CakeTracer, EarthModel, Timing, TravelTimeTree +from lassie.tracers.constant_velocity import ConstantVelocityTracer KM = 1e3 +CONSTANT_VELOCITY = 5 * KM + + +@pytest.fixture(scope="session") +def small_octree() -> Octree: + return Octree( + reference=Location( + lat=10.0, + lon=10.0, + elevation=0.2 * KM, + ), + size_initial=0.5 * KM, + size_limit=50, + east_bounds=(-2 * KM, 2 * KM), + north_bounds=(-2 * KM, 2 * KM), + depth_bounds=(0 * KM, 2 * KM), + absorbing_boundary=1 * KM, + ) + + +@pytest.fixture(scope="session") +def small_stations() -> Stations: + rng = np.random.default_rng(1232) + n_stations = 20 + stations: list[Station] = [] + for i_sta in range(n_stations): + station = Station( + network="XX", + station="STA%02d" % i_sta, + lat=10.0, + lon=10.0, + elevation=rng.uniform(0, 0.1) * KM, + depth=rng.uniform(0, 0.1) * KM, + north_shift=rng.uniform(-2, 2) * KM, + east_shift=rng.uniform(-2, 2) * KM, + ) + stations.append(station) + return Stations(stations=stations) def test_sptree_model(travel_time_tree: TravelTimeTree): @@ -68,3 +106,37 @@ def test_lut( traveltimes_tree = model.interpolate_travel_times(octree, stations_selection) traveltimes_lut = model.get_travel_times(octree, stations_selection) np.testing.assert_equal(traveltimes_tree, traveltimes_lut) + + +@pytest.mark.asyncio +async def test_travel_times_constant_velocity( + small_octree: Octree, + small_stations: Stations, +): + octree = small_octree + stations = small_stations + octree.size_limit = 200 + cake_tracer = CakeTracer( + phases={"cake:P": Timing(definition="P,p")}, + earthmodel=EarthModel( + filename=None, + raw_file_data=f""" + -2.0 {CONSTANT_VELOCITY/KM:.1f} 2.0 2.7 + 12.0 {CONSTANT_VELOCITY/KM:.1f} 2.0 2.7 +""", + ), + ) + constant = ConstantVelocityTracer( + velocity=CONSTANT_VELOCITY, + ) + + await cake_tracer.prepare(octree, stations) + + cake_travel_times = cake_tracer.get_travel_times("cake:P", octree, stations) + constant_traveltimes = constant.get_travel_times("constant:P", octree, stations) + + nan_mask = np.isnan(cake_travel_times) + logging.warning("percent nan: %.1f", (nan_mask.sum() / nan_mask.size) * 100) + + constant_traveltimes[nan_mask] = np.nan + np.testing.assert_almost_equal(cake_travel_times, constant_traveltimes, decimal=2) diff --git a/test/test_fast_marching.py b/test/test_fast_marching.py index 59d75206..7c506816 100644 --- a/test/test_fast_marching.py +++ b/test/test_fast_marching.py @@ -8,6 +8,12 @@ from lassie.models.station import Station, Stations from lassie.octree import Octree +from lassie.tracers.cake import ( + DEFAULT_VELOCITY_MODEL_FILE, + CakeTracer, + EarthModel, + Timing, +) from lassie.tracers.fast_marching.fast_marching import ( FastMarchingTracer, StationTravelTimeVolume, @@ -15,6 +21,7 @@ from lassie.tracers.fast_marching.velocity_models import ( Constant3DVelocityModel, NonLinLocVelocityModel, + VelocityModel2D, VelocityModel3D, ) from lassie.utils import datetime_now @@ -88,7 +95,7 @@ async def test_load_save( @pytest.mark.asyncio -async def test_travel_time_interpolation( +async def test_travel_times_constant_model( station_travel_times: StationTravelTimeVolume, octree: Octree, ) -> None: @@ -127,6 +134,39 @@ async def test_travel_time_interpolation( ) +@pytest.mark.asyncio +async def test_travel_times_cake( + octree: Octree, + fixed_stations: Stations, +): + tracer = FastMarchingTracer( + phase="fm:P", + velocity_model=VelocityModel2D( + grid_spacing=200.0, + velocity="vp", + filename=DEFAULT_VELOCITY_MODEL_FILE, + ), + ) + await tracer.prepare(octree, fixed_stations) + + cake_tracer = CakeTracer( + phases={"cake:P": Timing(definition="P,p")}, + earthmodel=EarthModel( + filename=DEFAULT_VELOCITY_MODEL_FILE, + ), + ) + await cake_tracer.prepare(octree, fixed_stations) + + travel_times_fast_marching = tracer.get_travel_times("fm:P", octree, fixed_stations) + travel_times_cake = cake_tracer.get_travel_times("cake:P", octree, fixed_stations) + + nan_mask = np.isnan(travel_times_cake) + travel_times_fast_marching[nan_mask] = np.nan + np.testing.assert_almost_equal( + travel_times_fast_marching, travel_times_cake, decimal=1 + ) + + @pytest.mark.asyncio async def test_fast_marching_phase_tracer( octree: Octree, fixed_stations: Stations diff --git a/test/test_plot.py b/test/test_plot.py deleted file mode 100644 index 46d77fe9..00000000 --- a/test/test_plot.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING - -import numpy as np -import pytest -from matplotlib import pyplot as plt - -from lassie.models.detection import EventDetection -from lassie.plot.detections import plot_detections -from lassie.plot.octree import plot_octree_surface_tiles -from lassie.utils import datetime_now - -if TYPE_CHECKING: - from lassie.models.detection import EventDetections - from lassie.octree import Octree - - -@pytest.mark.plot -def test_octree_2d(octree: Octree) -> None: - semblance = np.random.uniform(size=octree.n_nodes) - octree.map_semblance(semblance) - plot_octree_surface_tiles(octree, filename=Path("/tmp/test.png")) - - detection = EventDetection( - lat=0.0, - lon=0.0, - east_shift=0.0, - north_shift=0.0, - distance_border=1000.0, - semblance=1.0, - time=datetime_now(), - ) - - fig = plt.figure() - ax = fig.gca() - - plot_octree_surface_tiles(octree, axes=ax, detections=[detection]) - plt.show() - - -@pytest.mark.plot -def test_detections_semblance(detections: EventDetections) -> None: - plot_detections(detections, axes=None)