Skip to content

Commit

Permalink
Use default_collate from public API
Browse files Browse the repository at this point in the history
Differential Revision: D67815203

Pull Request resolved: #1419
  • Loading branch information
kit1980 authored Jan 4, 2025
1 parent 0d2b0a0 commit 227d3d7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
4 changes: 2 additions & 2 deletions examples/vision/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import re
import threading

import torchvision.datasets as datasets
import torchvision.datasets.folder
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import FileLister, HttpReader, IterDataPipe

from torchvision import datasets, transforms

IMAGES_ROOT = os.path.join("fakedata", "imagefolder")

USE_FORK_DATAPIPE = False
Expand Down
47 changes: 22 additions & 25 deletions test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_utils,
ChainDataset,
ConcatDataset,
dataloader,
Dataset,
IterableDataset,
IterDataPipe,
Expand Down Expand Up @@ -588,7 +589,6 @@ def print_traces_of_all_threads(pid):
# its `.exception` attribute.
# Inspired by https://stackoverflow.com/a/33599967
class ErrorTrackingProcess(mp.Process):

# Why no *args?
# py2 doesn't support def fn(x, *args, key=val, **kwargs)
# Setting disable_stderr=True may generate a lot of unrelated error outputs
Expand Down Expand Up @@ -1767,7 +1767,6 @@ def test_shuffle_batch_workers_prefetch(self):
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, prefetch_factor=3))

def test_random_sampler(self):

from collections import Counter

from torch.utils.data import RandomSampler
Expand Down Expand Up @@ -1888,7 +1887,6 @@ def test_distributed_sampler_invalid_rank(self):
DistributedSampler(dataset, 3, -1)

def test_duplicating_data_with_drop_last(self):

from torch.utils.data.distributed import DistributedSampler

num_processes = 4
Expand Down Expand Up @@ -2085,7 +2083,6 @@ def test_proper_exit(self):
for is_iterable_dataset, use_workers, pin_memory, hold_iter_reference in itertools.product(
[True, False], repeat=4
):

# `hold_iter_reference` specifies whether we hold a reference to the
# iterator. This is interesting because Python3 error traces holds a
# reference to the frames, which hold references to all the local
Expand Down Expand Up @@ -2330,51 +2327,51 @@ def __len__(self):

def test_default_convert_mapping_keep_type(self):
data = CustomDict({"a": 1, "b": 2})
converted = _utils.collate.default_convert(data)
converted = dataloader.default_convert(data)

self.assertEqual(converted, data)

def test_default_convert_sequence_keep_type(self):
data = CustomList([1, 2, 3])
converted = _utils.collate.default_convert(data)
converted = dataloader.default_convert(data)

self.assertEqual(converted, data)

def test_default_convert_sequence_dont_keep_type(self):
data = range(2)
converted = _utils.collate.default_convert(data)
converted = dataloader.default_convert(data)

self.assertEqual(converted, [0, 1])

def test_default_collate_dtype(self):
arr = [1, 2, -1]
collated = _utils.collate.default_collate(arr)
collated = dataloader.default_collate(arr)
self.assertEqual(collated, torch.tensor(arr))
self.assertEqual(collated.dtype, torch.int64)

arr = [1.1, 2.3, -0.9]
collated = _utils.collate.default_collate(arr)
collated = dataloader.default_collate(arr)
self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))

arr = [True, False]
collated = _utils.collate.default_collate(arr)
collated = dataloader.default_collate(arr)
self.assertEqual(collated, torch.tensor(arr))
self.assertEqual(collated.dtype, torch.bool)

# Should be a no-op
arr = ["a", "b", "c"]
self.assertEqual(arr, _utils.collate.default_collate(arr))
self.assertEqual(arr, dataloader.default_collate(arr))

def test_default_collate_mapping_keep_type(self):
batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
collated = _utils.collate.default_collate(batch)
collated = dataloader.default_collate(batch)

expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
self.assertEqual(collated, expected)

def test_default_collate_sequence_keep_type(self):
batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
collated = _utils.collate.default_collate(batch)
collated = dataloader.default_collate(batch)

expected = CustomList(
[
Expand All @@ -2387,7 +2384,7 @@ def test_default_collate_sequence_keep_type(self):

def test_default_collate_sequence_dont_keep_type(self):
batch = [range(2), range(2)]
collated = _utils.collate.default_collate(batch)
collated = dataloader.default_collate(batch)

self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])

Expand All @@ -2397,16 +2394,16 @@ def test_default_collate_bad_numpy_types(self):

# Should be a no-op
arr = np.array(["a", "b", "c"])
self.assertEqual(arr, _utils.collate.default_collate(arr))
self.assertEqual(arr, dataloader.default_collate(arr))

arr = np.array([[["a", "b", "c"]]])
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))

arr = np.array([object(), object(), object()])
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))

arr = np.array([[[object(), object(), object()]]])
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))

@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_default_collate_numpy_memmap(self):
Expand All @@ -2417,14 +2414,14 @@ def test_default_collate_numpy_memmap(self):
arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
arr_memmap[:] = arr[:]
arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
tensor = _utils.collate.default_collate(list(arr_new))
tensor = dataloader.default_collate(list(arr_new))

self.assertTrue((tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item())

def test_default_collate_bad_sequence_type(self):
batch = [["X"], ["X", "X"]]
self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch[::-1]))
self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch))
self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch[::-1]))

@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_default_collate_shared_tensor(self):
Expand All @@ -2435,17 +2432,17 @@ def test_default_collate_shared_tensor(self):

self.assertEqual(t_in.is_shared(), False)

self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), False)
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), False)

# FIXME: fix the following hack that makes `default_collate` believe
# that it is in a worker process (since it tests
# `get_worker_info() != None`), even though it is not.
old = _utils.worker._worker_info
try:
_utils.worker._worker_info = "x"
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), True)
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), True)
finally:
_utils.worker._worker_info = old

Expand Down

0 comments on commit 227d3d7

Please sign in to comment.