Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strict support for mypy #169

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ jobs:
src: "."
- name: Typecheck with mypy
run: |
mypy s3torchconnector/src
mypy s3torchconnectorclient/python/src
mypy --strict s3torchconnector/src
mypy --strict s3torchconnectorclient/python/src

dependencies:
name: Python dependencies checks
Expand Down
4 changes: 2 additions & 2 deletions doc/DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ For Python code changes, run
black --verbose .
flake8 s3torchconnector/ --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 s3torchconnectorclient/python --count --select=E9,F63,F7,F82 --show-source --statistics
mypy s3torchconnector/src
mypy s3torchconnectorclient/python/src
mypy --strict s3torchconnector/src
mypy --strict s3torchconnectorclient/python/src
```
to lint.

Expand Down
10 changes: 6 additions & 4 deletions s3torchconnector/src/s3torchconnector/_s3_bucket_iterable.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from __future__ import annotations


from functools import partial
from itertools import chain
from typing import Iterator, List
from typing import Iterator, Dict, Any

from s3torchconnectorclient._mountpoint_s3_client import (
ObjectInfo,
Expand Down Expand Up @@ -43,13 +45,13 @@ def __init__(self, client: S3Client, bucket: str, prefix: str):
self._client = client
self._list_stream = iter(client.list_objects(bucket, prefix))

def __iter__(self):
def __iter__(self) -> _PickleableListObjectStream:
return self

def __next__(self) -> ListObjectResult:
return next(self._list_stream)

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
return {
"client": self._client,
"bucket": self._list_stream.bucket,
Expand All @@ -60,7 +62,7 @@ def __getstate__(self):
"complete": self._list_stream.complete,
}

def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
self._client = state["client"]
self._list_stream = ListObjectStream._from_state(**state)

Expand Down
2 changes: 1 addition & 1 deletion s3torchconnector/src/s3torchconnector/_user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, comments: Optional[List[str]] = None):
self._comments = comments or []

@property
def prefix(self):
def prefix(self) -> str:
comments_str = "; ".join(filter(None, self._comments))
if comments_str:
return f"{self._user_agent_prefix} ({comments_str})"
Expand Down
94 changes: 75 additions & 19 deletions s3torchconnector/src/s3torchconnector/s3iterable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
from __future__ import annotations

from functools import partial
from typing import Iterator, Any, Union, Iterable, Callable, Optional
from typing import (
Iterator,
Union,
Iterable,
Callable,
Optional,
TypeVar,
cast,
overload,
)
import logging

import torch.utils.data
Expand All @@ -17,8 +28,10 @@

log = logging.getLogger(__name__)

T_co = TypeVar("T_co", covariant=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class S3IterableDataset(torch.utils.data.IterableDataset):
class S3IterableDataset(torch.utils.data.IterableDataset[T_co]):
"""An IterableStyle dataset created from S3 objects.

To create an instance of S3IterableDataset, you need to use
Expand All @@ -28,33 +41,55 @@ class S3IterableDataset(torch.utils.data.IterableDataset):
def __init__(
self,
region: str,
*,
get_dataset_objects: Callable[[S3Client], Iterable[S3BucketKeyData]],
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
transform: Callable[[S3Reader], T_co],
):
self._get_dataset_objects = get_dataset_objects
self._transform = transform
self._region = region
self._endpoint = endpoint
self._client = None
self._client: Optional[S3Client] = None

@property
def region(self):
def region(self) -> str:
return self._region

@property
def endpoint(self):
def endpoint(self) -> Optional[str]:
return self._endpoint

@overload
@classmethod
def from_objects(
cls,
object_uris: Union[str, Iterable[str]],
*,
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
):
) -> S3IterableDataset[S3Reader]: ...

@overload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to introduce this overload just for typing?

@classmethod
def from_objects(
cls,
object_uris: Union[str, Iterable[str]],
*,
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], T_co],
) -> S3IterableDataset[T_co]: ...

@classmethod
def from_objects(
cls,
object_uris: Union[str, Iterable[str]],
*,
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], T_co | S3Reader] = identity,
) -> S3IterableDataset[T_co | S3Reader]:
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.

Args:
Expand All @@ -72,20 +107,41 @@ def from_objects(
log.info(f"Building {cls.__name__} from_objects")
return cls(
region,
partial(get_objects_from_uris, object_uris),
endpoint,
transform=transform,
get_dataset_objects=partial(get_objects_from_uris, object_uris),
endpoint=endpoint,
transform=cast(Callable[[S3Reader], T_co], transform),
)

@overload
@classmethod
def from_prefix(
cls,
s3_uri: str,
*,
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
):
) -> S3IterableDataset[S3Reader]: ...

@overload
@classmethod
def from_prefix(
cls,
s3_uri: str,
*,
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], T_co],
) -> S3IterableDataset[T_co]: ...

@classmethod
def from_prefix(
cls,
s3_uri: str,
*,
region: str,
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], T_co | S3Reader] = identity,
) -> S3IterableDataset[T_co | S3Reader]:
"""Returns an instance of S3IterableDataset using the S3 URI provided.

Args:
Expand All @@ -103,24 +159,24 @@ def from_prefix(
log.info(f"Building {cls.__name__} from_prefix {s3_uri=}")
return cls(
region,
partial(get_objects_from_prefix, s3_uri),
endpoint,
transform=transform,
get_dataset_objects=partial(get_objects_from_prefix, s3_uri),
endpoint=endpoint,
transform=cast(Callable[[S3Reader], T_co], transform),
)

def _get_client(self):
def _get_client(self) -> S3Client:
if self._client is None:
self._client = S3Client(self.region, self.endpoint)
return self._client

def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any:
def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> T_co:
return self._transform(
self._get_client().get_object(
bucket_key.bucket, bucket_key.key, object_info=bucket_key.object_info
)
)

def __iter__(self) -> Iterator[Any]:
def __iter__(self) -> Iterator[T_co]:
return map(
self._get_transformed_object, self._get_dataset_objects(self._get_client())
)
Loading
Loading