Skip to content

Commit

Permalink
test_pytorch.py: PipeClassType / IterableWrapperType fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Oct 29, 2024
1 parent 4a435a0 commit 2112a3f
Showing 1 changed file with 25 additions and 33 deletions.
58 changes: 25 additions & 33 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
from unittest.mock import patch

import numpy as np
Expand All @@ -35,15 +35,21 @@

# These control which classes are tested (for most, but not all tests).
# Centralized to allow easy add/delete of specific test parameters.
PipeClassType = Union[
ExperimentAxisQueryIterable,
IterableWrapperType = Union[
Type[ExperimentAxisQueryIterDataPipe],
Type[ExperimentAxisQueryIterableDataset],
]
IterableWrappers = (
ExperimentAxisQueryIterDataPipe,
ExperimentAxisQueryIterableDataset,
)
PipeClassType = Union[
Type[ExperimentAxisQueryIterable],
IterableWrapperType,
]
PipeClasses = (
ExperimentAxisQueryIterable,
ExperimentAxisQueryIterDataPipe,
ExperimentAxisQueryIterableDataset,
*IterableWrappers,
)
XValueGen = Callable[[range, range], spmatrix]

Expand Down Expand Up @@ -450,24 +456,22 @@ def test_batching__partial_soma_batches_are_concatenated(
@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)]
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_multiprocessing__returns_full_result(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
) -> None:
"""Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a
PyTorch DataLoader with multiple workers configured."""
"""Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset``
provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader
with multiple workers configured."""
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["soma_joinid", "label"],
io_batch_size=3, # two chunks, one per worker
)
# Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing
# Wrap with a DataLoader, which sets up the multiprocessing
dl = experiment_dataloader(dp, num_workers=2)

full_result = list(iter(dl))
Expand Down Expand Up @@ -593,12 +597,9 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__non_batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -624,12 +625,9 @@ def test_experiment_dataloader__non_batched(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -656,12 +654,9 @@ def test_experiment_dataloader__batched(
for use_eager_fetch in (True, False)
],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched_length(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -682,12 +677,9 @@ def test_experiment_dataloader__batched_length(
"obs_range,var_range,X_value_gen,batch_size",
[(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__collate_fn(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
batch_size: int,
) -> None:
Expand Down

0 comments on commit 2112a3f

Please sign in to comment.