Skip to content

Commit

Permalink
added mypy types for files in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ta-hill committed Mar 1, 2024
1 parent 6c9ea41 commit d33db5e
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 26 deletions.
48 changes: 34 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import string
from typing import TYPE_CHECKING
from typing import Any

import numpy as np
import pytest
Expand All @@ -25,7 +26,7 @@


@pytest.fixture(scope="session")
def format_str_to_text_header() -> Callable:
def format_str_to_text_header() -> Callable[[str], str]:
"""Fixture wrapper around helper function to format text headers."""

def _format_str_to_text_header(text: str) -> str:
Expand All @@ -38,7 +39,12 @@ def _format_str_to_text_header(text: str) -> str:


@pytest.fixture(scope="module")
def make_header_field_descriptor() -> Callable:
def make_header_field_descriptor() -> (
Callable[
[str, list[str] | None, list[int] | None, Endianness],
dict[str, list[HeaderFieldDescriptor] | int],
]
):
"""Fixture wrapper around helper function to generate params for descriptors."""

def _make_header_field_descriptor(
Expand All @@ -65,10 +71,12 @@ def _make_header_field_descriptor(
)
temp_dt = np.dtype(dt_string)
item_size = temp_dt.itemsize
# becuase mypy doesn't like MappingProxy
temp_dt_field_values: list[tuple[Any, ...]] = list(temp_dt.fields.values()) # type: ignore[union-attr]
dt_offsets = (
offsets
if offsets is not None
else [field[-1] for field in temp_dt.fields.values()]
else [field[-1] for field in temp_dt_field_values]
)
header_fields = [
HeaderFieldDescriptor(
Expand All @@ -85,14 +93,19 @@ def _make_header_field_descriptor(


@pytest.fixture(scope="module")
def make_trace_header_descriptor(make_header_field_descriptor: Callable) -> Callable:
def make_trace_header_descriptor(
make_header_field_descriptor: Callable[
[str, list[str] | None, list[int] | None, str | Endianness],
dict[str, list[HeaderFieldDescriptor] | int],
],
) -> Callable[..., TraceHeaderDescriptor]:
"""Fixture wrapper for helper function to create TraceHeaderDescriptors."""

def _make_trace_header_descriptor(
dt_string: str = "i2",
names: list[str] | None = None,
offsets: list[int] | None = None,
endianness: str = Endianness.BIG,
endianness: str | Endianness = Endianness.BIG,
) -> TraceHeaderDescriptor:
"""Convenience function for creating TraceHeaderDescriptors.
Expand All @@ -105,9 +118,10 @@ def _make_trace_header_descriptor(
Returns:
TraceHeaderDescriptor: Descriptor object for TraceHeaderDescriptors
"""
head_field_desc = make_header_field_descriptor(
dt_string=dt_string, names=names, offsets=offsets, endianness=endianness
head_field_desc: dict[str, Any] = make_header_field_descriptor(
dt_string, names, offsets, endianness
)

return TraceHeaderDescriptor(
fields=head_field_desc["fields"],
item_size=head_field_desc["item_size"],
Expand All @@ -118,7 +132,7 @@ def _make_trace_header_descriptor(


@pytest.fixture(scope="module")
def make_trace_data_descriptor() -> Callable:
def make_trace_data_descriptor() -> Callable[..., TraceDataDescriptor]:
"""Fixture wrapper for helper function to create TraceDataDescriptor."""

def _make_trace_data_descriptor(
Expand Down Expand Up @@ -150,8 +164,9 @@ def _make_trace_data_descriptor(

@pytest.fixture(scope="module")
def make_trace_descriptor(
make_trace_header_descriptor: Callable, make_trace_data_descriptor: Callable
) -> Callable:
make_trace_header_descriptor: Callable[..., TraceHeaderDescriptor],
make_trace_data_descriptor: Callable[..., TraceDataDescriptor],
) -> Callable[..., TraceDescriptor]:
"""Fixture wrapper for helper function to create TraceDescriptors."""

def _make_trace_descriptor(
Expand All @@ -177,7 +192,12 @@ def _make_trace_descriptor(


@pytest.fixture(scope="module")
def make_binary_header_descriptor(make_header_field_descriptor: Callable) -> Callable:
def make_binary_header_descriptor(
make_header_field_descriptor: Callable[
[str, list[str] | None, list[int] | None, Endianness | str],
dict[str, list[HeaderFieldDescriptor] | int],
],
) -> Callable[..., BinaryHeaderDescriptor]:
"""Fixture wrapper around helper function for creating BinaryHeaderDescriptor."""

def _make_binary_header_descriptor(
Expand All @@ -197,8 +217,8 @@ def _make_binary_header_descriptor(
Returns:
BinaryHeaderDescriptor: Descriptor object for BinaryHeaderDescriptor
"""
head_field_desc = make_header_field_descriptor(
dt_string=dt_string, names=names, offsets=offsets, endianness=endianness
head_field_desc: dict[str, Any] = make_header_field_descriptor(
dt_string, names, offsets, endianness
)
return BinaryHeaderDescriptor(
fields=head_field_desc["fields"],
Expand All @@ -211,7 +231,7 @@ def _make_binary_header_descriptor(

def generate_unique_names(count: int) -> list[str]:
"""Helper function to create random unique names as placeholders during testing."""
names: set = set()
names: set[str] = set()
rng = np.random.default_rng()
while len(names) < count:
name_length = rng.integers(5, 10) # noqa: S311
Expand Down
8 changes: 4 additions & 4 deletions tests/main/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create_mock_segy_rev0(
tmp_file: Path,
num_samples: int,
num_traces: int,
format_str_to_text_header: Callable,
format_str_to_text_header: Callable[[str], str],
) -> SegyFile:
"""Create a temporary file that mocks a segy Rev0 file structure."""
rev0_spec = registry.get_spec(SegyStandard.REV0)
Expand Down Expand Up @@ -50,7 +50,7 @@ def create_mock_segy_rev0(

@pytest.fixture(scope="session")
def mock_segy_rev0(
request: list[int], tmp_path: Path, format_str_to_text_header: Callable
request: list[int], tmp_path: Path, format_str_to_text_header: Callable[[str], str]
) -> SegyFile:
"""Returns a temp file that for rev0 SegyFile object."""
req_params = getattr(request, "param", [10, 10])
Expand All @@ -68,7 +68,7 @@ def create_mock_segy_rev1(
num_samples: int,
num_traces: int,
num_extended_headers: int,
format_str_to_text_header: Callable,
format_str_to_text_header: Callable[[str], str],
) -> SegyFile:
"""Create a temporary file that mocks a segy Rev1 file structure."""
rev1_spec = registry.get_spec(SegyStandard.REV1)
Expand Down Expand Up @@ -111,7 +111,7 @@ def create_mock_segy_rev1(

@pytest.fixture(scope="session")
def mock_segy_rev1(
request: list[int], tmp_path: Path, format_str_to_text_header: Callable
request: list[int], tmp_path: Path, format_str_to_text_header: Callable[[str], str]
) -> SegyFile:
"""Returns a temp file that for rev1 SegyFile object."""
req_params = getattr(request, "param", [10, 10, 2])
Expand Down
3 changes: 2 additions & 1 deletion tests/main/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from segy.indexing import merge_cat_file
from segy.indexing import trace_ibm2ieee_inplace
from segy.schema import Endianness
from segy.schema import TraceDescriptor

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -100,7 +101,7 @@ def test_trace_ibm2ieee_inplace(
header_params: dict[str, Any],
data_params: dict[str, Any],
float_vals: list[float],
make_trace_descriptor: Callable,
make_trace_descriptor: Callable[..., TraceDescriptor],
) -> None:
"""Test changing dtype of IBM32 values inplace."""
trace_descr = make_trace_descriptor(header_params, data_params)
Expand Down
11 changes: 7 additions & 4 deletions tests/schema/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
]
)
def binary_header_descriptors(
request: pytest.FixtureRequest, make_binary_header_descriptor: Callable
request: pytest.FixtureRequest,
make_binary_header_descriptor: Callable[..., BinaryHeaderDescriptor],
) -> BinaryHeaderDescriptor:
"""Generates BinaryHeaderDescriptor objects from parameters.
Expand All @@ -68,7 +69,8 @@ def binary_header_descriptors(
]
)
def trace_header_descriptors(
request: pytest.FixtureRequest, make_trace_header_descriptor: Callable
request: pytest.FixtureRequest,
make_trace_header_descriptor: Callable[..., TraceHeaderDescriptor],
) -> TraceHeaderDescriptor:
"""Generates TraceHeaderDescriptor objects from parameters.
Expand Down Expand Up @@ -119,7 +121,8 @@ def data_types(request: pytest.FixtureRequest) -> DataTypeDescriptor:
]
)
def trace_data_descriptors(
request: pytest.FixtureRequest, make_trace_data_descriptor: Callable
request: pytest.FixtureRequest,
make_trace_data_descriptor: Callable[..., TraceDataDescriptor],
) -> TraceDataDescriptor:
"""Fixture that creates TraceDataDescriptors of different data types and endianness."""
return make_trace_data_descriptor(
Expand Down Expand Up @@ -182,7 +185,7 @@ def trace_data_descriptors(

@pytest.fixture(params=[sample_text, sample_real_header_text])
def text_header_samples(
request: pytest.FixtureRequest, format_str_to_text_header: Callable
request: pytest.FixtureRequest, format_str_to_text_header: Callable[[str], str]
) -> str:
"""Fixture that generates fixed size text header test data from strings."""
return format_str_to_text_header(request.param)
2 changes: 1 addition & 1 deletion tests/schema/test_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_structured_data_type_descriptor(
) -> None:
"""This tests for creatin a StructuredDataTypeDescriptor for different component data types."""
new_sdtd = StructuredDataTypeDescriptor(
fields=fields, item_size=item_size, offset=offset
fields=list(fields), item_size=item_size, offset=offset
)
assert new_sdtd.dtype.names == tuple([f.name for f in fields])
assert new_sdtd.item_size == new_sdtd.dtype.itemsize
Expand Down
5 changes: 3 additions & 2 deletions tests/schema/test_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pytest

from segy.schema.header import BinaryHeaderDescriptor
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_trace_header_descriptors(
)


def void_buffer(buff_size: int) -> np.ndarray:
def void_buffer(buff_size: int) -> npt.NDArray[np.void]:
"""Creates a new buffer of requested number of bytes with void(number_bytes) datatype.
Prefills with random bytes.
Expand All @@ -98,7 +99,7 @@ def void_buffer(buff_size: int) -> np.ndarray:
def get_dt_info(
dt: np.dtype[Any],
atrnames: list[str] | None = None,
) -> dict:
) -> dict[str, Any]:
"""Helper function to get info about a numpy dtype."""
if atrnames is None:
atrnames = [
Expand Down

0 comments on commit d33db5e

Please sign in to comment.