Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/async squirrel #16

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,27 @@ extend-select = [
'I',
'RUF',
'T20',
'D',
]

ignore = ["RUF012", "RUF009"]
ignore = [
"RUF012",
"RUF009",
"D100",
"D101",
"D102",
"D103",
"D104",
"D105",
"D107",
]

[tool.ruff]
target-version = 'py311'

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.pytest.ini_options]
markers = ["plot: plot figures in tests"]

Expand Down
79 changes: 59 additions & 20 deletions src/qseek/apps/qseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import nest_asyncio
from pkg_resources import get_distribution

from qseek.models.detection import EventDetection

nest_asyncio.apply()

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -128,6 +130,18 @@
type=Path,
help="path of existing run",
)
features_extract.add_argument(
"--recalculate",
action="store_true",
default=False,
help="recalculate all magnitudes",
)
features_extract.add_argument(
"--nparallel",
type=int,
default=32,
help="number of parallel tasks for feature extraction",
)

modules = subparsers.add_parser(
"modules",
Expand Down Expand Up @@ -194,7 +208,7 @@ def main() -> None:

load_insights()
from rich import box
from rich.progress import track
from rich.progress import Progress
from rich.prompt import IntPrompt
from rich.table import Table

Expand Down Expand Up @@ -256,35 +270,60 @@ async def run() -> None:
case "feature-extraction":
search = Search.load_rundir(args.rundir)
search.data_provider.prepare(search.stations)
recalculate_magnitudes = args.recalculate

tasks = []

def console_status(task: asyncio.Task[EventDetection]):
detection = task.result()
if detection.magnitudes:
console.print(
f"Event {str(detection.time).split('.')[0]}:",
", ".join(
f"[bold]{m.magnitude}[/bold] {m.average:.2f}±{m.error:.2f}"
for m in detection.magnitudes
),
)
else:
console.print(f"Event {detection.time}: No magnitudes")

async def extract() -> None:
progress = Progress()
tracker = progress.add_task(
"Calculating magnitudes",
total=search.catalog.n_events,
console=console,
)

async def worker() -> None:
for magnitude in search.magnitudes:
await magnitude.prepare(search.octree, search.stations)

iterator = asyncio.as_completed(
tuple(
search.add_magnitude_and_features(detection)
for detection in search._catalog
await search.catalog.check(repair=True)

sem = asyncio.Semaphore(args.nparallel)
for detection in search.catalog:
await sem.acquire()
task = asyncio.create_task(
search.add_magnitude_and_features(
detection,
recalculate=recalculate_magnitudes,
)
)
tasks.append(task)
task.add_done_callback(lambda _: sem.release())
task.add_done_callback(tasks.remove)
task.add_done_callback(console_status)
task.add_done_callback(
lambda _: progress.update(tracker, advance=1)
)
)

for result in track(
iterator,
description="Extracting features",
total=search._catalog.n_events,
):
event = await result
if event.magnitudes:
for mag in event.magnitudes:
print(f"{mag.magnitude} {mag.average:.2f}±{mag.error:.2f}") # noqa: T201
print("--") # noqa: T201
await asyncio.gather(*tasks)

await search._catalog.save()
await search._catalog.export_detections(
jitter_location=search.octree.smallest_node_size()
)

asyncio.run(extract(), debug=loop_debug)
asyncio.run(worker(), debug=loop_debug)

case "corrections":
import json
Expand Down Expand Up @@ -391,7 +430,7 @@ def is_insight(module: type) -> bool:
raise EnvironmentError(f"folder {args.folder} does not exist")

file = args.folder / "search.schema.json"
print(f"writing JSON schemas to {args.folder}") # noqa: T201
console.print(f"writing JSON schemas to {args.folder}")
file.write_text(json.dumps(Search.model_json_schema(), indent=2))

file = args.folder / "detections.schema.json"
Expand Down
9 changes: 5 additions & 4 deletions src/qseek/corrections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ async def prepare(
"""Prepare the station for the corrections.

Args:
station: The station to prepare.
octree: The octree to use for the preparation.
phases: The phases to prepare the station for.
rundir: The rundir to use for the delay. Defaults to None.
stations (Stations): The station to prepare.
octree (Octree): The octree to use for the preparation.
phases (Iterable[PhaseDescription]): The phases to prepare the station for.
rundir (Path | None, optional): The rundir to use for the delay.
Defaults to None.
"""
...

Expand Down
8 changes: 4 additions & 4 deletions src/qseek/images/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def name(self) -> str:
return self.__class__.__name__

def get_blinding(self, sampling_rate: float) -> timedelta:
"""
Blinding duration for the image function. Added to padded waveforms.
"""Blinding duration for the image function. Added to padded waveforms.

Args:
sampling_rate (float): The sampling rate of the waveform.
Expand Down Expand Up @@ -73,6 +72,7 @@ def set_stations(self, stations: Stations) -> None:

def resample(self, sampling_rate: float, max_normalize: bool = False) -> None:
"""Resample traces in-place.

Args:
sampling_rate (float): Desired sampling rate in Hz.
max_normalize (bool): Normalize by maximum value to keep the scale of the
Expand Down Expand Up @@ -137,7 +137,7 @@ def search_phase_arrival(
trace_idx (int): Index of the trace.
event_time (datetime): Time of the event.
modelled_arrival (datetime): Time to search around.
search_length_seconds (float, optional): Total search length in seconds
search_window_seconds (float, optional): Total search length in seconds
around modelled arrival time. Defaults to 5.
threshold (float, optional): Threshold for detection. Defaults to 0.1.

Expand All @@ -158,7 +158,7 @@ def search_phase_arrivals(
Args:
event_time (datetime): Time of the event.
modelled_arrivals (list[datetime]): Time to search around.
search_length_seconds (float, optional): Total search length in seconds
search_window_seconds (float, optional): Total search length in seconds
around modelled arrival time. Defaults to 5.
threshold (float, optional): Threshold for detection. Defaults to 0.1.

Expand Down
7 changes: 5 additions & 2 deletions src/qseek/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,22 @@ async def iter_images(
"""Iterate over images from batches.

Args:
batches (AsyncIterator[Batch]): Async iterator over batches.
batch_iterator (AsyncIterator[Batch]): Async iterator over batches.

Yields:
AsyncIterator[WaveformImages]: Async iterator over images.
"""

stats = self._stats

async def worker() -> None:
logger.info(
"start pre-processing images, queue size %d", self._queue.maxsize
)
async for batch in batch_iterator:
if batch.is_empty():
logger.debug("empty batch, skipping")
continue

start_time = datetime_now()
images = await self.process_traces(batch.traces)
stats.time_per_batch = datetime_now() - start_time
Expand Down
11 changes: 5 additions & 6 deletions src/qseek/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def search_phase_arrival(
trace_idx (int): Index of the trace.
event_time (datetime): Time of the event.
modelled_arrival (datetime): Time to search around.
search_length_seconds (float, optional): Total search length in seconds
search_window_seconds (float, optional): Total search length in seconds
around modelled arrival time. Defaults to 5.
threshold (float, optional): Threshold for detection. Defaults to 0.1.
detection_blinding_seconds (float, optional): Blinding time in seconds for
Expand Down Expand Up @@ -113,16 +113,15 @@ def search_phase_arrival(
peak_delay = peak_times - event_time.timestamp()

# Limit to post-event peaks
post_event_peaks = peak_delay > 0.0
peak_idx = peak_idx[post_event_peaks]
peak_times = peak_times[post_event_peaks]
peak_delay = peak_delay[post_event_peaks]
after_event_peaks = peak_delay > 0.0
peak_idx = peak_idx[after_event_peaks]
peak_times = peak_times[after_event_peaks]
peak_delay = peak_delay[after_event_peaks]

if not peak_idx.size:
return None

peak_values = search_trace.get_ydata()[peak_idx]

closest_peak_idx = np.argmin(peak_delay)

return ObservedArrival(
Expand Down
17 changes: 13 additions & 4 deletions src/qseek/magnitudes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,23 @@ def get_subclasses(cls) -> tuple[type[EventMagnitudeCalculator], ...]:
"""
return tuple(cls.__subclasses__())

def has_magnitude(self, event: EventDetection) -> bool:
"""Check if the given event has a magnitude.

Args:
event (EventDetection): The event to check.

Returns:
bool: True if the event has a magnitude, False otherwise.
"""
raise NotImplementedError

async def add_magnitude(
self,
squirrel: Squirrel,
event: EventDetection,
) -> None:
"""
Adds a magnitude to the squirrel for the given event.
"""Adds a magnitude to the squirrel for the given event.

Args:
squirrel (Squirrel): The squirrel object to add the magnitude to.
Expand All @@ -132,8 +142,7 @@ async def prepare(
octree: Octree,
stations: Stations,
) -> None:
"""
Prepare the magnitudes calculation by initializing necessary data structures.
"""Prepare the magnitudes calculation by initializing necessary data structures.

Args:
octree (Octree): The octree containing seismic event data.
Expand Down
8 changes: 7 additions & 1 deletion src/qseek/magnitudes/local_magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ def validate_model(self) -> Self:
self._model = LocalMagnitudeModel.get_subclass_by_name(self.model)()
return self

def has_magnitude(self, event: EventDetection) -> bool:
for mag in event.magnitudes:
if type(mag) is LocalMagnitude and mag.model == self.model:
return True
return False

async def add_magnitude(self, squirrel: Squirrel, event: EventDetection) -> None:
model = self._model

Expand All @@ -180,7 +186,7 @@ async def add_magnitude(self, squirrel: Squirrel, event: EventDetection) -> None
cut_off_fade=cut_off_fade,
quantity=model.restitution_quantity,
phase=None,
remove_clipped=True,
filter_clipped=True,
)
if not traces:
logger.warning("No restituted traces found for event %s", event.time)
Expand Down
2 changes: 1 addition & 1 deletion src/qseek/magnitudes/local_magnitude_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def get_station_magnitude(
try:
traces = _COMPONENT_MAP[self.component](traces)
except KeyError:
logger.warning("Could not get channels for %s", receiver.nsl.pretty)
logger.debug("Could not get channels for %s", receiver.nsl.pretty)
return None
if not traces:
return None
Expand Down
Loading
Loading