From d33db5eb23ff3b682fef9f39ad3d277de9670ca5 Mon Sep 17 00:00:00 2001 From: "Thomas.Hill" Date: Fri, 1 Mar 2024 10:32:43 -0600 Subject: [PATCH] added mypy types for files in tests --- tests/conftest.py | 48 ++++++++++++++++++++++++---------- tests/main/conftest.py | 8 +++--- tests/main/test_indexing.py | 3 ++- tests/schema/conftest.py | 11 +++++--- tests/schema/test_data_type.py | 2 +- tests/schema/test_header.py | 5 ++-- 6 files changed, 51 insertions(+), 26 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1f20477..5f29ca5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import string from typing import TYPE_CHECKING +from typing import Any import numpy as np import pytest @@ -25,7 +26,7 @@ @pytest.fixture(scope="session") -def format_str_to_text_header() -> Callable: +def format_str_to_text_header() -> Callable[[str], str]: """Fixture wrapper around helper function to format text headers.""" def _format_str_to_text_header(text: str) -> str: @@ -38,7 +39,12 @@ def _format_str_to_text_header(text: str) -> str: @pytest.fixture(scope="module") -def make_header_field_descriptor() -> Callable: +def make_header_field_descriptor() -> ( + Callable[ + [str, list[str] | None, list[int] | None, Endianness], + dict[str, list[HeaderFieldDescriptor] | int], + ] +): """Fixture wrapper around helper function to generate params for descriptors.""" def _make_header_field_descriptor( @@ -65,10 +71,12 @@ def _make_header_field_descriptor( ) temp_dt = np.dtype(dt_string) item_size = temp_dt.itemsize + # becuase mypy doesn't like MappingProxy + temp_dt_field_values: list[tuple[Any, ...]] = list(temp_dt.fields.values()) # type: ignore[union-attr] dt_offsets = ( offsets if offsets is not None - else [field[-1] for field in temp_dt.fields.values()] + else [field[-1] for field in temp_dt_field_values] ) header_fields = [ HeaderFieldDescriptor( @@ -85,14 +93,19 @@ def _make_header_field_descriptor( @pytest.fixture(scope="module") -def make_trace_header_descriptor(make_header_field_descriptor: Callable) -> Callable: +def make_trace_header_descriptor( + make_header_field_descriptor: Callable[ + [str, list[str] | None, list[int] | None, str | Endianness], + dict[str, list[HeaderFieldDescriptor] | int], + ], +) -> Callable[..., TraceHeaderDescriptor]: """Fixture wrapper for helper function to create TraceHeaderDescriptors.""" def _make_trace_header_descriptor( dt_string: str = "i2", names: list[str] | None = None, offsets: list[int] | None = None, - endianness: str = Endianness.BIG, + endianness: str | Endianness = Endianness.BIG, ) -> TraceHeaderDescriptor: """Convenience function for creating TraceHeaderDescriptors. @@ -105,9 +118,10 @@ def _make_trace_header_descriptor( Returns: TraceHeaderDescriptor: Descriptor object for TraceHeaderDescriptors """ - head_field_desc = make_header_field_descriptor( - dt_string=dt_string, names=names, offsets=offsets, endianness=endianness + head_field_desc: dict[str, Any] = make_header_field_descriptor( + dt_string, names, offsets, endianness ) + return TraceHeaderDescriptor( fields=head_field_desc["fields"], item_size=head_field_desc["item_size"], @@ -118,7 +132,7 @@ def _make_trace_header_descriptor( @pytest.fixture(scope="module") -def make_trace_data_descriptor() -> Callable: +def make_trace_data_descriptor() -> Callable[..., TraceDataDescriptor]: """Fixture wrapper for helper function to create TraceDataDescriptor.""" def _make_trace_data_descriptor( @@ -150,8 +164,9 @@ def _make_trace_data_descriptor( @pytest.fixture(scope="module") def make_trace_descriptor( - make_trace_header_descriptor: Callable, make_trace_data_descriptor: Callable -) -> Callable: + make_trace_header_descriptor: Callable[..., TraceHeaderDescriptor], + make_trace_data_descriptor: Callable[..., TraceDataDescriptor], +) -> Callable[..., TraceDescriptor]: """Fixture wrapper for helper function to create TraceDescriptors.""" def _make_trace_descriptor( @@ -177,7 +192,12 @@ def _make_trace_descriptor( @pytest.fixture(scope="module") -def make_binary_header_descriptor(make_header_field_descriptor: Callable) -> Callable: +def make_binary_header_descriptor( + make_header_field_descriptor: Callable[ + [str, list[str] | None, list[int] | None, Endianness | str], + dict[str, list[HeaderFieldDescriptor] | int], + ], +) -> Callable[..., BinaryHeaderDescriptor]: """Fixture wrapper around helper function for creating BinaryHeaderDescriptor.""" def _make_binary_header_descriptor( @@ -197,8 +217,8 @@ def _make_binary_header_descriptor( Returns: BinaryHeaderDescriptor: Descriptor object for BinaryHeaderDescriptor """ - head_field_desc = make_header_field_descriptor( - dt_string=dt_string, names=names, offsets=offsets, endianness=endianness + head_field_desc: dict[str, Any] = make_header_field_descriptor( + dt_string, names, offsets, endianness ) return BinaryHeaderDescriptor( fields=head_field_desc["fields"], @@ -211,7 +231,7 @@ def _make_binary_header_descriptor( def generate_unique_names(count: int) -> list[str]: """Helper function to create random unique names as placeholders during testing.""" - names: set = set() + names: set[str] = set() rng = np.random.default_rng() while len(names) < count: name_length = rng.integers(5, 10) # noqa: S311 diff --git a/tests/main/conftest.py b/tests/main/conftest.py index 297bc7f..eb0e7a0 100644 --- a/tests/main/conftest.py +++ b/tests/main/conftest.py @@ -17,7 +17,7 @@ def create_mock_segy_rev0( tmp_file: Path, num_samples: int, num_traces: int, - format_str_to_text_header: Callable, + format_str_to_text_header: Callable[[str], str], ) -> SegyFile: """Create a temporary file that mocks a segy Rev0 file structure.""" rev0_spec = registry.get_spec(SegyStandard.REV0) @@ -50,7 +50,7 @@ def create_mock_segy_rev0( @pytest.fixture(scope="session") def mock_segy_rev0( - request: list[int], tmp_path: Path, format_str_to_text_header: Callable + request: list[int], tmp_path: Path, format_str_to_text_header: Callable[[str], str] ) -> SegyFile: """Returns a temp file that for rev0 SegyFile object.""" req_params = getattr(request, "param", [10, 10]) @@ -68,7 +68,7 @@ def create_mock_segy_rev1( num_samples: int, num_traces: int, num_extended_headers: int, - format_str_to_text_header: Callable, + format_str_to_text_header: Callable[[str], str], ) -> SegyFile: """Create a temporary file that mocks a segy Rev1 file structure.""" rev1_spec = registry.get_spec(SegyStandard.REV1) @@ -111,7 +111,7 @@ def create_mock_segy_rev1( @pytest.fixture(scope="session") def mock_segy_rev1( - request: list[int], tmp_path: Path, format_str_to_text_header: Callable + request: list[int], tmp_path: Path, format_str_to_text_header: Callable[[str], str] ) -> SegyFile: """Returns a temp file that for rev1 SegyFile object.""" req_params = getattr(request, "param", [10, 10, 2]) diff --git a/tests/main/test_indexing.py b/tests/main/test_indexing.py index d01cbdb..0baa32a 100644 --- a/tests/main/test_indexing.py +++ b/tests/main/test_indexing.py @@ -15,6 +15,7 @@ from segy.indexing import merge_cat_file from segy.indexing import trace_ibm2ieee_inplace from segy.schema import Endianness +from segy.schema import TraceDescriptor if TYPE_CHECKING: from collections.abc import Callable @@ -100,7 +101,7 @@ def test_trace_ibm2ieee_inplace( header_params: dict[str, Any], data_params: dict[str, Any], float_vals: list[float], - make_trace_descriptor: Callable, + make_trace_descriptor: Callable[..., TraceDescriptor], ) -> None: """Test changing dtype of IBM32 values inplace.""" trace_descr = make_trace_descriptor(header_params, data_params) diff --git a/tests/schema/conftest.py b/tests/schema/conftest.py index cef883f..365c0a3 100644 --- a/tests/schema/conftest.py +++ b/tests/schema/conftest.py @@ -44,7 +44,8 @@ ] ) def binary_header_descriptors( - request: pytest.FixtureRequest, make_binary_header_descriptor: Callable + request: pytest.FixtureRequest, + make_binary_header_descriptor: Callable[..., BinaryHeaderDescriptor], ) -> BinaryHeaderDescriptor: """Generates BinaryHeaderDescriptor objects from parameters. @@ -68,7 +69,8 @@ def binary_header_descriptors( ] ) def trace_header_descriptors( - request: pytest.FixtureRequest, make_trace_header_descriptor: Callable + request: pytest.FixtureRequest, + make_trace_header_descriptor: Callable[..., TraceHeaderDescriptor], ) -> TraceHeaderDescriptor: """Generates TraceHeaderDescriptor objects from parameters. @@ -119,7 +121,8 @@ def data_types(request: pytest.FixtureRequest) -> DataTypeDescriptor: ] ) def trace_data_descriptors( - request: pytest.FixtureRequest, make_trace_data_descriptor: Callable + request: pytest.FixtureRequest, + make_trace_data_descriptor: Callable[..., TraceDataDescriptor], ) -> TraceDataDescriptor: """Fixture that creates TraceDataDescriptors of different data types and endianness.""" return make_trace_data_descriptor( @@ -182,7 +185,7 @@ def trace_data_descriptors( @pytest.fixture(params=[sample_text, sample_real_header_text]) def text_header_samples( - request: pytest.FixtureRequest, format_str_to_text_header: Callable + request: pytest.FixtureRequest, format_str_to_text_header: Callable[[str], str] ) -> str: """Fixture that generates fixed size text header test data from strings.""" return format_str_to_text_header(request.param) diff --git a/tests/schema/test_data_type.py b/tests/schema/test_data_type.py index c0614af..a6c3eb1 100644 --- a/tests/schema/test_data_type.py +++ b/tests/schema/test_data_type.py @@ -128,7 +128,7 @@ def test_structured_data_type_descriptor( ) -> None: """This tests for creatin a StructuredDataTypeDescriptor for different component data types.""" new_sdtd = StructuredDataTypeDescriptor( - fields=fields, item_size=item_size, offset=offset + fields=list(fields), item_size=item_size, offset=offset ) assert new_sdtd.dtype.names == tuple([f.name for f in fields]) assert new_sdtd.item_size == new_sdtd.dtype.itemsize diff --git a/tests/schema/test_header.py b/tests/schema/test_header.py index d5608b7..7b9cd1c 100644 --- a/tests/schema/test_header.py +++ b/tests/schema/test_header.py @@ -8,6 +8,7 @@ from typing import Any import numpy as np +import numpy.typing as npt import pytest from segy.schema.header import BinaryHeaderDescriptor @@ -83,7 +84,7 @@ def test_trace_header_descriptors( ) -def void_buffer(buff_size: int) -> np.ndarray: +def void_buffer(buff_size: int) -> npt.NDArray[np.void]: """Creates a new buffer of requested number of bytes with void(number_bytes) datatype. Prefills with random bytes. @@ -98,7 +99,7 @@ def void_buffer(buff_size: int) -> np.ndarray: def get_dt_info( dt: np.dtype[Any], atrnames: list[str] | None = None, -) -> dict: +) -> dict[str, Any]: """Helper function to get info about a numpy dtype.""" if atrnames is None: atrnames = [