Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type-fixes / errata #18

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ version = {attr = "tiledbsoma_ml.__version__"}
[tool.setuptools.package-data]
"tiledbsoma_ml" = ["py.typed"]

[tool.setuptools_scm]
root = "../../.."

[tool.mypy]
show_error_codes = true
ignore_missing_imports = true
Expand Down
18 changes: 8 additions & 10 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,13 @@ def __iter__(self) -> Iterator[XObsDatum]:
experimental
"""

if (
self.return_sparse_X
and torch.utils.data.get_worker_info()
and torch.utils.data.get_worker_info().num_workers > 0
):
raise NotImplementedError(
"torch does not work with sparse tensors in multi-processing mode "
"(see https://github.com/pytorch/pytorch/issues/20248)"
)
if self.return_sparse_X:
worker_info = torch.utils.data.get_worker_info()
if worker_info and worker_info.num_workers > 0:
raise NotImplementedError(
"torch does not work with sparse tensors in multi-processing mode "
"(see https://github.com/pytorch/pytorch/issues/20248)"
)

world_size, rank = _get_distributed_world_rank()
n_workers, worker_id = _get_worker_world_rank()
Expand Down Expand Up @@ -426,7 +424,7 @@ def set_epoch(self, epoch: int) -> None:

def __getitem__(self, index: int) -> XObsDatum:
raise NotImplementedError(
"``ExperimentAxisQueryIterable can only be iterated - does not support mapping"
"`ExperimentAxisQueryIterable` can only be iterated - does not support mapping"
)

def _io_batch_iter(
Expand Down
58 changes: 25 additions & 33 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
from unittest.mock import patch

import numpy as np
Expand All @@ -35,15 +35,21 @@

# These control which classes are tested (for most, but not all tests).
# Centralized to allow easy add/delete of specific test parameters.
PipeClassType = Union[
ExperimentAxisQueryIterable,
IterableWrapperType = Union[
Type[ExperimentAxisQueryIterDataPipe],
Type[ExperimentAxisQueryIterableDataset],
]
IterableWrappers = (
ExperimentAxisQueryIterDataPipe,
ExperimentAxisQueryIterableDataset,
)
PipeClassType = Union[
Type[ExperimentAxisQueryIterable],
IterableWrapperType,
]
PipeClasses = (
ExperimentAxisQueryIterable,
ExperimentAxisQueryIterDataPipe,
ExperimentAxisQueryIterableDataset,
*IterableWrappers,
)
XValueGen = Callable[[range, range], spmatrix]

Expand Down Expand Up @@ -450,24 +456,22 @@ def test_batching__partial_soma_batches_are_concatenated(
@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)]
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_multiprocessing__returns_full_result(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
) -> None:
"""Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a
PyTorch DataLoader with multiple workers configured."""
"""Tests that ``ExperimentAxisQueryIterDataPipe`` / ``ExperimentAxisQueryIterableDataset``
provide all data, as collected from multiple processes that are managed by a PyTorch DataLoader
with multiple workers configured."""
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["soma_joinid", "label"],
io_batch_size=3, # two chunks, one per worker
)
# Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing
# Wrap with a DataLoader, which sets up the multiprocessing
dl = experiment_dataloader(dp, num_workers=2)

full_result = list(iter(dl))
Expand Down Expand Up @@ -593,12 +597,9 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__non_batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -624,12 +625,9 @@ def test_experiment_dataloader__non_batched(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -656,12 +654,9 @@ def test_experiment_dataloader__batched(
for use_eager_fetch in (True, False)
],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__batched_length(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
Expand All @@ -682,12 +677,9 @@ def test_experiment_dataloader__batched_length(
"obs_range,var_range,X_value_gen,batch_size",
[(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
@pytest.mark.parametrize("PipeClass", IterableWrappers)
def test_experiment_dataloader__collate_fn(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
PipeClass: IterableWrapperType,
soma_experiment: Experiment,
batch_size: int,
) -> None:
Expand Down