Skip to content

Commit

Permalink
Revert "Set device for torch tensors with gpu > 1 (#132)" (#134)
Browse files Browse the repository at this point in the history
This reverts commit 8782c9d.
  • Loading branch information
edknv authored Apr 14, 2023
1 parent 8782c9d commit 014b658
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
3 changes: 1 addition & 2 deletions merlin/dataloader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,7 @@ def _to_tensor(self, df_or_series):
if self.device == "cpu":
tensor = df_or_series.to_numpy()
else:
with cupy.cuda.Device(self.device):
tensor = df_or_series.to_cupy()
tensor = df_or_series.to_cupy()

return tensor

Expand Down
6 changes: 0 additions & 6 deletions merlin/dataloader/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
from functools import partial

from merlin.core.compat.torch import torch as th
Expand Down Expand Up @@ -119,11 +118,6 @@ def map(self, fn):

return self

def _get_device_ctx(self, dev):
if dev == "cpu" or not th:
return contextlib.nullcontext()
return th.cuda.device(f"cuda:{dev}")


class DLDataLoader(th.utils.data.DataLoader):
"""
Expand Down

0 comments on commit 014b658

Please sign in to comment.