From 014b658909426f544059434834f5480ab484f824 Mon Sep 17 00:00:00 2001 From: edknv <109497216+edknv@users.noreply.github.com> Date: Fri, 14 Apr 2023 11:17:56 -0700 Subject: [PATCH] Revert "Set device for torch tensors with gpu > 1 (#132)" (#134) This reverts commit 8782c9d6c10c56de57607e06b1b1dd6a4e9b6430. --- merlin/dataloader/loader_base.py | 3 +-- merlin/dataloader/torch.py | 6 ------ 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/merlin/dataloader/loader_base.py b/merlin/dataloader/loader_base.py index 171c3ece..81f5fd5c 100644 --- a/merlin/dataloader/loader_base.py +++ b/merlin/dataloader/loader_base.py @@ -447,8 +447,7 @@ def _to_tensor(self, df_or_series): if self.device == "cpu": tensor = df_or_series.to_numpy() else: - with cupy.cuda.Device(self.device): - tensor = df_or_series.to_cupy() + tensor = df_or_series.to_cupy() return tensor diff --git a/merlin/dataloader/torch.py b/merlin/dataloader/torch.py index 6921ce7e..148126ec 100644 --- a/merlin/dataloader/torch.py +++ b/merlin/dataloader/torch.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import contextlib from functools import partial from merlin.core.compat.torch import torch as th @@ -119,11 +118,6 @@ def map(self, fn): return self - def _get_device_ctx(self, dev): - if dev == "cpu" or not th: - return contextlib.nullcontext() - return th.cuda.device(f"cuda:{dev}") - class DLDataLoader(th.utils.data.DataLoader): """