diff --git a/merlin/dataloader/loader_base.py b/merlin/dataloader/loader_base.py index 81f5fd5c..171c3ece 100644 --- a/merlin/dataloader/loader_base.py +++ b/merlin/dataloader/loader_base.py @@ -447,7 +447,8 @@ def _to_tensor(self, df_or_series): if self.device == "cpu": tensor = df_or_series.to_numpy() else: - tensor = df_or_series.to_cupy() + with cupy.cuda.Device(self.device): + tensor = df_or_series.to_cupy() return tensor diff --git a/merlin/dataloader/torch.py b/merlin/dataloader/torch.py index 148126ec..6921ce7e 100644 --- a/merlin/dataloader/torch.py +++ b/merlin/dataloader/torch.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import contextlib from functools import partial from merlin.core.compat.torch import torch as th @@ -118,6 +119,11 @@ def map(self, fn): return self + def _get_device_ctx(self, dev): + if dev == "cpu" or not th: + return contextlib.nullcontext() + return th.cuda.device(f"cuda:{dev}") + class DLDataLoader(th.utils.data.DataLoader): """