Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add scheduler for alpha/beta parameters of PrioritizedSampler #2452

Merged
merged 5 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.scheduler import (
LinearScheduler,
SchedulerList,
StepScheduler,
)

from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
Expand Down Expand Up @@ -99,6 +104,7 @@
VecNorm,
)


OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
_has_gym = importlib.util.find_spec("gym") is not None
Expand Down Expand Up @@ -3026,6 +3032,51 @@ def test_prioritized_slice_sampler_episodes(device):
), "after priority update, only episode 1 and 3 are expected to be sampled"


def test_prioritized_parameter_scheduler():
INIT_ALPHA = 0.7
INIT_BETA = 0.6
GAMMA = 0.1
EVERY_N_STEPS = 10
LINEAR_STEPS = 100
TOTAL_STEPS = 200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's maybe make these args to the func?

rb = TensorDictPrioritizedReplayBuffer(
alpha=INIT_ALPHA, beta=INIT_BETA, storage=ListStorage(max_size=2000)
)
data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000)
rb.extend(data)
alpha_scheduler = LinearScheduler(
rb, param_name="alpha", final_value=0.0, num_steps=LINEAR_STEPS
)
beta_scheduler = StepScheduler(
rb,
param_name="beta",
gamma=GAMMA,
n_steps=EVERY_N_STEPS,
max_value=1.0,
mode="additive",
)
scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler))
expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS + 1)
expected_alpha_vals = np.pad(
expected_alpha_vals, (0, TOTAL_STEPS - LINEAR_STEPS), constant_values=0.0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use torch here

expected_beta_vals = [INIT_BETA]
for _ in range((TOTAL_STEPS // EVERY_N_STEPS - 1)):
expected_beta_vals.append(expected_beta_vals[-1] + GAMMA)
expected_beta_vals = (
np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0)
)
for i in range(TOTAL_STEPS):
assert np.isclose(
rb.sampler.alpha, expected_alpha_vals[i]
), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}"
assert np.isclose(
rb.sampler.beta, expected_beta_vals[i]
), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use torch.testing.assert_close

rb.sample(20)
scheduler.step()


class TestEnsemble:
def _make_data(self, data_type):
if data_type is torch.Tensor:
Expand Down
16 changes: 16 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def __repr__(self):
def max_size(self):
return self._max_capacity

@property
def alpha(self):
return self._alpha

@alpha.setter
def alpha(self, value):
self._alpha = value

@property
def beta(self):
return self._beta

@beta.setter
def beta(self, value):
self._beta = value

def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
Expand Down
240 changes: 240 additions & 0 deletions torchrl/data/replay_buffers/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
from typing import Any, Callable, Dict
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

from .replay_buffers import ReplayBuffer
from .samplers import Sampler
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved


class ParameterScheduler:
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler.

Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.

Args:
rb (ReplayBuffer): the replay buffer whose sampler to adjust
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to None.
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to None
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
min_value: int | float = None,
max_value: int | float = None,
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
):
if not isinstance(obj, ReplayBuffer) and not isinstance(obj, Sampler):
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
f"ParameterScheduler only supports Sampler class. Pass either ReplayBuffer or Sampler object. Got {type(obj)}"
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
)
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
self.param_name = param_name
self._min_val = min_value
self._max_val = max_value
if not hasattr(self.sampler, self.param_name):
raise ValueError(
f"Provided class {obj.__name__} does not have an attribute {param_name}"
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
)
self.initial_val = getattr(self.sampler, self.param_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to copy that? If it's a tensor its value could change in-place

self._step_cnt = 0

def state_dict(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to match the nn.Module.state_dict signature here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that one was actually blindly copied from torch.optim.LRScheduler ehehe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow! Ok then...

"""Return the state of the scheduler as a :class:`dict`.
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

It contains an entry for every variable in self.__dict__ which
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
is not the optimizer.
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
"""
return {key: value for key, value in self.__dict__.items() if key != "sampler"}
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the scheduler's state.

Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)

def step(self):
self._step_cnt += 1
# Apply the step function
new_value = self._step()
# clip value to specified range
new_value_clipped = np.clip(new_value, a_min=self._min_val, a_max=self._max_val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if it's a tensor?
Perhaps we could stick to torch here?
For instance optimizers have the capability of having the LR as a tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. Imo there are two ways of handling this

  1. specifying a backend (torch or numpy) based on the type of the parameter
  2. forcing new_value to be of type float, similar to as its done in torch's ReduceLROnPlateau here. Might be more error prone but in general we would not expect tensors with multiple elements in the scheduler.

What do you think?

# Set the new value of the parameter dynamically
setattr(self.sampler, self.param_name, new_value_clipped)

def _step(self):
raise NotImplementedError

LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

class LambdaScheduler(ParameterScheduler):
"""Sets a parameter to its initial value times a given function.

Similar to torch.optim.LambdaLR.
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
lambda_fn (function): A function which computes a multiplicative factor given an integer
parameter step_count
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to None.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to None
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
lambda_fn: Callable[[int], float],
min_value: int | float = None,
max_value: int | float = None,
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(obj, param_name, min_value, max_value)
self.lambda_fn = lambda_fn

def _step(self):
return self.initial_val * self.lambda_fn(self._step_cnt)


class LinearScheduler(ParameterScheduler):
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps.

This scheduler linearly interpolates between the initial value of the parameter and a final target value.

Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
final_value (Union[int, float]): The final value that the parameter will reach after the
specified number of steps.
num_steps (Union[int, float], optional): The total number of steps over which the parameter
will be linearly altered.
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.7 if step == 1
>>> # beta = 0.8 if step == 2
>>> # beta = 0.9 if step == 3
>>> # beta = 1.0 if step >= 4
>>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
final_value: int | float,
num_steps: int,
):
super().__init__(obj, param_name)
self.final_val = final_value
self.num_steps = num_steps
self._delta = (self.final_val - self.initial_val) / self.num_steps

def _step(self):
if self._step_cnt < self.num_steps:
vmoens marked this conversation as resolved.
Show resolved Hide resolved
return self.initial_val + (self._delta * self._step_cnt)
else:
return self.final_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if we ever want this to be compilable without graph breaks we should think of a way to remove the control flow, eg

return torch.where(self._step_cnt < self.num_steps, self.initial_val + (self._delta * self._step_cnt), self.final_val)

assuming that all of these are tensors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting point. I sticked to the way torch schedulers handle different behavior for different epochs (e.g. here). You think that is okay for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's fine, maybe let's add a comment to let someone know in the future that this should be fixed



class StepScheduler(ParameterScheduler):
"""A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.

The scheduler can apply:
1. Multiplicative changes: `new_val = curr_val * gamma`
2. Additive changes: `new_val = curr_val + gamma`

Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
gamma (int | float, optional): The value by which to adjust the parameter,
either in a multiplicative or additive way
n_steps (int, optional): The number of steps after which the parameter should be altered.
Defaults to 1
mode (str, optional): The mode of scheduling. Can be either 'multiplicative' or 'additive'.
Defaults to 'multiplicative'
min_value (int | float, optional): a lower bound for the parameter to be adjusted
Defaults to None.
max_value (int | float, optional): an upper bound for the parameter to be adjusted
Defaults to None
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.6 if 0 <= step < 10
>>> # beta = 0.7 if 10 <= step < 20
>>> # beta = 0.8 if 20 <= step < 30
>>> # beta = 0.9 if 30 <= step < 40
>>> # beta = 1.0 if 40 <= step
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
gamma: int | float = 0.9,
n_steps: int = 1,
mode: str = "multiplicative",
min_value: int | float = None,
max_value: int | float = None,
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
):

super().__init__(obj, param_name, min_value, max_value)
self.gamma = gamma
self.n_steps = n_steps
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
if mode == "additive":
operator = np.add
elif mode == "multiplicative":
operator = np.multiply
else:
raise ValueError(
f"Invalid mode: {self.mode}. Choose 'multiplicative' or 'additive'."
LTluttmann marked this conversation as resolved.
Show resolved Hide resolved
)
self.operator = operator

def _step(self):
"""Applies the scheduling logic to alter the parameter value every `n_steps`."""
# Check if the current step count is a multiple of n_steps
current_val = getattr(self.sampler, self.param_name)
if self._step_cnt % self.n_steps == 0:
vmoens marked this conversation as resolved.
Show resolved Hide resolved
return self.operator(current_val, self.gamma)
else:
return current_val
Comment on lines +248 to +251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



class SchedulerList:
"""Simple container abstracting a list of schedulers."""

def __init__(self, schedulers: list[ParameterScheduler]) -> None:
if isinstance(schedulers, ParameterScheduler):
schedulers = [schedulers]
self.schedulers = schedulers

def append(self, scheduler: ParameterScheduler):
self.schedulers.append(scheduler)

def step(self):
for scheduler in self.schedulers:
scheduler.step()
Loading