Skip to content

Commit

Permalink
error msg typo fix, worker_info simplification
Browse files Browse the repository at this point in the history
`mypy` (outside `pre-commit`) didn't like the `get_worker_info().num_workers`
  • Loading branch information
ryan-williams committed Oct 25, 2024
1 parent 67681a7 commit 4a435a0
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,13 @@ def __iter__(self) -> Iterator[XObsDatum]:
experimental
"""

if (
self.return_sparse_X
and torch.utils.data.get_worker_info()
and torch.utils.data.get_worker_info().num_workers > 0
):
raise NotImplementedError(
"torch does not work with sparse tensors in multi-processing mode "
"(see https://github.com/pytorch/pytorch/issues/20248)"
)
if self.return_sparse_X:
worker_info = torch.utils.data.get_worker_info()
if worker_info and worker_info.num_workers > 0:
raise NotImplementedError(
"torch does not work with sparse tensors in multi-processing mode "
"(see https://github.com/pytorch/pytorch/issues/20248)"
)

world_size, rank = _get_distributed_world_rank()
n_workers, worker_id = _get_worker_world_rank()
Expand Down Expand Up @@ -426,7 +424,7 @@ def set_epoch(self, epoch: int) -> None:

def __getitem__(self, index: int) -> XObsDatum:
raise NotImplementedError(
"``ExperimentAxisQueryIterable can only be iterated - does not support mapping"
"`ExperimentAxisQueryIterable` can only be iterated - does not support mapping"
)

def _io_batch_iter(
Expand Down

0 comments on commit 4a435a0

Please sign in to comment.