diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index 16b778f1..c40350aa 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -105,8 +105,8 @@ jobs: src: "." - name: Typecheck with mypy run: | - mypy s3torchconnector/src - mypy s3torchconnectorclient/python/src + mypy --strict s3torchconnector/src + mypy --strict s3torchconnectorclient/python/src dependencies: name: Python dependencies checks diff --git a/doc/DEVELOPMENT.md b/doc/DEVELOPMENT.md index aac5b62f..9d853536 100644 --- a/doc/DEVELOPMENT.md +++ b/doc/DEVELOPMENT.md @@ -70,8 +70,8 @@ For Python code changes, run black --verbose . flake8 s3torchconnector/ --count --select=E9,F63,F7,F82 --show-source --statistics flake8 s3torchconnectorclient/python --count --select=E9,F63,F7,F82 --show-source --statistics -mypy s3torchconnector/src -mypy s3torchconnectorclient/python/src +mypy --strict s3torchconnector/src +mypy --strict s3torchconnectorclient/python/src ``` to lint. diff --git a/s3torchconnector/src/s3torchconnector/_s3_bucket_iterable.py b/s3torchconnector/src/s3torchconnector/_s3_bucket_iterable.py index cb3475ad..f43d7618 100644 --- a/s3torchconnector/src/s3torchconnector/_s3_bucket_iterable.py +++ b/s3torchconnector/src/s3torchconnector/_s3_bucket_iterable.py @@ -1,9 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +from __future__ import annotations + from functools import partial from itertools import chain -from typing import Iterator, List +from typing import Iterator, Dict, Any from s3torchconnectorclient._mountpoint_s3_client import ( ObjectInfo, @@ -43,13 +45,13 @@ def __init__(self, client: S3Client, bucket: str, prefix: str): self._client = client self._list_stream = iter(client.list_objects(bucket, prefix)) - def __iter__(self): + def __iter__(self) -> _PickleableListObjectStream: return self def __next__(self) -> ListObjectResult: return next(self._list_stream) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "client": self._client, "bucket": self._list_stream.bucket, @@ -60,7 +62,7 @@ def __getstate__(self): "complete": self._list_stream.complete, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self._client = state["client"] self._list_stream = ListObjectStream._from_state(**state) diff --git a/s3torchconnector/src/s3torchconnector/_user_agent.py b/s3torchconnector/src/s3torchconnector/_user_agent.py index 2ab254d4..b8ae3665 100644 --- a/s3torchconnector/src/s3torchconnector/_user_agent.py +++ b/s3torchconnector/src/s3torchconnector/_user_agent.py @@ -15,7 +15,7 @@ def __init__(self, comments: Optional[List[str]] = None): self._comments = comments or [] @property - def prefix(self): + def prefix(self) -> str: comments_str = "; ".join(filter(None, self._comments)) if comments_str: return f"{self._user_agent_prefix} ({comments_str})" diff --git a/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py b/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py index 6df2e27e..f30c98b6 100644 --- a/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py +++ b/s3torchconnector/src/s3torchconnector/s3iterable_dataset.py @@ -1,7 +1,18 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +from __future__ import annotations + from functools import partial -from typing import Iterator, Any, Union, Iterable, Callable, Optional +from typing import ( + Iterator, + Union, + Iterable, + Callable, + Optional, + TypeVar, + cast, + overload, +) import logging import torch.utils.data @@ -17,8 +28,10 @@ log = logging.getLogger(__name__) +T_co = TypeVar("T_co", covariant=True) + -class S3IterableDataset(torch.utils.data.IterableDataset): +class S3IterableDataset(torch.utils.data.IterableDataset[T_co]): """An IterableStyle dataset created from S3 objects. To create an instance of S3IterableDataset, you need to use @@ -28,24 +41,26 @@ class S3IterableDataset(torch.utils.data.IterableDataset): def __init__( self, region: str, + *, get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]], endpoint: Optional[str] = None, - transform: Callable[[S3Reader], Any] = identity, + transform: Callable[[S3Reader], T_co], ): self._get_dataset_objects = get_dataset_objects self._transform = transform self._region = region self._endpoint = endpoint - self._client = None + self._client: Optional[S3Client] = None @property - def region(self): + def region(self) -> str: return self._region @property - def endpoint(self): + def endpoint(self) -> Optional[str]: return self._endpoint + @overload @classmethod def from_objects( cls, @@ -53,8 +68,28 @@ def from_objects( *, region: str, endpoint: Optional[str] = None, - transform: Callable[[S3Reader], Any] = identity, - ): + ) -> S3IterableDataset[S3Reader]: ... + + @overload + @classmethod + def from_objects( + cls, + object_uris: Union[str, Iterable[str]], + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co], + ) -> S3IterableDataset[T_co]: ... + + @classmethod + def from_objects( + cls, + object_uris: Union[str, Iterable[str]], + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co | S3Reader] = identity, + ) -> S3IterableDataset[T_co | S3Reader]: """Returns an instance of S3IterableDataset using the S3 URI(s) provided. Args: @@ -72,11 +107,12 @@ def from_objects( log.info(f"Building {cls.__name__} from_objects") return cls( region, - partial(get_objects_from_uris, object_uris), - endpoint, - transform=transform, + get_dataset_objects=partial(get_objects_from_uris, object_uris), + endpoint=endpoint, + transform=cast(Callable[[S3Reader], T_co], transform), ) + @overload @classmethod def from_prefix( cls, @@ -84,8 +120,28 @@ def from_prefix( *, region: str, endpoint: Optional[str] = None, - transform: Callable[[S3Reader], Any] = identity, - ): + ) -> S3IterableDataset[S3Reader]: ... + + @overload + @classmethod + def from_prefix( + cls, + s3_uri: str, + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co], + ) -> S3IterableDataset[T_co]: ... + + @classmethod + def from_prefix( + cls, + s3_uri: str, + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co | S3Reader] = identity, + ) -> S3IterableDataset[T_co | S3Reader]: """Returns an instance of S3IterableDataset using the S3 URI provided. Args: @@ -103,24 +159,24 @@ def from_prefix( log.info(f"Building {cls.__name__} from_prefix {s3_uri=}") return cls( region, - partial(get_objects_from_prefix, s3_uri), - endpoint, - transform=transform, + get_dataset_objects=partial(get_objects_from_prefix, s3_uri), + endpoint=endpoint, + transform=cast(Callable[[S3Reader], T_co], transform), ) - def _get_client(self): + def _get_client(self) -> S3Client: if self._client is None: self._client = S3Client(self.region, self.endpoint) return self._client - def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any: + def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> T_co: return self._transform( self._get_client().get_object( bucket_key.bucket, bucket_key.key, object_info=bucket_key.object_info ) ) - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[T_co]: return map( self._get_transformed_object, self._get_dataset_objects(self._get_client()) ) diff --git a/s3torchconnector/src/s3torchconnector/s3map_dataset.py b/s3torchconnector/src/s3torchconnector/s3map_dataset.py index 163da2b0..88642caf 100644 --- a/s3torchconnector/src/s3torchconnector/s3map_dataset.py +++ b/s3torchconnector/src/s3torchconnector/s3map_dataset.py @@ -1,7 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +from __future__ import annotations + from functools import partial -from typing import List, Any, Callable, Iterable, Union, Optional +from typing import List, Callable, Iterable, Union, Optional, TypeVar, overload, cast import logging import torch.utils.data @@ -19,7 +21,10 @@ log = logging.getLogger(__name__) -class S3MapDataset(torch.utils.data.Dataset): +T_co = TypeVar("T_co", covariant=True) + + +class S3MapDataset(torch.utils.data.Dataset[T_co]): """A Map-Style dataset created from S3 objects. To create an instance of S3MapDataset, you need to use @@ -29,23 +34,24 @@ class S3MapDataset(torch.utils.data.Dataset): def __init__( self, region: str, + *, get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]], endpoint: Optional[str] = None, - transform: Callable[[S3Reader], Any] = identity, + transform: Callable[[S3Reader], T_co], ): self._get_dataset_objects = get_dataset_objects self._transform = transform self._region = region self._endpoint = endpoint - self._client = None + self._client: Optional[S3Client] = None self._bucket_key_pairs: Optional[List[S3BucketKeyData]] = None @property - def region(self): + def region(self) -> str: return self._region @property - def endpoint(self): + def endpoint(self) -> Optional[str]: return self._endpoint @property @@ -55,6 +61,7 @@ def _dataset_bucket_key_pairs(self) -> List[S3BucketKeyData]: assert self._bucket_key_pairs is not None return self._bucket_key_pairs + @overload @classmethod def from_objects( cls, @@ -62,8 +69,28 @@ def from_objects( *, region: str, endpoint: Optional[str] = None, - transform: Callable[[S3Reader], Any] = identity, - ): + ) -> S3MapDataset[S3Reader]: ... + + @overload + @classmethod + def from_objects( + cls, + object_uris: Union[str, Iterable[str]], + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co], + ) -> S3MapDataset[T_co]: ... + + @classmethod + def from_objects( + cls, + object_uris: Union[str, Iterable[str]], + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co | S3Reader] = identity, + ) -> S3MapDataset[T_co | S3Reader]: """Returns an instance of S3MapDataset using the S3 URI(s) provided. Args: @@ -81,11 +108,12 @@ def from_objects( log.info(f"Building {cls.__name__} from_objects") return cls( region, - partial(get_objects_from_uris, object_uris), - endpoint, - transform=transform, + get_dataset_objects=partial(get_objects_from_uris, object_uris), + endpoint=endpoint, + transform=cast(Callable[[S3Reader], T_co], transform), ) + @overload @classmethod def from_prefix( cls, @@ -93,8 +121,28 @@ def from_prefix( *, region: str, endpoint: Optional[str] = None, - transform: Callable[[S3Reader], Any] = identity, - ): + ) -> S3MapDataset[S3Reader]: ... + + @overload + @classmethod + def from_prefix( + cls, + s3_uri: str, + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co], + ) -> S3MapDataset[T_co]: ... + + @classmethod + def from_prefix( + cls, + s3_uri: str, + *, + region: str, + endpoint: Optional[str] = None, + transform: Callable[[S3Reader], T_co | S3Reader] = identity, + ) -> S3MapDataset[T_co | S3Reader]: """Returns an instance of S3MapDataset using the S3 URI provided. Args: @@ -112,12 +160,12 @@ def from_prefix( log.info(f"Building {cls.__name__} from_prefix {s3_uri=}") return cls( region, - partial(get_objects_from_prefix, s3_uri), - endpoint, - transform=transform, + get_dataset_objects=partial(get_objects_from_prefix, s3_uri), + endpoint=endpoint, + transform=cast(Callable[[S3Reader], T_co], transform), ) - def _get_client(self): + def _get_client(self) -> S3Client: if self._client is None: self._client = S3Client(self.region, self.endpoint) return self._client @@ -128,8 +176,8 @@ def _get_object(self, i: int) -> S3Reader: bucket_key.bucket, bucket_key.key, object_info=bucket_key.object_info ) - def __getitem__(self, i: int) -> Any: + def __getitem__(self, i: int) -> T_co: return self._transform(self._get_object(i)) - def __len__(self): + def __len__(self) -> int: return len(self._dataset_bucket_key_pairs) diff --git a/s3torchconnector/src/s3torchconnector/s3reader.py b/s3torchconnector/src/s3torchconnector/s3reader.py index b7873270..ffe37889 100644 --- a/s3torchconnector/src/s3torchconnector/s3reader.py +++ b/s3torchconnector/src/s3torchconnector/s3reader.py @@ -1,15 +1,16 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +from __future__ import annotations import io from functools import cached_property from io import SEEK_CUR, SEEK_END, SEEK_SET -from typing import Callable, Optional, Iterable, Iterator +from typing import Callable, Optional, IO, Iterator, List, Any, Iterable from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo, GetObjectStream -class S3Reader(io.BufferedIOBase): +class S3Reader(io.BufferedIOBase, IO[bytes]): """A read-only, file like representation of a single object stored in S3.""" def __init__( @@ -32,15 +33,15 @@ def __init__( self._position = 0 @property - def bucket(self): + def bucket(self) -> str: return self._bucket @property - def key(self): + def key(self) -> str: return self._key @cached_property - def _object_info(self): + def _object_info(self) -> ObjectInfo: return self._get_object_info() def prefetch(self) -> None: @@ -187,3 +188,12 @@ def writable(self) -> bool: bool: Return whether object was opened for writing. """ return False + + def write(self, data: Any) -> int: + raise OSError("Not implemented") + + def writelines(self, data: Iterable[Any]) -> None: + raise OSError("Not implemented") + + def __enter__(self) -> S3Reader: + return self diff --git a/s3torchconnector/src/s3torchconnector/s3writer.py b/s3torchconnector/src/s3torchconnector/s3writer.py index 76dd4262..395939e9 100644 --- a/s3torchconnector/src/s3torchconnector/s3writer.py +++ b/s3torchconnector/src/s3torchconnector/s3writer.py @@ -1,25 +1,32 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +from __future__ import annotations import io -from typing import Union +from types import TracebackType +from typing import Union, IO, Iterable, Optional from s3torchconnectorclient._mountpoint_s3_client import PutObjectStream -class S3Writer(io.BufferedIOBase): +class S3Writer(io.BufferedIOBase, IO[bytes]): """A write-only, file like representation of a single object stored in S3.""" def __init__(self, stream: PutObjectStream): self.stream = stream - def __enter__(self): + def __enter__(self) -> S3Writer: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.close() - def write( + def write( # type: ignore self, # Ignoring the type for this as we don't currently support the Buffer protocol data: Union[bytes, memoryview], # type: ignore @@ -40,7 +47,15 @@ def write( self.stream.write(data) return len(data) - def close(self): + def writelines( # type: ignore + self, + # Ignoring the type for this as we don't currently support the Buffer protocol + data: Iterable[bytes | memoryview], # type: ignore + ) -> None: + for line in data: + self.write(line) + + def close(self) -> None: """Close write-stream to S3. Ensures all bytes are written successfully. Raises: @@ -48,7 +63,7 @@ def close(self): """ self.stream.close() - def flush(self): + def flush(self) -> None: """No-op""" pass diff --git a/s3torchconnectorclient/python/src/s3torchconnectorclient/__init__.py b/s3torchconnectorclient/python/src/s3torchconnectorclient/__init__.py index d4ff3d15..e75e7e47 100644 --- a/s3torchconnectorclient/python/src/s3torchconnectorclient/__init__.py +++ b/s3torchconnectorclient/python/src/s3torchconnectorclient/__init__.py @@ -2,6 +2,7 @@ # // SPDX-License-Identifier: BSD import copyreg +from typing import Tuple, Type, Any from ._logger_patch import TRACE as LOG_TRACE from ._logger_patch import _install_trace_logging @@ -10,7 +11,7 @@ _install_trace_logging() -def _s3exception_reduce(exc: S3Exception): +def _s3exception_reduce(exc: S3Exception) -> Tuple[Type[S3Exception], Any]: return S3Exception, exc.args diff --git a/s3torchconnectorclient/python/src/s3torchconnectorclient/_logger_patch.py b/s3torchconnectorclient/python/src/s3torchconnectorclient/_logger_patch.py index 8a7d57a6..71d4b2f0 100644 --- a/s3torchconnectorclient/python/src/s3torchconnectorclient/_logger_patch.py +++ b/s3torchconnectorclient/python/src/s3torchconnectorclient/_logger_patch.py @@ -6,5 +6,5 @@ TRACE = 5 -def _install_trace_logging(): +def _install_trace_logging() -> None: logging.addLevelName(TRACE, "TRACE")