-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add scheduler for alpha/beta parameters of PrioritizedSampler #2452
Changes from 2 commits
e2337ef
b4dca1b
5aa2a05
915d1c4
4b2897a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,11 @@ | |
SliceSampler, | ||
SliceSamplerWithoutReplacement, | ||
) | ||
from torchrl.data.replay_buffers.scheduler import ( | ||
LinearScheduler, | ||
SchedulerList, | ||
StepScheduler, | ||
) | ||
|
||
from torchrl.data.replay_buffers.storages import ( | ||
LazyMemmapStorage, | ||
|
@@ -99,6 +104,7 @@ | |
VecNorm, | ||
) | ||
|
||
|
||
OLD_TORCH = parse(torch.__version__) < parse("2.0.0") | ||
_has_tv = importlib.util.find_spec("torchvision") is not None | ||
_has_gym = importlib.util.find_spec("gym") is not None | ||
|
@@ -3026,6 +3032,51 @@ def test_prioritized_slice_sampler_episodes(device): | |
), "after priority update, only episode 1 and 3 are expected to be sampled" | ||
|
||
|
||
def test_prioritized_parameter_scheduler(): | ||
INIT_ALPHA = 0.7 | ||
INIT_BETA = 0.6 | ||
GAMMA = 0.1 | ||
EVERY_N_STEPS = 10 | ||
LINEAR_STEPS = 100 | ||
TOTAL_STEPS = 200 | ||
rb = TensorDictPrioritizedReplayBuffer( | ||
alpha=INIT_ALPHA, beta=INIT_BETA, storage=ListStorage(max_size=2000) | ||
) | ||
data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000) | ||
rb.extend(data) | ||
alpha_scheduler = LinearScheduler( | ||
rb, param_name="alpha", final_value=0.0, num_steps=LINEAR_STEPS | ||
) | ||
beta_scheduler = StepScheduler( | ||
rb, | ||
param_name="beta", | ||
gamma=GAMMA, | ||
n_steps=EVERY_N_STEPS, | ||
max_value=1.0, | ||
mode="additive", | ||
) | ||
scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler)) | ||
expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS + 1) | ||
expected_alpha_vals = np.pad( | ||
expected_alpha_vals, (0, TOTAL_STEPS - LINEAR_STEPS), constant_values=0.0 | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use torch here |
||
expected_beta_vals = [INIT_BETA] | ||
for _ in range((TOTAL_STEPS // EVERY_N_STEPS - 1)): | ||
expected_beta_vals.append(expected_beta_vals[-1] + GAMMA) | ||
expected_beta_vals = ( | ||
np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0) | ||
) | ||
for i in range(TOTAL_STEPS): | ||
assert np.isclose( | ||
rb.sampler.alpha, expected_alpha_vals[i] | ||
), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}" | ||
assert np.isclose( | ||
rb.sampler.beta, expected_beta_vals[i] | ||
), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use |
||
rb.sample(20) | ||
scheduler.step() | ||
|
||
|
||
class TestEnsemble: | ||
def _make_data(self, data_type): | ||
if data_type is torch.Tensor: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
from typing import Any, Callable, Dict | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import numpy as np | ||
|
||
from .replay_buffers import ReplayBuffer | ||
from .samplers import Sampler | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class ParameterScheduler: | ||
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler. | ||
|
||
Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler. | ||
|
||
Args: | ||
rb (ReplayBuffer): the replay buffer whose sampler to adjust | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter | ||
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted | ||
Defaults to None. | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted | ||
Defaults to None | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
obj: ReplayBuffer | Sampler, | ||
param_name: str, | ||
min_value: int | float = None, | ||
max_value: int | float = None, | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
if not isinstance(obj, ReplayBuffer) and not isinstance(obj, Sampler): | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise TypeError( | ||
f"ParameterScheduler only supports Sampler class. Pass either ReplayBuffer or Sampler object. Got {type(obj)}" | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj | ||
self.param_name = param_name | ||
self._min_val = min_value | ||
self._max_val = max_value | ||
if not hasattr(self.sampler, self.param_name): | ||
raise ValueError( | ||
f"Provided class {obj.__name__} does not have an attribute {param_name}" | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
self.initial_val = getattr(self.sampler, self.param_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to copy that? If it's a tensor its value could change in-place |
||
self._step_cnt = 0 | ||
|
||
def state_dict(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to match the nn.Module.state_dict signature here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that one was actually blindly copied from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh wow! Ok then... |
||
"""Return the state of the scheduler as a :class:`dict`. | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
It contains an entry for every variable in self.__dict__ which | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
is not the optimizer. | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
return {key: value for key, value in self.__dict__.items() if key != "sampler"} | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]): | ||
"""Load the scheduler's state. | ||
|
||
Args: | ||
state_dict (dict): scheduler state. Should be an object returned | ||
from a call to :meth:`state_dict`. | ||
""" | ||
self.__dict__.update(state_dict) | ||
|
||
def step(self): | ||
self._step_cnt += 1 | ||
# Apply the step function | ||
new_value = self._step() | ||
# clip value to specified range | ||
new_value_clipped = np.clip(new_value, a_min=self._min_val, a_max=self._max_val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if it's a tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch. Imo there are two ways of handling this
What do you think? |
||
# Set the new value of the parameter dynamically | ||
setattr(self.sampler, self.param_name, new_value_clipped) | ||
|
||
def _step(self): | ||
raise NotImplementedError | ||
|
||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
class LambdaScheduler(ParameterScheduler): | ||
"""Sets a parameter to its initial value times a given function. | ||
|
||
Similar to torch.optim.LambdaLR. | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself) | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the | ||
beta parameter | ||
lambda_fn (function): A function which computes a multiplicative factor given an integer | ||
parameter step_count | ||
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted | ||
Defaults to None. | ||
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted | ||
Defaults to None | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
obj: ReplayBuffer | Sampler, | ||
param_name: str, | ||
lambda_fn: Callable[[int], float], | ||
min_value: int | float = None, | ||
max_value: int | float = None, | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
super().__init__(obj, param_name, min_value, max_value) | ||
self.lambda_fn = lambda_fn | ||
|
||
def _step(self): | ||
return self.initial_val * self.lambda_fn(self._step_cnt) | ||
|
||
|
||
class LinearScheduler(ParameterScheduler): | ||
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps. | ||
|
||
This scheduler linearly interpolates between the initial value of the parameter and a final target value. | ||
|
||
Args: | ||
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself) | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the | ||
beta parameter | ||
final_value (Union[int, float]): The final value that the parameter will reach after the | ||
specified number of steps. | ||
num_steps (Union[int, float], optional): The total number of steps over which the parameter | ||
will be linearly altered. | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Example: | ||
>>> # xdoctest: +SKIP | ||
>>> # Assuming sampler uses initial beta = 0.6 | ||
>>> # beta = 0.7 if step == 1 | ||
>>> # beta = 0.8 if step == 2 | ||
>>> # beta = 0.9 if step == 3 | ||
>>> # beta = 1.0 if step >= 4 | ||
>>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4) | ||
>>> for epoch in range(100): | ||
>>> train(...) | ||
>>> validate(...) | ||
>>> scheduler.step() | ||
""" | ||
|
||
def __init__( | ||
self, | ||
obj: ReplayBuffer | Sampler, | ||
param_name: str, | ||
final_value: int | float, | ||
num_steps: int, | ||
): | ||
super().__init__(obj, param_name) | ||
self.final_val = final_value | ||
self.num_steps = num_steps | ||
self._delta = (self.final_val - self.initial_val) / self.num_steps | ||
|
||
def _step(self): | ||
if self._step_cnt < self.num_steps: | ||
vmoens marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.initial_val + (self._delta * self._step_cnt) | ||
else: | ||
return self.final_val | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: if we ever want this to be compilable without graph breaks we should think of a way to remove the control flow, eg return torch.where(self._step_cnt < self.num_steps, self.initial_val + (self._delta * self._step_cnt), self.final_val) assuming that all of these are tensors There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interesting point. I sticked to the way torch schedulers handle different behavior for different epochs (e.g. here). You think that is okay for now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that's fine, maybe let's add a comment to let someone know in the future that this should be fixed |
||
|
||
|
||
class StepScheduler(ParameterScheduler): | ||
"""A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes. | ||
|
||
The scheduler can apply: | ||
1. Multiplicative changes: `new_val = curr_val * gamma` | ||
2. Additive changes: `new_val = curr_val + gamma` | ||
|
||
Args: | ||
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself) | ||
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the | ||
beta parameter | ||
gamma (int | float, optional): The value by which to adjust the parameter, | ||
either in a multiplicative or additive way | ||
n_steps (int, optional): The number of steps after which the parameter should be altered. | ||
Defaults to 1 | ||
mode (str, optional): The mode of scheduling. Can be either 'multiplicative' or 'additive'. | ||
Defaults to 'multiplicative' | ||
min_value (int | float, optional): a lower bound for the parameter to be adjusted | ||
Defaults to None. | ||
max_value (int | float, optional): an upper bound for the parameter to be adjusted | ||
Defaults to None | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Example: | ||
>>> # xdoctest: +SKIP | ||
>>> # Assuming sampler uses initial beta = 0.6 | ||
>>> # beta = 0.6 if 0 <= step < 10 | ||
>>> # beta = 0.7 if 10 <= step < 20 | ||
>>> # beta = 0.8 if 20 <= step < 30 | ||
>>> # beta = 0.9 if 30 <= step < 40 | ||
>>> # beta = 1.0 if 40 <= step | ||
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0) | ||
>>> for epoch in range(100): | ||
>>> train(...) | ||
>>> validate(...) | ||
>>> scheduler.step() | ||
""" | ||
|
||
def __init__( | ||
self, | ||
obj: ReplayBuffer | Sampler, | ||
param_name: str, | ||
gamma: int | float = 0.9, | ||
n_steps: int = 1, | ||
mode: str = "multiplicative", | ||
min_value: int | float = None, | ||
max_value: int | float = None, | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
|
||
super().__init__(obj, param_name, min_value, max_value) | ||
self.gamma = gamma | ||
self.n_steps = n_steps | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if mode == "additive": | ||
operator = np.add | ||
elif mode == "multiplicative": | ||
operator = np.multiply | ||
else: | ||
raise ValueError( | ||
f"Invalid mode: {self.mode}. Choose 'multiplicative' or 'additive'." | ||
LTluttmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
self.operator = operator | ||
|
||
def _step(self): | ||
"""Applies the scheduling logic to alter the parameter value every `n_steps`.""" | ||
# Check if the current step count is a multiple of n_steps | ||
current_val = getattr(self.sampler, self.param_name) | ||
if self._step_cnt % self.n_steps == 0: | ||
vmoens marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.operator(current_val, self.gamma) | ||
else: | ||
return current_val | ||
Comment on lines
+248
to
+251
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
||
|
||
class SchedulerList: | ||
"""Simple container abstracting a list of schedulers.""" | ||
|
||
def __init__(self, schedulers: list[ParameterScheduler]) -> None: | ||
if isinstance(schedulers, ParameterScheduler): | ||
schedulers = [schedulers] | ||
self.schedulers = schedulers | ||
|
||
def append(self, scheduler: ParameterScheduler): | ||
self.schedulers.append(scheduler) | ||
|
||
def step(self): | ||
for scheduler in self.schedulers: | ||
scheduler.step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's maybe make these args to the func?