Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Composite replay buffers #1768

Merged
merged 10 commits into from
Jan 9, 2024
Merged

[Feature] Composite replay buffers #1768

merged 10 commits into from
Jan 9, 2024

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jan 3, 2024

Description

In this PR, I propose a composite Replay Buffer API where one can put together multiple replay buffers and sample from each of them either in a structured or random way.

The composite replay buffer can be built in two different ways:

rb = ReplayBufferEnsemble(*list_of_replay_buffers)

or

rb = ReplayBufferEnsemble(storages=StorageEnsemble(*storages), ...)

That way, users have either full control over storage, sampler etc (provided they have the appropriate signatures) or they can just combine them naively.

The RBEnsemble class accepts a p argument that tells how much importance is to be given to each replay buffer during sampling. A collate_fn can be passed too, it will be executed after the collate_fn of each independent replay buffer has been called. A sample_from_all argument allows users to choose whether the sampled batches should contain a sub-batch from each buffer or not. If not, num_buffer_sampled will indicate how many buffers have to be sampled from every time (using the probabilities p if provided).

Samples are stacked along the first dimension, giving samples of shape [B, S] where B=num_buffer_sampled and S=batch_size//num_buffer_sampled.

Thanks to LazyStackedTensorDict, we can stack any kind of data but one may wish to make the data uniform first, for this, independent transforms have to be plugged onto each buffer (eg, key names differ in the buffers or the image sizes are incongruent).

Questions:

  • What should we do with the writers? I think it will be clunky and hard to maintain to have the possibility to write in that buffer when we could simply just write in the buffers independently (ie, the writer should just raise NotImplementedError or RuntimeError when called)
  • when calling rb.__getitem__, we get the item in the storage and massage it a bit (using transforms etc) with a regular replay buffer. Do we want to allow that method in the ensemble too? Maybe with rb[buffer_id, index]?

cc @nicklashansen

Copy link

pytorch-bot bot commented Jan 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1768

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit c965c9e with merge base fd27cb7 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 3, 2024
Copy link

