Skip to content

Commit

Permalink
[BC-breaking] Remove MapIterator (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Dec 26, 2024
1 parent 138861b commit dabc1a1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 68 deletions.
49 changes: 1 addition & 48 deletions src/spdl/dataloader/_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
Callable,
Iterable,
Iterator,
Mapping,
Sequence,
)
from typing import Any, TypeVar

__all__ = ["run_in_subprocess", "MapIterator", "MergeIterator"]
__all__ = ["run_in_subprocess", "MergeIterator"]

T = TypeVar("T")

Expand Down Expand Up @@ -335,49 +334,3 @@ def __iter__(self) -> Iterator[T]:
yield from _stocastic_iter(
iterators, self.weights, self.stop_after, self.seed
)


################################################################################
# MapIterator
################################################################################


class MapIterator(Iterable[V]):
"""Combine Mapping object and iterable to iterate over mapped objects
Args:
mapping: Object implements :py:class:`~collections.abc.Mapping` interface.
sampler: **Optional** Generator that yields key for the mapping.
Used to specify the iteration order over the mapping and/or to sample
from a subset of the mapping.
Example:
>>> mapping = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
>>> for item in MapIterator(mapping):
... print(item)
...
a
b
c
d
e
>>> sampler = range(4, -2, -1)
>>> for item in MapIterator(mapping, sampler):
... print(item)
...
e
c
a
"""

def __init__(
self,
mapping: Mapping[K, V],
sampler: Iterable[K] | None = None,
) -> None:
self.mapping = mapping
self.sampler = sampler

def __iter__(self) -> Iterator[V]:
for key in self.sampler or self.mapping:
yield self.mapping[key]
21 changes: 1 addition & 20 deletions tests/spdl_unittest/dataloader/dataloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time

import pytest
from spdl.dataloader import DataLoader, MapIterator, MergeIterator
from spdl.dataloader import DataLoader, MergeIterator


def get_dl(*args, timeout=3, num_threads=2, **kwargs):
Expand Down Expand Up @@ -165,25 +165,6 @@ def agg(vals: list[int]) -> tuple[int, int, int, int]:
assert list(dl) == expected


def test_mapiterator():
"""MapIterator iterates the mapped values"""

mapping = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}

result = list(MapIterator(mapping))
assert result == list(mapping.values())


def test_mapiterator_sampler():
"""MapIterator iterates the mapped values picked by sampler"""

mapping = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
sampler = [4, 2, 0]

result = list(MapIterator(mapping, sampler))
assert result == ["e", "c", "a"]


def test_mergeiterator_ordered():
"""MergeIterator iterates multiple iterators"""

Expand Down

0 comments on commit dabc1a1

Please sign in to comment.