Skip to content

Commit

Permalink
Set device for torch tensors with gpu > 1 (#132)
Browse files Browse the repository at this point in the history
* Set device for torch tensors with gpu > 1

* Remove device type validation
  • Loading branch information
edknv authored Apr 12, 2023
1 parent 12372f4 commit 8782c9d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 2 additions & 1 deletion merlin/dataloader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ def _to_tensor(self, df_or_series):
if self.device == "cpu":
tensor = df_or_series.to_numpy()
else:
tensor = df_or_series.to_cupy()
with cupy.cuda.Device(self.device):
tensor = df_or_series.to_cupy()

return tensor

Expand Down
6 changes: 6 additions & 0 deletions merlin/dataloader/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
from functools import partial

from merlin.core.compat.torch import torch as th
Expand Down Expand Up @@ -118,6 +119,11 @@ def map(self, fn):

return self

def _get_device_ctx(self, dev):
if dev == "cpu" or not th:
return contextlib.nullcontext()
return th.cuda.device(f"cuda:{dev}")


class DLDataLoader(th.utils.data.DataLoader):
"""
Expand Down

0 comments on commit 8782c9d

Please sign in to comment.