Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add scheduler for alpha/beta parameters of PrioritizedSampler #2452

Merged
merged 5 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.scheduler import (
LinearScheduler,
SchedulerList,
StepScheduler,
)

from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
Expand Down Expand Up @@ -99,6 +104,7 @@
VecNorm,
)


OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
_has_gym = importlib.util.find_spec("gym") is not None
Expand Down Expand Up @@ -3026,6 +3032,77 @@ def test_prioritized_slice_sampler_episodes(device):
), "after priority update, only episode 1 and 3 are expected to be sampled"


@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)])
@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)])
@pytest.mark.parametrize("gamma", [0.1])
@pytest.mark.parametrize("total_steps", [200])
@pytest.mark.parametrize("n_annealing_steps", [100])
@pytest.mark.parametrize("anneal_every_n", [10, 159])
@pytest.mark.parametrize("alpha_min", [0, 0.2])
@pytest.mark.parametrize("beta_max", [1, 1.4])
def test_prioritized_parameter_scheduler(
alpha,
beta,
gamma,
total_steps,
n_annealing_steps,
anneal_every_n,
alpha_min,
beta_max,
):
rb = TensorDictPrioritizedReplayBuffer(
alpha=alpha, beta=beta, storage=ListStorage(max_size=1000)
)
data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000)
rb.extend(data)
alpha_scheduler = LinearScheduler(
rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps
)
beta_scheduler = StepScheduler(
rb,
param_name="beta",
gamma=gamma,
n_steps=anneal_every_n,
max_value=beta_max,
mode="additive",
)

scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler))

alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha)
alpha_min = torch.tensor(alpha_min)
expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1)
expected_alpha_vals = torch.nn.functional.pad(
expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min
)

expected_beta_vals = [beta]
annealing_steps = total_steps // anneal_every_n
gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma
expected_beta_vals = (
(beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max)
)
for i in range(total_steps):
curr_alpha = rb.sampler.alpha
torch.testing.assert_close(
curr_alpha
if torch.is_tensor(curr_alpha)
else torch.tensor(curr_alpha).float(),
expected_alpha_vals[i],
msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}",
)
curr_beta = rb.sampler.beta
torch.testing.assert_close(
curr_beta
if torch.is_tensor(curr_beta)
else torch.tensor(curr_beta).float(),
expected_beta_vals[i],
msg=f"expected {expected_beta_vals[i]}, got {curr_beta}",
)
rb.sample(20)
scheduler.step()


class TestEnsemble:
def _make_data(self, data_type):
if data_type is torch.Tensor:
Expand Down
16 changes: 16 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def __repr__(self):
def max_size(self):
return self._max_capacity

@property
def alpha(self):
return self._alpha

@alpha.setter
def alpha(self, value):
self._alpha = value

@property
def beta(self):
return self._beta

@beta.setter
def beta(self, value):
self._beta = value

def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
Expand Down
263 changes: 263 additions & 0 deletions torchrl/data/replay_buffers/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from abc import ABC, abstractmethod

from typing import Any, Callable, Dict
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

import torch

from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import Sampler


class ParameterScheduler(ABC):
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler.

Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.

Args:
obj (ReplayBuffer or Sampler): the replay buffer or sampler whose sampler to adjust
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to `None`.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to `None`.

"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
min_value: int | float | None = None,
max_value: int | float | None = None,
):
if not isinstance(obj, (ReplayBuffer, Sampler)):
raise TypeError(
f"ParameterScheduler only supports Sampler class. Pass either `ReplayBuffer` or `Sampler` object. Got {type(obj)} instead."
)
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
self.param_name = param_name
self._min_val = min_value or float("-inf")
self._max_val = max_value or float("inf")
if not hasattr(self.sampler, self.param_name):
raise ValueError(
f"Provided class {type(obj).__name__} does not have an attribute {param_name}"
)
initial_val = getattr(self.sampler, self.param_name)
if isinstance(initial_val, torch.Tensor):
initial_val = initial_val.clone()
self.backend = torch
else:
self.backend = np
self.initial_val = initial_val
self._step_cnt = 0

def state_dict(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to match the nn.Module.state_dict signature here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that one was actually blindly copied from torch.optim.LRScheduler ehehe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow! Ok then...

"""Returns the state of the scheduler as a :class:`dict`.

It contains an entry for every variable in ``self.__dict__`` which
is not the sampler.
"""
sd = dict(self.__dict__)
del sd["sampler"]
return sd

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the scheduler's state.

Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)

