Skip to content

Commit

Permalink
Fix remaining mypy issues mostly ignoring them
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
gokulavasan committed Mar 26, 2024
1 parent 2b831b3 commit 9dda532
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
dispatching_dp = find_lca_round_robin_sharding_dp(graph)
# TODO(ejguan): When the last DataPipe is round_robin_sharding, use InPrcoessReadingService
if dispatching_dp is not None:
dummy_dp = _DummyIterDataPipe()
dummy_dp = _DummyIterDataPipe() # type: ignore
graph = replace_dp(graph, dispatching_dp, dummy_dp) # type: ignore[arg-type]
datapipe = list(graph.values())[0][0]
# TODO(ejguan): Determine buffer_size at runtime or use unlimited buffer
Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from torch.utils.data import DataChunk, IterableDataset, default_collate
from torch.utils.data.datapipes._typing import _DataPipeMeta
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union, Hashable
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, TypeVar, Union, Hashable

try:
import torcharrow
Expand All @@ -24,6 +24,7 @@ except ImportError:

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
ForkIterDataPipeCopyOptions = Literal["shallow", "deep"]

class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta):
functions: Dict[str, Callable] = ...
Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int =
return [datapipe]

datapipe = datapipe.enumerate()
container = _RoundRobinDemultiplexerIterDataPipe(datapipe, num_instances, buffer_size=buffer_size)
container = _RoundRobinDemultiplexerIterDataPipe(datapipe, num_instances, buffer_size=buffer_size) # type: ignore
return [_ChildDataPipe(container, i).map(_drop_index) for i in range(num_instances)]


Expand Down Expand Up @@ -357,7 +357,7 @@ def __new__(
)

# The implementation basically uses Forker but only yields a specific element within the sequence
container = _UnZipperIterDataPipe(source_datapipe, instance_ids, buffer_size) # type: ignore[arg-type]
container = _UnZipperIterDataPipe(source_datapipe, instance_ids, buffer_size) # type: ignore
return [_ChildDataPipe(container, i) for i in range(len(instance_ids))]


Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/randomsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __new__(
"RandomSplitter needs `total_length`, but it is unable to infer it from "
f"the `source_datapipe`: {source_datapipe}."
)
container = _RandomSplitterIterDataPipe(source_datapipe, total_length, weights, seed)
container = _RandomSplitterIterDataPipe(source_datapipe, total_length, weights, seed) # type: ignore
if target is None:
return [SplitterIterator(container, k) for k in list(weights.keys())]
else:
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
self._rng = random.Random(self._seed)
self._lengths: List[int] = []

def draw(self) -> T:
def draw(self) -> T: # type: ignore
selected_key = self._rng.choices(self.keys, self.weights)[0]
index = self.key_to_index[selected_key]
self.weights[index] -= 1
Expand Down

0 comments on commit 9dda532

Please sign in to comment.