Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataPipe] extract keys #406

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MapKeyZipper,
MaxTokenBucketizer,
ParagraphAggregator,
ExtractKeys,
Rows2Columnar,
SampleMultiplexer,
UnZipper,
Expand Down Expand Up @@ -902,6 +903,19 @@ def test_mux_longest_iterdatapipe(self):
with self.assertRaises(TypeError):
len(output_dp)

def test_extractor(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_extractor(self):
def test_key_extractor(self):

nit: We used to have a different extractor


# Functional Test: verify that extracting by patterns yields correct output
stage1 = IterableWrapper([
{"1.txt": "1", "1.bin": "1b"},
{"2.txt": "2", "2.bin": "2b"},
])
stage2 = ExtractKeys(stage1, "*.txt", "*.bin")
output = list(iter(stage2))
assert len(output) == 2
tmbdev marked this conversation as resolved.
Show resolved Hide resolved
assert output[0][0] == "1"
assert output[0][1] == "1b"

tmbdev marked this conversation as resolved.
Show resolved Hide resolved
def test_zip_longest_iterdatapipe(self):

# Functional Test: raises TypeError when an input is not of type `IterDataPipe`
Expand Down
8 changes: 7 additions & 1 deletion torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@
TFRecordLoaderIterDataPipe as TFRecordLoader,
)
from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper
from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset
from torchdata.datapipes.iter.util.webdataset import (
WebDatasetIterDataPipe as WebDataset,
)
from torchdata.datapipes.iter.util.extractkeys import (
ExtractKeysIterDataPipe as ExtractKeys,
)
from torchdata.datapipes.iter.util.xzfileloader import (
XzFileLoaderIterDataPipe as XzFileLoader,
XzFileReaderIterDataPipe as XzFileReader,
Expand All @@ -136,6 +141,7 @@
"Demultiplexer",
"EndOnDiskCacheHolder",
"Enumerator",
"ExtractKeys",
"Extractor",
"FSSpecFileLister",
"FSSpecFileOpener",
Expand Down
59 changes: 59 additions & 0 deletions torchdata/datapipes/iter/util/extractkeys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from fnmatch import fnmatch
from typing import Dict, Iterator, Tuple

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe


@functional_datapipe("extract_keys")
class ExtractKeysIterDataPipe(IterDataPipe[Dict]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to KeyExtractor to follow our naming convention? Thanks.

We can still keep "extract_keys" as the functional name.

r"""
Given a stream of dictionaries, return a stream of tuples by selecting keys using glob patterns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Given a stream of dictionaries, return a stream of tuples by selecting keys using glob patterns.
Given a stream of dictionaries, return a stream of dicts (or tuples) by selecting keys using glob patterns.


Args:
source_datapipe: a DataPipe yielding a stream of dictionaries.
duplicate_is_error: it is an error if the same key is selected twice (True)
ignore_missing: skip any dictionaries where one or more patterns don't match (False)
Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
duplicate_is_error: it is an error if the same key is selected twice (True)
ignore_missing: skip any dictionaries where one or more patterns don't match (False)

Duplicate lines of descriptions

*args: list of glob patterns or list of glob patterns

Returns:
a DataPipe yielding a stream of tuples

Examples:
>>> dp = FileLister(...).load_from_tar().webdataset().decode(...).extract_keys(["*.jpg", "*.png"], "*.gt.txt")
"""
Comment on lines +32 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to the one example with webdataset, please add an example with sample outputs here. Copying from the test cases is totally fine to me.


def __init__(
self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error=True, ignore_missing=False
) -> None:
super().__init__()
self.source_datapipe: IterDataPipe[Dict] = source_datapipe
self.duplicate_is_error = duplicate_is_error
self.patterns = args
self.ignore_missing = ignore_missing

def __iter__(self) -> Iterator[Tuple]:
for sample in self.source_datapipe:
result = []
for pattern in self.patterns:
pattern = [pattern] if not isinstance(pattern, (list, tuple)) else pattern
matches = [x for x in sample.keys() if any(fnmatch(x, p) for p in pattern)]
if len(matches) == 0:
if self.ignore_missing:
continue
else:
raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
if len(matches) > 1 and self.duplicate_is_error:
tmbdev marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
value = sample[matches[0]]
result.append(value)
yield tuple(result)

def __len__(self) -> int:
return len(self.source_datapipe)
Comment on lines +72 to +73
Copy link
Contributor

@NivekT NivekT Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: A sample will always be yielded even if nothing matches right?