def step(self):
self._step_cnt += 1
# Apply the step function
new_value = self._step()
# clip value to specified range
new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val)
# Set the new value of the parameter dynamically
setattr(self.sampler, self.param_name, new_value_clipped)

@abstractmethod
def _step(self):
...


class LambdaScheduler(ParameterScheduler):
"""Sets a parameter to its initial value times a given function.

Similar to :class:`~torch.optim.LambdaLR`.

Args:
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter.
lambda_fn (Callable[[int], float]): A function which computes a multiplicative factor given an integer
parameter ``step_count``.
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to `None`.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to `None`.

"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
lambda_fn: Callable[[int], float],
min_value: int | float | None = None,
max_value: int | float | None = None,
):
super().__init__(obj, param_name, min_value, max_value)
self.lambda_fn = lambda_fn

def _step(self):
return self.initial_val * self.lambda_fn(self._step_cnt)


class LinearScheduler(ParameterScheduler):
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps.

This scheduler linearly interpolates between the initial value of the parameter and a final target value.

Args:
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter.
final_value (number): The final value that the parameter will reach after the
specified number of steps.
num_steps (number, optional): The total number of steps over which the parameter
will be linearly altered.

Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.7 if step == 1
>>> # beta = 0.8 if step == 2
>>> # beta = 0.9 if step == 3
>>> # beta = 1.0 if step >= 4
>>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
final_value: int | float,
num_steps: int,
):
super().__init__(obj, param_name)
if isinstance(self.initial_val, torch.Tensor):
# cast to same type as initial value
final_value = torch.tensor(final_value).to(self.initial_val)
self.final_val = final_value
self.num_steps = num_steps
self._delta = (self.final_val - self.initial_val) / self.num_steps

def _step(self):
if self._step_cnt < self.num_steps:
vmoens marked this conversation as resolved.
Show resolved Hide resolved
return self.initial_val + (self._delta * self._step_cnt)
else:
return self.final_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if we ever want this to be compilable without graph breaks we should think of a way to remove the control flow, eg

return torch.where(self._step_cnt < self.num_steps, self.initial_val + (self._delta * self._step_cnt), self.final_val)

assuming that all of these are tensors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting point. I sticked to the way torch schedulers handle different behavior for different epochs (e.g. here). You think that is okay for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's fine, maybe let's add a comment to let someone know in the future that this should be fixed



class StepScheduler(ParameterScheduler):
"""A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.

The scheduler can apply:
1. Multiplicative changes: `new_val = curr_val * gamma`
2. Additive changes: `new_val = curr_val + gamma`

Args:
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter.
gamma (int or float, optional): The value by which to adjust the parameter,
either in a multiplicative or additive way.
n_steps (int, optional): The number of steps after which the parameter should be altered.
Defaults to 1.
mode (str, optional): The mode of scheduling. Can be either `'multiplicative'` or `'additive'`.
Defaults to `'multiplicative'`.
min_value (int or float, optional): a lower bound for the parameter to be adjusted.
Defaults to `None`.
max_value (int or float, optional): an upper bound for the parameter to be adjusted.
Defaults to `None`.

Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.6 if 0 <= step < 10
>>> # beta = 0.7 if 10 <= step < 20
>>> # beta = 0.8 if 20 <= step < 30
>>> # beta = 0.9 if 30 <= step < 40
>>> # beta = 1.0 if 40 <= step
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
gamma: int | float = 0.9,
n_steps: int = 1,
mode: str = "multiplicative",
min_value: int | float | None = None,
max_value: int | float | None = None,
):

super().__init__(obj, param_name, min_value, max_value)
self.gamma = gamma
self.n_steps = n_steps
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
self.mode = mode
if mode == "additive":
operator = self.backend.add
elif mode == "multiplicative":
operator = self.backend.multiply
else:
raise ValueError(
f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'."
)
self.operator = operator

def _step(self):
"""Applies the scheduling logic to alter the parameter value every `n_steps`."""
# Check if the current step count is a multiple of n_steps
current_val = getattr(self.sampler, self.param_name)
if self._step_cnt % self.n_steps == 0:
vmoens marked this conversation as resolved.
Show resolved Hide resolved
return self.operator(current_val, self.gamma)
else:
return current_val
Comment on lines +248 to +251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



class SchedulerList:
"""Simple container abstracting a list of schedulers."""

def __init__(self, schedulers: list[ParameterScheduler]) -> None:
if isinstance(schedulers, ParameterScheduler):
schedulers = [schedulers]
self.schedulers = schedulers

def append(self, scheduler: ParameterScheduler):
self.schedulers.append(scheduler)

def step(self):
for scheduler in self.schedulers:
scheduler.step()
Loading