Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve data getter functionality #189

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 44 additions & 35 deletions src/segy/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from segy.schema import TraceSpec
from segy.schema.base import BaseDataType

IntDType = np.signedinteger[Any]


def merge_cat_file(
fs: AbstractFileSystem,
Expand Down Expand Up @@ -63,7 +65,7 @@ def merge_cat_file(
return bytearray(b"".join(buffer_bytes))


def bounds_check(indices: list[int], max_: int, type_: str) -> None:
def bounds_check(indices: NDArray[IntDType], max_: int, type_: str) -> None:
"""Check if indices are out of bounds (negative, or more than max).

Wrapping negative indices is not supported yet. The `type_` argument
Expand All @@ -77,14 +79,11 @@ def bounds_check(indices: list[int], max_: int, type_: str) -> None:
Raises:
IndexError: If any of the indices are negative or exceed the maximum value.
"""
negative_indices = [index for index in indices if index < 0]
out_of_range_indices = [index for index in indices if index >= max_]

outliers = negative_indices + out_of_range_indices
oob_indices = np.where((indices < 0) | (indices >= max_))[0]

if outliers:
if len(oob_indices) > 0:
msg = (
f"Requested {type_} indices {outliers} are out of bounds. SEG-Y "
f"Requested {type_} indices {oob_indices} are out of bounds. SEG-Y "
f"file has {max_} traces. Valid indices are "
f"[0, {max_ - 1})."
)
Expand Down Expand Up @@ -125,7 +124,9 @@ def __init__( # noqa: PLR0913
)

@abstractmethod
def indices_to_byte_ranges(self, indices: list[int]) -> tuple[list[int], list[int]]:
def indices_to_byte_ranges(
self, indices: NDArray[IntDType]
) -> tuple[NDArray[IntDType], NDArray[IntDType]]:
"""Logic to calculate start/end bytes."""

@abstractmethod
Expand All @@ -136,28 +137,23 @@ def post_process(self, data: NDArray[Any]) -> NDArray[Any]:
"""Apply transforms to the data after decoding."""
return self.transform_pipeline.apply(data)

def __getitem__(self, item: int | list[int] | slice) -> Any: # noqa: ANN401
def __getitem__(self, item: int | list[int] | NDArray[IntDType] | slice) -> Any: # noqa: ANN401
"""Operator for integers, lists, and slices with bounds checking."""
indices = None

if isinstance(item, int):
indices = [item]
bounds_check(indices, self.max_value, self.kind)

elif isinstance(item, list):
indices = item
bounds_check(indices, self.max_value, self.kind)

elif isinstance(item, slice):
if isinstance(item, slice):
if item.step == 0:
msg = "Step of 0 is invalid for slicing."
raise ValueError(msg)

start = item.start or 0
stop = item.stop or self.max_value
start_stop = np.asarray([start, stop - 1])

bounds_check([start, stop - 1], self.max_value, self.kind)
indices = list(range(*item.indices(self.max_value)))
bounds_check(start_stop, self.max_value, self.kind)
indices = np.arange(*item.indices(self.max_value))

else: # int, list, or ndarray case
indices = np.atleast_1d(item)
bounds_check(indices, self.max_value, self.kind)

if len(indices) == 0:
msg = "Couldn't parse request. Please ensure it is a valid index."
Expand All @@ -166,7 +162,7 @@ def __getitem__(self, item: int | list[int] | slice) -> Any: # noqa: ANN401
data = self.fetch(indices)
return self.post_process(data)

def fetch(self, indices: list[int]) -> NDArray[Any]:
def fetch(self, indices: NDArray[IntDType]) -> NDArray[Any]:
"""Fetches and decodes binary data from the given indices.

Args:
Expand All @@ -179,12 +175,19 @@ def fetch(self, indices: list[int]) -> NDArray[Any]:
- This method internally converts the indices to byte ranges using
the 'indices_to_byte_ranges' method.
- The byte ranges are used to fetch the corresponding data from the
file specified by the 'url' parameter.
file specified by the 'url' parameter. However, this is fastest
if minimize the amount of reads. Here we combine starts and
stops that are adjacent to each other. This requires a sort.
- The indices users request may be out of order, so we ensure we
save the index order and then use it to sort the read buffer back
to user's requested shape.
- The fetched data is then decoded and squeezed before being returned.
"""
index_order = np.argsort(indices)
starts, ends = self.indices_to_byte_ranges(indices)
buffer = merge_cat_file(self.fs, self.url, starts, ends)
return self.decode(buffer).squeeze()
buffer = merge_cat_file(self.fs, self.url, starts.tolist(), ends.tolist())
array = self.decode(buffer)
return array[index_order].squeeze()


class TraceIndexer(AbstractIndexer):
Expand All @@ -197,7 +200,9 @@ class TraceIndexer(AbstractIndexer):
spec: TraceSpec
kind: str = "trace"

def indices_to_byte_ranges(self, indices: list[int]) -> tuple[list[int], list[int]]:
def indices_to_byte_ranges(
self, indices: NDArray[IntDType]
) -> tuple[NDArray[IntDType], NDArray[IntDType]]:
"""Convert trace indices to byte ranges."""
if self.spec.offset is None:
msg = "Trace starting offset must be specified."
Expand All @@ -206,8 +211,8 @@ def indices_to_byte_ranges(self, indices: list[int]) -> tuple[list[int], list[in
start_offset = self.spec.offset
trace_itemsize = self.spec.dtype.itemsize

starts = [start_offset + i * trace_itemsize for i in indices]
ends = [start + trace_itemsize for start in starts]
starts = start_offset + indices * trace_itemsize
ends = starts + trace_itemsize

return starts, ends

Expand All @@ -227,7 +232,9 @@ class HeaderIndexer(AbstractIndexer):
spec: TraceSpec
kind: str = "header"

def indices_to_byte_ranges(self, indices: list[int]) -> tuple[list[int], list[int]]:
def indices_to_byte_ranges(
self, indices: NDArray[IntDType]
) -> tuple[NDArray[IntDType], NDArray[IntDType]]:
"""Convert header indices to byte ranges (without trace data)."""
trace_itemsize = self.spec.dtype.itemsize
header_itemsize = self.spec.header.itemsize
Expand All @@ -238,8 +245,8 @@ def indices_to_byte_ranges(self, indices: list[int]) -> tuple[list[int], list[in

start_offset = self.spec.offset

starts = [start_offset + i * trace_itemsize for i in indices]
ends = [start + header_itemsize for start in starts]
starts = start_offset + indices * trace_itemsize
ends = starts + header_itemsize

return starts, ends

Expand All @@ -258,7 +265,9 @@ class DataIndexer(AbstractIndexer):
spec: TraceSpec
kind: str = "data"

def indices_to_byte_ranges(self, indices: list[int]) -> tuple[list[int], list[int]]:
def indices_to_byte_ranges(
self, indices: NDArray[IntDType]
) -> tuple[NDArray[IntDType], NDArray[IntDType]]:
"""Convert data indices to byte ranges (without trace headers)."""
trace_itemsize = self.spec.itemsize
data_itemsize = self.spec.data.itemsize
Expand All @@ -270,8 +279,8 @@ def indices_to_byte_ranges(self, indices: list[int]) -> tuple[list[int], list[in

start_offset = self.spec.offset + header_itemsize

starts = [start_offset + i * trace_itemsize for i in indices]
ends = [start + data_itemsize for start in starts]
starts = start_offset + indices * trace_itemsize
ends = starts + data_itemsize

return starts, ends

Expand Down
5 changes: 3 additions & 2 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from typing import TYPE_CHECKING

import numpy as np
import pytest

from segy.indexing import bounds_check
Expand All @@ -31,7 +32,7 @@ class TestBoundsCheckIndices:
)
def test_in_bounds(self, indices: list[int], size: int) -> None:
"""Test the case where indices are in bounds."""
bounds_check(indices, size, "trace")
bounds_check(np.asarray(indices), size, "trace")

@pytest.mark.parametrize(
("indices", "size"),
Expand All @@ -43,7 +44,7 @@ def test_in_bounds(self, indices: list[int], size: int) -> None:
def test_out_of_bounds(self, indices: list[int], size: int) -> None:
"""Test the case where indices out of bounds or negative."""
with pytest.raises(IndexError, match="out of bounds"):
bounds_check(indices, size, "")
bounds_check(np.asarray(indices), size, "")


class TestMergeCatFile:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_segy_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,18 @@ def test_trace_accessor(
assert_array_equal(traces.header, test_config.expected_headers)
assert_array_almost_equal(traces.sample, test_config.expected_samples)

# Test random access
index = [0, 2, 4]
traces = segy_file.trace[index]
assert_array_equal(traces.header, test_config.expected_headers[index])
assert_array_almost_equal(traces.sample, test_config.expected_samples[index])

# Test reverse order random access
index = [5, 3, 0]
traces = segy_file.trace[index]
assert_array_equal(traces.header, test_config.expected_headers[index])
assert_array_almost_equal(traces.sample, test_config.expected_samples[index])


class TestSegyFileExceptions:
"""Test exceptions for SegyFile."""
Expand Down
Loading