Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
2 parents 9364724 + ad37c22 commit 1c278b2
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,9 @@ def __init__(
use_register: bool = False,
mask: torch.Tensor | None = None,
):
dtype, device = _default_dtype_and_device(dtype, device)
dtype, device = _default_dtype_and_device(
dtype, device, allow_none_device=False
)
self.use_register = use_register
space = CategoricalBox(n)
if shape is None:
Expand Down Expand Up @@ -2046,7 +2048,9 @@ def __init__(
if len(kwargs):
raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.")

dtype, device = _default_dtype_and_device(dtype, device)
dtype, device = _default_dtype_and_device(
dtype, device, allow_none_device=False
)
if dtype is None:
dtype = torch.get_default_dtype()
if domain is None:
Expand Down Expand Up @@ -2644,7 +2648,9 @@ def __init__(
if isinstance(shape, int):
shape = _size([shape])

dtype, device = _default_dtype_and_device(dtype, device)
dtype, device = _default_dtype_and_device(
dtype, device, allow_none_device=False
)
if dtype == torch.bool:
min_value = False
max_value = True
Expand Down Expand Up @@ -2851,7 +2857,9 @@ def __init__(
mask: torch.Tensor | None = None,
):
self.nvec = nvec
dtype, device = _default_dtype_and_device(dtype, device)
dtype, device = _default_dtype_and_device(
dtype, device, allow_none_device=False
)
if shape is None:
shape = _size((sum(nvec),))
else:
Expand Down Expand Up @@ -3327,7 +3335,9 @@ def __init__(
):
if shape is None:
shape = _size([])
dtype, device = _default_dtype_and_device(dtype, device)
dtype, device = _default_dtype_and_device(
dtype, device, allow_none_device=False
)
space = CategoricalBox(n)
super().__init__(
shape=shape, space=space, device=device, dtype=dtype, domain="discrete"
Expand Down Expand Up @@ -3874,7 +3884,9 @@ def __init__(
if nvec.ndim < 1:
nvec = nvec.unsqueeze(0)
self.nvec = nvec
dtype, device = _default_dtype_and_device(dtype, device)
dtype, device = _default_dtype_and_device(
dtype, device, allow_none_device=False
)
if shape is None:
shape = nvec.shape
else:
Expand Down

0 comments on commit 1c278b2

Please sign in to comment.