github-actions bot commented Jan 3, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}4$. Worsened: $\large\color{#d91a1a}2$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 66.5333ms 65.8355ms 15.1894 Ops/s 15.5575 Ops/s $\color{#d91a1a}-2.37\%$
test_sync 45.8878ms 35.7992ms 27.9336 Ops/s 29.0743 Ops/s $\color{#d91a1a}-3.92\%$
test_async 0.1396s 34.0399ms 29.3773 Ops/s 29.0986 Ops/s $\color{#35bf28}+0.96\%$
test_simple 0.5144s 0.4543s 2.2011 Ops/s 2.2317 Ops/s $\color{#d91a1a}-1.37\%$
test_transformed 0.6893s 0.6298s 1.5877 Ops/s 1.6131 Ops/s $\color{#d91a1a}-1.57\%$
test_serial 1.4288s 1.4159s 0.7063 Ops/s 0.7206 Ops/s $\color{#d91a1a}-1.99\%$
test_parallel 1.4643s 1.4099s 0.7093 Ops/s 0.7392 Ops/s $\color{#d91a1a}-4.05\%$
test_step_mdp_speed[True-True-True-True-True] 0.1758ms 21.7590μs 45.9580 KOps/s 45.9904 KOps/s $\color{#d91a1a}-0.07\%$
test_step_mdp_speed[True-True-True-True-False] 37.4810μs 13.2630μs 75.3975 KOps/s 75.4736 KOps/s $\color{#d91a1a}-0.10\%$
test_step_mdp_speed[True-True-True-False-True] 37.4600μs 12.7726μs 78.2929 KOps/s 77.5589 KOps/s $\color{#35bf28}+0.95\%$
test_step_mdp_speed[True-True-True-False-False] 32.5610μs 7.7298μs 129.3694 KOps/s 127.6989 KOps/s $\color{#35bf28}+1.31\%$
test_step_mdp_speed[True-True-False-True-True] 49.2310μs 23.1353μs 43.2240 KOps/s 43.4060 KOps/s $\color{#d91a1a}-0.42\%$
test_step_mdp_speed[True-True-False-True-False] 41.5780μs 14.5241μs 68.8513 KOps/s 69.1504 KOps/s $\color{#d91a1a}-0.43\%$
test_step_mdp_speed[True-True-False-False-True] 50.9550μs 14.0375μs 71.2376 KOps/s 71.0467 KOps/s $\color{#35bf28}+0.27\%$
test_step_mdp_speed[True-True-False-False-False] 35.6670μs 9.0403μs 110.6159 KOps/s 108.9582 KOps/s $\color{#35bf28}+1.52\%$
test_step_mdp_speed[True-False-True-True-True] 70.7210μs 24.5598μs 40.7169 KOps/s 40.7271 KOps/s $\color{#d91a1a}-0.03\%$
test_step_mdp_speed[True-False-True-True-False] 36.1970μs 15.9912μs 62.5342 KOps/s 63.5697 KOps/s $\color{#d91a1a}-1.63\%$
test_step_mdp_speed[True-False-True-False-True] 52.1370μs 14.0975μs 70.9345 KOps/s 71.2415 KOps/s $\color{#d91a1a}-0.43\%$
test_step_mdp_speed[True-False-True-False-False] 34.6540μs 9.1422μs 109.3832 KOps/s 109.5730 KOps/s $\color{#d91a1a}-0.17\%$
test_step_mdp_speed[True-False-False-True-True] 60.3720μs 25.8903μs 38.6245 KOps/s 39.2416 KOps/s $\color{#d91a1a}-1.57\%$
test_step_mdp_speed[True-False-False-True-False] 41.9780μs 17.2177μs 58.0798 KOps/s 59.0845 KOps/s $\color{#d91a1a}-1.70\%$
test_step_mdp_speed[True-False-False-False-True] 39.6840μs 15.2528μs 65.5616 KOps/s 65.2408 KOps/s $\color{#35bf28}+0.49\%$
test_step_mdp_speed[True-False-False-False-False] 37.5400μs 10.4137μs 96.0270 KOps/s 98.1069 KOps/s $\color{#d91a1a}-2.12\%$
test_step_mdp_speed[False-True-True-True-True] 52.6990μs 24.4032μs 40.9783 KOps/s 40.7625 KOps/s $\color{#35bf28}+0.53\%$
test_step_mdp_speed[False-True-True-True-False] 44.0730μs 15.9716μs 62.6113 KOps/s 63.8387 KOps/s $\color{#d91a1a}-1.92\%$
test_step_mdp_speed[False-True-True-False-True] 42.8700μs 16.4262μs 60.8783 KOps/s 62.0903 KOps/s $\color{#d91a1a}-1.95\%$
test_step_mdp_speed[False-True-True-False-False] 33.3830μs 10.4324μs 95.8556 KOps/s 97.9388 KOps/s $\color{#d91a1a}-2.13\%$
test_step_mdp_speed[False-True-False-True-True] 55.4740μs 25.7543μs 38.8285 KOps/s 39.1031 KOps/s $\color{#d91a1a}-0.70\%$
test_step_mdp_speed[False-True-False-True-False] 41.0770μs 17.3445μs 57.6552 KOps/s 59.0123 KOps/s $\color{#d91a1a}-2.30\%$
test_step_mdp_speed[False-True-False-False-True] 54.5710μs 17.7152μs 56.4488 KOps/s 57.5823 KOps/s $\color{#d91a1a}-1.97\%$
test_step_mdp_speed[False-True-False-False-False] 39.8940μs 11.5740μs 86.4007 KOps/s 87.4339 KOps/s $\color{#d91a1a}-1.18\%$
test_step_mdp_speed[False-False-True-True-True] 61.6150μs 27.0971μs 36.9044 KOps/s 37.3588 KOps/s $\color{#d91a1a}-1.22\%$
test_step_mdp_speed[False-False-True-True-False] 49.0610μs 18.6908μs 53.5024 KOps/s 54.9454 KOps/s $\color{#d91a1a}-2.63\%$
test_step_mdp_speed[False-False-True-False-True] 54.5220μs 17.8971μs 55.8749 KOps/s 57.5038 KOps/s $\color{#d91a1a}-2.83\%$
test_step_mdp_speed[False-False-True-False-False] 41.1770μs 11.5965μs 86.2328 KOps/s 87.6154 KOps/s $\color{#d91a1a}-1.58\%$
test_step_mdp_speed[False-False-False-True-True] 56.9770μs 28.2018μs 35.4588 KOps/s 36.1169 KOps/s $\color{#d91a1a}-1.82\%$
test_step_mdp_speed[False-False-False-True-False] 71.4430μs 19.7088μs 50.7387 KOps/s 52.4144 KOps/s $\color{#d91a1a}-3.20\%$
test_step_mdp_speed[False-False-False-False-True] 53.1590μs 18.5907μs 53.7902 KOps/s 54.8111 KOps/s $\color{#d91a1a}-1.86\%$
test_step_mdp_speed[False-False-False-False-False] 42.0190μs 12.6898μs 78.8036 KOps/s 80.4447 KOps/s $\color{#d91a1a}-2.04\%$
test_values[generalized_advantage_estimate-True-True] 12.6040ms 12.1700ms 82.1690 Ops/s 83.7964 Ops/s $\color{#d91a1a}-1.94\%$
test_values[vec_generalized_advantage_estimate-True-True] 33.9161ms 27.0678ms 36.9442 Ops/s 36.7131 Ops/s $\color{#35bf28}+0.63\%$
test_values[td0_return_estimate-False-False] 0.2507ms 0.1804ms 5.5428 KOps/s 5.6690 KOps/s $\color{#d91a1a}-2.23\%$
test_values[td1_return_estimate-False-False] 26.6356ms 26.0467ms 38.3926 Ops/s 39.3849 Ops/s $\color{#d91a1a}-2.52\%$
test_values[vec_td1_return_estimate-False-False] 34.8029ms 27.0060ms 37.0289 Ops/s 36.7216 Ops/s $\color{#35bf28}+0.84\%$
test_values[td_lambda_return_estimate-True-False] 39.5485ms 36.8755ms 27.1183 Ops/s 28.3652 Ops/s $\color{#d91a1a}-4.40\%$
test_values[vec_td_lambda_return_estimate-True-False] 35.5449ms 27.1416ms 36.8438 Ops/s 35.8746 Ops/s $\color{#35bf28}+2.70\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 10.7621ms 8.1921ms 122.0683 Ops/s 127.1215 Ops/s $\color{#d91a1a}-3.98\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.3794ms 1.9471ms 513.5786 Ops/s 507.6935 Ops/s $\color{#35bf28}+1.16\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 8.1313ms 0.4379ms 2.2839 KOps/s 2.2828 KOps/s $\color{#35bf28}+0.05\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 44.6121ms 35.9901ms 27.7854 Ops/s 25.7245 Ops/s $\textbf{\color{#35bf28}+8.01\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 10.8606ms 2.6445ms 378.1430 Ops/s 375.6544 Ops/s $\color{#35bf28}+0.66\%$
test_dqn_speed 15.4502ms 7.9214ms 126.2408 Ops/s 119.1806 Ops/s $\textbf{\color{#35bf28}+5.92\%}$
test_ddpg_speed 25.0039ms 15.0543ms 66.4263 Ops/s 68.2076 Ops/s $\color{#d91a1a}-2.61\%$
test_sac_speed 38.2798ms 30.5423ms 32.7414 Ops/s 32.0148 Ops/s $\color{#35bf28}+2.27\%$
test_redq_speed 0.1134s 38.9810ms 25.6535 Ops/s 27.9733 Ops/s $\textbf{\color{#d91a1a}-8.29\%}$
test_redq_deprec_speed 28.7924ms 26.1583ms 38.2288 Ops/s 35.7690 Ops/s $\textbf{\color{#35bf28}+6.88\%}$
test_td3_speed 29.1727ms 20.8101ms 48.0536 Ops/s 48.6749 Ops/s $\color{#d91a1a}-1.28\%$
test_cql_speed 96.2713ms 89.7115ms 11.1468 Ops/s 11.2182 Ops/s $\color{#d91a1a}-0.64\%$
test_a2c_speed 30.8791ms 27.9656ms 35.7582 Ops/s 36.5026 Ops/s $\color{#d91a1a}-2.04\%$
test_ppo_speed 30.8128ms 28.0346ms 35.6702 Ops/s 36.6291 Ops/s $\color{#d91a1a}-2.62\%$
test_reinforce_speed 30.4663ms 26.2784ms 38.0540 Ops/s 37.1868 Ops/s $\color{#35bf28}+2.33\%$
test_iql_speed 65.8526ms 63.9610ms 15.6345 Ops/s 15.6930 Ops/s $\color{#d91a1a}-0.37\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.0679ms 1.4528ms 688.3138 Ops/s 702.5539 Ops/s $\color{#d91a1a}-2.03\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 8.8750ms 0.5250ms 1.9047 KOps/s 1.9372 KOps/s $\color{#d91a1a}-1.67\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 8.7906ms 0.5113ms 1.9558 KOps/s 2.0097 KOps/s $\color{#d91a1a}-2.68\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 2.3757ms 1.4485ms 690.3786 Ops/s 694.8770 Ops/s $\color{#d91a1a}-0.65\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 2.2647ms 0.5141ms 1.9451 KOps/s 1.9302 KOps/s $\color{#35bf28}+0.77\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 8.8160ms 0.5077ms 1.9698 KOps/s 1.9820 KOps/s $\color{#d91a1a}-0.62\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.4254ms 1.6430ms 608.6295 Ops/s 610.8497 Ops/s $\color{#d91a1a}-0.36\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 9.3983ms 0.6637ms 1.5067 KOps/s 1.5147 KOps/s $\color{#d91a1a}-0.53\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 9.1908ms 0.6479ms 1.5436 KOps/s 1.5729 KOps/s $\color{#d91a1a}-1.87\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.1381ms 1.4622ms 683.9166 Ops/s 587.9130 Ops/s $\textbf{\color{#35bf28}+16.33\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.1303s 0.6649ms 1.5040 KOps/s 1.8934 KOps/s $\textbf{\color{#d91a1a}-20.56\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 8.8955ms 0.5180ms 1.9304 KOps/s 1.9729 KOps/s $\color{#d91a1a}-2.15\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 1.8455ms 1.4388ms 695.0310 Ops/s 700.0512 Ops/s $\color{#d91a1a}-0.72\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6405ms 0.5145ms 1.9436 KOps/s 1.9112 KOps/s $\color{#35bf28}+1.70\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 6.5954ms 0.5093ms 1.9636 KOps/s 1.9722 KOps/s $\color{#d91a1a}-0.43\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.4701ms 1.6556ms 604.0186 Ops/s 612.9676 Ops/s $\color{#d91a1a}-1.46\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 9.3296ms 0.6666ms 1.5003 KOps/s 1.5082 KOps/s $\color{#d91a1a}-0.53\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 8.6120ms 0.6537ms 1.5297 KOps/s 1.5294 KOps/s $\color{#35bf28}+0.02\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1342s 17.0665ms 58.5945 Ops/s 58.2524 Ops/s $\color{#35bf28}+0.59\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 14.9078ms 12.3163ms 81.1932 Ops/s 81.1796 Ops/s $\color{#35bf28}+0.02\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 2.1373ms 1.5306ms 653.3206 Ops/s 663.3238 Ops/s $\color{#d91a1a}-1.51\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1271s 17.1688ms 58.2453 Ops/s 58.8684 Ops/s $\color{#d91a1a}-1.06\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 14.4560ms 12.2702ms 81.4984 Ops/s 81.3410 Ops/s $\color{#35bf28}+0.19\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 2.1182ms 1.5339ms 651.9471 Ops/s 652.2481 Ops/s $\color{#d91a1a}-0.05\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1148s 16.8067ms 59.5002 Ops/s 59.6353 Ops/s $\color{#d91a1a}-0.23\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 15.2751ms 12.5077ms 79.9505 Ops/s 78.9771 Ops/s $\color{#35bf28}+1.23\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 2.3429ms 1.7305ms 577.8766 Ops/s 578.3555 Ops/s $\color{#d91a1a}-0.08\%$

Copy link

github-actions bot commented Jan 3, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 92. Improved: $\large\color{#35bf28}5$. Worsened: $\large\color{#d91a1a}2$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1207s 0.1205s 8.2972 Ops/s 8.3323 Ops/s $\color{#d91a1a}-0.42\%$
test_sync 0.1821s 0.1107s 9.0337 Ops/s 9.1032 Ops/s $\color{#d91a1a}-0.76\%$
test_async 0.2716s 99.8951ms 10.0105 Ops/s 10.0067 Ops/s $\color{#35bf28}+0.04\%$
test_single_pixels 0.1339s 0.1317s 7.5955 Ops/s 7.0107 Ops/s $\textbf{\color{#35bf28}+8.34\%}$
test_sync_pixels 99.0392ms 96.3589ms 10.3779 Ops/s 10.4772 Ops/s $\color{#d91a1a}-0.95\%$
test_async_pixels 0.2521s 91.8474ms 10.8876 Ops/s 11.0110 Ops/s $\color{#d91a1a}-1.12\%$
test_simple 0.9724s 0.8939s 1.1187 Ops/s 1.1124 Ops/s $\color{#35bf28}+0.57\%$
test_transformed 1.1974s 1.1317s 0.8836 Ops/s 0.8672 Ops/s $\color{#35bf28}+1.89\%$
test_serial 2.5744s 2.5224s 0.3965 Ops/s 0.3930 Ops/s $\color{#35bf28}+0.87\%$
test_parallel 2.5424s 2.4910s 0.4014 Ops/s 0.4079 Ops/s $\color{#d91a1a}-1.59\%$
test_step_mdp_speed[True-True-True-True-True] 0.1074ms 32.6617μs 30.6169 KOps/s 29.6390 KOps/s $\color{#35bf28}+3.30\%$
test_step_mdp_speed[True-True-True-True-False] 52.3510μs 19.7836μs 50.5468 KOps/s 50.7845 KOps/s $\color{#d91a1a}-0.47\%$
test_step_mdp_speed[True-True-True-False-True] 36.6310μs 19.2597μs 51.9219 KOps/s 51.8314 KOps/s $\color{#35bf28}+0.17\%$
test_step_mdp_speed[True-True-True-False-False] 30.9100μs 11.4075μs 87.6614 KOps/s 86.5647 KOps/s $\color{#35bf28}+1.27\%$
test_step_mdp_speed[True-True-False-True-True] 59.4410μs 35.4737μs 28.1899 KOps/s 27.8682 KOps/s $\color{#35bf28}+1.15\%$
test_step_mdp_speed[True-True-False-True-False] 45.3610μs 21.4970μs 46.5181 KOps/s 45.4105 KOps/s $\color{#35bf28}+2.44\%$
test_step_mdp_speed[True-True-False-False-True] 46.5900μs 20.7810μs 48.1208 KOps/s 45.6825 KOps/s $\textbf{\color{#35bf28}+5.34\%}$
test_step_mdp_speed[True-True-False-False-False] 33.9400μs 13.3901μs 74.6821 KOps/s 74.3020 KOps/s $\color{#35bf28}+0.51\%$
test_step_mdp_speed[True-False-True-True-True] 67.7000μs 36.8970μs 27.1025 KOps/s 26.3535 KOps/s $\color{#35bf28}+2.84\%$
test_step_mdp_speed[True-False-True-True-False] 43.3000μs 23.3202μs 42.8812 KOps/s 41.1847 KOps/s $\color{#35bf28}+4.12\%$
test_step_mdp_speed[True-False-True-False-True] 44.8500μs 20.8158μs 48.0405 KOps/s 45.9308 KOps/s $\color{#35bf28}+4.59\%$
test_step_mdp_speed[True-False-True-False-False] 28.5600μs 13.4168μs 74.5337 KOps/s 74.2254 KOps/s $\color{#35bf28}+0.42\%$
test_step_mdp_speed[True-False-False-True-True] 68.9910μs 38.2449μs 26.1472 KOps/s 25.3451 KOps/s $\color{#35bf28}+3.17\%$
test_step_mdp_speed[True-False-False-True-False] 51.7410μs 25.4663μs 39.2676 KOps/s 38.1301 KOps/s $\color{#35bf28}+2.98\%$
test_step_mdp_speed[True-False-False-False-True] 42.9610μs 22.7048μs 44.0436 KOps/s 42.3749 KOps/s $\color{#35bf28}+3.94\%$
test_step_mdp_speed[True-False-False-False-False] 52.4700μs 15.1918μs 65.8248 KOps/s 64.9290 KOps/s $\color{#35bf28}+1.38\%$
test_step_mdp_speed[False-True-True-True-True] 64.2800μs 36.4773μs 27.4143 KOps/s 26.3549 KOps/s $\color{#35bf28}+4.02\%$
test_step_mdp_speed[False-True-True-True-False] 94.8510μs 23.5533μs 42.4569 KOps/s 41.3397 KOps/s $\color{#35bf28}+2.70\%$
test_step_mdp_speed[False-True-True-False-True] 50.2000μs 24.8987μs 40.1627 KOps/s 38.0523 KOps/s $\textbf{\color{#35bf28}+5.55\%}$
test_step_mdp_speed[False-True-True-False-False] 34.2600μs 15.0430μs 66.4762 KOps/s 64.9331 KOps/s $\color{#35bf28}+2.38\%$
test_step_mdp_speed[False-True-False-True-True] 67.4410μs 38.9981μs 25.6423 KOps/s 25.3545 KOps/s $\color{#35bf28}+1.14\%$
test_step_mdp_speed[False-True-False-True-False] 51.2100μs 25.6339μs 39.0108 KOps/s 38.4597 KOps/s $\color{#35bf28}+1.43\%$
test_step_mdp_speed[False-True-False-False-True] 44.8310μs 27.0379μs 36.9851 KOps/s 35.6810 KOps/s $\color{#35bf28}+3.65\%$
test_step_mdp_speed[False-True-False-False-False] 88.4610μs 16.9142μs 59.1220 KOps/s 57.4791 KOps/s $\color{#35bf28}+2.86\%$
test_step_mdp_speed[False-False-True-True-True] 0.1192ms 40.7806μs 24.5214 KOps/s 24.1259 KOps/s $\color{#35bf28}+1.64\%$
test_step_mdp_speed[False-False-True-True-False] 52.8010μs 27.9174μs 35.8200 KOps/s 35.6174 KOps/s $\color{#35bf28}+0.57\%$
test_step_mdp_speed[False-False-True-False-True] 93.8310μs 27.2793μs 36.6578 KOps/s 35.5998 KOps/s $\color{#35bf28}+2.97\%$
test_step_mdp_speed[False-False-True-False-False] 40.0520μs 17.0705μs 58.5807 KOps/s 57.8333 KOps/s $\color{#35bf28}+1.29\%$
test_step_mdp_speed[False-False-False-True-True] 81.6300μs 42.6841μs 23.4279 KOps/s 23.3692 KOps/s $\color{#35bf28}+0.25\%$
test_step_mdp_speed[False-False-False-True-False] 54.9800μs 29.5258μs 33.8687 KOps/s 33.6158 KOps/s $\color{#35bf28}+0.75\%$
test_step_mdp_speed[False-False-False-False-True] 55.8110μs 28.8163μs 34.7026 KOps/s 34.3729 KOps/s $\color{#35bf28}+0.96\%$
test_step_mdp_speed[False-False-False-False-False] 38.8500μs 18.6645μs 53.5777 KOps/s 52.3722 KOps/s $\color{#35bf28}+2.30\%$
test_values[generalized_advantage_estimate-True-True] 23.4808ms 22.9264ms 43.6179 Ops/s 43.7657 Ops/s $\color{#d91a1a}-0.34\%$
test_values[vec_generalized_advantage_estimate-True-True] 86.3864ms 3.2750ms 305.3477 Ops/s 311.2794 Ops/s $\color{#d91a1a}-1.91\%$
test_values[td0_return_estimate-False-False] 95.7420μs 59.9290μs 16.6864 KOps/s 16.7947 KOps/s $\color{#d91a1a}-0.64\%$
test_values[td1_return_estimate-False-False] 50.1425ms 49.7486ms 20.1011 Ops/s 20.1366 Ops/s $\color{#d91a1a}-0.18\%$
test_values[vec_td1_return_estimate-False-False] 2.0016ms 1.7342ms 576.6197 Ops/s 577.7087 Ops/s $\color{#d91a1a}-0.19\%$
test_values[td_lambda_return_estimate-True-False] 81.9325ms 79.8371ms 12.5255 Ops/s 12.5211 Ops/s $\color{#35bf28}+0.04\%$
test_values[vec_td_lambda_return_estimate-True-False] 2.0111ms 1.7329ms 577.0631 Ops/s 579.7247 Ops/s $\color{#d91a1a}-0.46\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 21.9661ms 21.7653ms 45.9446 Ops/s 46.7835 Ops/s $\color{#d91a1a}-1.79\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 0.8152ms 0.6748ms 1.4820 KOps/s 1.5005 KOps/s $\color{#d91a1a}-1.23\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.6911ms 0.6291ms 1.5896 KOps/s 1.6034 KOps/s $\color{#d91a1a}-0.86\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.4874ms 1.4351ms 696.7960 Ops/s 699.7321 Ops/s $\color{#d91a1a}-0.42\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.9171ms 0.6513ms 1.5354 KOps/s 1.5517 KOps/s $\color{#d91a1a}-1.05\%$
test_dqn_speed 7.9188ms 7.5892ms 131.7655 Ops/s 130.9254 Ops/s $\color{#35bf28}+0.64\%$
test_ddpg_speed 15.7385ms 14.9625ms 66.8336 Ops/s 61.3633 Ops/s $\textbf{\color{#35bf28}+8.91\%}$
test_sac_speed 30.7745ms 29.9298ms 33.4115 Ops/s 33.2759 Ops/s $\color{#35bf28}+0.41\%$
test_redq_speed 36.8373ms 35.4743ms 28.1894 Ops/s 27.9406 Ops/s $\color{#35bf28}+0.89\%$
test_redq_deprec_speed 25.8377ms 24.8975ms 40.1646 Ops/s 40.1586 Ops/s $\color{#35bf28}+0.01\%$
test_td3_speed 29.2499ms 20.4054ms 49.0067 Ops/s 49.0738 Ops/s $\color{#d91a1a}-0.14\%$
test_cql_speed 88.8533ms 87.0682ms 11.4853 Ops/s 11.5685 Ops/s $\color{#d91a1a}-0.72\%$
test_a2c_speed 28.9193ms 27.9918ms 35.7248 Ops/s 35.8621 Ops/s $\color{#d91a1a}-0.38\%$
test_ppo_speed 30.4202ms 28.3667ms 35.2526 Ops/s 35.6118 Ops/s $\color{#d91a1a}-1.01\%$
test_reinforce_speed 28.1678ms 27.2443ms 36.7049 Ops/s 36.8563 Ops/s $\color{#d91a1a}-0.41\%$
test_iql_speed 0.1584s 66.1541ms 15.1162 Ops/s 16.4931 Ops/s $\textbf{\color{#d91a1a}-8.35\%}$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.2688ms 1.8950ms 527.6946 Ops/s 523.8007 Ops/s $\color{#35bf28}+0.74\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 2.1852ms 0.8370ms 1.1947 KOps/s 1.1985 KOps/s $\color{#d91a1a}-0.31\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.9946ms 0.8233ms 1.2147 KOps/s 1.2075 KOps/s $\color{#35bf28}+0.59\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 2.3940ms 1.8679ms 535.3633 Ops/s 530.4679 Ops/s $\color{#35bf28}+0.92\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 2.0135ms 0.8240ms 1.2136 KOps/s 1.2141 KOps/s $\color{#d91a1a}-0.04\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.9699ms 0.8132ms 1.2297 KOps/s 1.2263 KOps/s $\color{#35bf28}+0.28\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.7653ms 2.1509ms 464.9265 Ops/s 405.2037 Ops/s $\textbf{\color{#35bf28}+14.74\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 3.4013ms 0.9562ms 1.0458 KOps/s 1.0444 KOps/s $\color{#35bf28}+0.13\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 1.0802ms 0.9456ms 1.0575 KOps/s 1.0456 KOps/s $\color{#35bf28}+1.14\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 2.5331ms 1.9030ms 525.4965 Ops/s 519.3719 Ops/s $\color{#35bf28}+1.18\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 2.1218ms 0.8360ms 1.1962 KOps/s 1.1940 KOps/s $\color{#35bf28}+0.18\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.9959ms 0.8256ms 1.2113 KOps/s 1.2113 KOps/s $+0.00\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 1.9420ms 1.8486ms 540.9486 Ops/s 528.1226 Ops/s $\color{#35bf28}+2.43\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.1453s 1.0246ms 976.0207 Ops/s 1.2087 KOps/s $\textbf{\color{#d91a1a}-19.25\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.9982ms 0.8151ms 1.2268 KOps/s 1.2230 KOps/s $\color{#35bf28}+0.31\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 2.7487ms 2.1568ms 463.6410 Ops/s 456.6340 Ops/s $\color{#35bf28}+1.53\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 3.4110ms 0.9568ms 1.0452 KOps/s 1.0409 KOps/s $\color{#35bf28}+0.41\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 1.1134ms 0.9472ms 1.0558 KOps/s 1.0517 KOps/s $\color{#35bf28}+0.38\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1326s 17.5201ms 57.0773 Ops/s 54.8393 Ops/s $\color{#35bf28}+4.08\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 15.4900ms 12.1817ms 82.0906 Ops/s 81.9435 Ops/s $\color{#35bf28}+0.18\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 5.8462ms 1.9210ms 520.5501 Ops/s 538.2234 Ops/s $\color{#d91a1a}-3.28\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1228s 17.2253ms 58.0541 Ops/s 57.3649 Ops/s $\color{#35bf28}+1.20\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 15.0388ms 12.1831ms 82.0808 Ops/s 81.6385 Ops/s $\color{#35bf28}+0.54\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 5.9203ms 1.9185ms 521.2315 Ops/s 539.9364 Ops/s $\color{#d91a1a}-3.46\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1219s 17.4039ms 57.4584 Ops/s 56.5541 Ops/s $\color{#35bf28}+1.60\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 15.3453ms 12.2921ms 81.3528 Ops/s 81.2820 Ops/s $\color{#35bf28}+0.09\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 2.7319ms 1.9776ms 505.6674 Ops/s 489.3360 Ops/s $\color{#35bf28}+3.34\%$

@vmoens vmoens added the enhancement New feature or request label Jan 3, 2024
@nicklashansen
Copy link

nicklashansen commented Jan 4, 2024

@vmoens

What should we do with the writers? I think it will be clunky and hard to maintain to have the possibility to write in that buffer when we could simply just write in the buffers independently (ie, the writer should just raise NotImplementedError or RuntimeError when called)

I'm not sure I fully understand the problem here. Are you suggesting that writing to the ensemble buffer would raise an error, but that users can still access and write to each of the buffers in the ensemble? One use case for writing would be multi-task online RL, where the user may wish to maintain separate buffers for each task and perhaps sample data from each task at different rates.

when calling rb.getitem, we get the item in the storage and massage it a bit (using transforms etc) with a regular replay buffer. Do we want to allow that method in the ensemble too? Maybe with rb[buffer_id, index]?

Sounds reasonable to me!

@vmoens
Copy link
Contributor Author

vmoens commented Jan 4, 2024

users can still access and write to each of the buffers in the ensemble

Exactly, that seems to be the most intuitive thing to do. I don't see how extend or add could be routed to the appropriate replay buffer.

vmoens added 2 commits January 4, 2024 16:57
# Conflicts:
#	torchrl/data/replay_buffers/samplers.py
@vmoens
Copy link
Contributor Author

vmoens commented Jan 8, 2024

@nicklashansen Since now you can do ensemble_replay_buffer[0].extend(...) I think the problem of the writer is solved

@vmoens
Copy link
Contributor Author

vmoens commented Jan 8, 2024

@nicklashansen not too far from done, check the doc here https://docs-preview.pytorch.org/pytorch/rl/1768/reference/data.html#composing-datasets

@vmoens vmoens marked this pull request as ready for review January 8, 2024 17:11
@vmoens vmoens changed the title [WIP, Feature] Composite replay buffers [Feature] Composite replay buffers Jan 9, 2024
@vmoens vmoens added the Data Data-related PR, will launch data-related jobs label Jan 9, 2024
@vmoens vmoens merged commit 11a82c3 into main Jan 9, 2024
62 of 65 checks passed
@vmoens vmoens deleted the composite-buffer branch January 9, 2024 10:18
@nicklashansen
Copy link

@vmoens Thanks! It looks pretty straightforward based on the documentation.

I did try it out on the OpenXExperienceReplay dataset, but am getting an error during sampling. Seems like there is some dimension mismatch in the batch size? This is a slightly different use case than the one in the documentation, since I'm working with datasets rather than replay buffers.

Example:

self._dataset = OpenXExperienceReplay(
			dataset_id=cfg.get('dataset_id', 'ucsd_pick_and_place_dataset_converted_externally_to_rlds'),
			shuffle=True,
			slice_len=cfg.horizon+1,
			streaming=False,
			root=cfg.data_dir,
			download=True,
		)
self._datasets = ReplayBufferEnsemble(
			self._dataset,
			self._dataset,
			p=[0.5, 0.5],
		)
self._datasets.sample(32)

results in the error

Traceback (most recent call last):
  File "/data2/nihansen/code/openx/tdmpc2/train.py", line 94, in launch
    train(0, cfg)
  File "/data2/nihansen/code/openx/tdmpc2/train.py", line 70, in train
    buffer=OpenX(cfg),
  File "/data2/nihansen/code/openx/tdmpc2/common/openx.py", line 36, in __init__
    data = self._datasets.sample(32)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 444, in sample
    ret = self._sample(batch_size)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1320, in _sample
    sample, info = super()._sample(*args, **kwargs)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/torchrl/data/replay_buffers/utils.py", line 47, in decorated_fun
    output = fun(self, *args, **kwargs)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 380, in _sample
    info["index"] = index
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/tensordict/_td.py", line 535, in __setitem__
    self._set_tuple(
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/tensordict/_td.py", line 1351, in _set_tuple
    return self._set_str(key[0], value, inplace=inplace, validated=validated)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/tensordict/_td.py", line 1321, in _set_str
    value = self._validate_value(value, check_shape=True)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/tensordict/base.py", line 3817, in _validate_value
    value.batch_size = self.batch_size
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/tensordict/_td.py", line 1271, in batch_size
    self._batch_size_setter(new_size)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/tensordict/base.py", line 477, in _batch_size_setter
    self._check_new_batch_size(new_batch_size)
  File "/data/nihansen/miniconda3/envs/openx/lib/python3.9/site-packages/tensordict/base.py", line 3774, in _check_new_batch_size
    raise RuntimeError(
RuntimeError: the tensor index has shape torch.Size([2, 16]) which is incompatible with the batch-size torch.Size([2, 16, 1]).

Any suggestions for how to resolve this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Data Data-related PR, will launch data-related jobs enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants