Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 17, 2024
1 parent e32d624 commit 284683c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
4 changes: 1 addition & 3 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def __init__(
layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity},
activation=activation,
norm_layer=[layer_norm_cls] * stages,
norm_args=[
{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)
],
norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)],
),
nn.Flatten(-3, -1),
)
Expand Down
20 changes: 8 additions & 12 deletions sheeprl/data/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,10 @@ def to_tensor(
return buf

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None:
"""Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten.
Expand Down Expand Up @@ -617,12 +615,10 @@ def __len__(self) -> int:
return self.buffer_size

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(
self,
Expand Down Expand Up @@ -860,17 +856,17 @@ def __len__(self) -> int:
return self._cum_lengths[-1] if len(self._buf) > 0 else 0

@typing.overload
def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None:
...
def add(
self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False
) -> None: ...

@typing.overload
def add(
self,
data: Dict[str, np.ndarray],
env_idxes: Sequence[int] | None = None,
validate_args: bool = False,
) -> None:
...
) -> None: ...

def add(
self,
Expand Down
2 changes: 2 additions & 0 deletions sheeprl/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class OneHotCategoricalValidateArgs(Distribution):
probs (Tensor): event probabilities
logits (Tensor): event log probabilities (unnormalized)
"""

arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot
has_enumerate_support = True
Expand Down Expand Up @@ -391,6 +392,7 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs
[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
(Bengio et al, 2013)
"""

has_rsample = True

def rsample(self, sample_shape=torch.Size()):
Expand Down

0 comments on commit 284683c

Please sign in to comment.