-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pre-conditioning matrix to Barker proposal #731
Add pre-conditioning matrix to Barker proposal #731
Conversation
This is a first draft of adding the pre-conditioning to the Barker proposal. This follows Algorithms 4 and 5 in Appendix G of the original Barker proposal paper. It's somewhat unclear from the paper, but the separate step size that was already implemented serves as a global scale for the normal distribution of the proposal. The function `_compute_acceptance_probability` now takes in the transpose sqrt mass matrix and the inverse, also it has been flattened to accomodate the corresponding matrix multiplicatios.
Fix typing of mass matrix.
The original docstring of step_size was incorrect, there is no sympletic integrator.
We make this possible by adding an identity pre-conditining matrix, which should make the test run in the same way as before.
We add a new test to barker.py to ensure that our implementation of the preconditioning matrix is correct. We follow Appendix G in the paper that mentions that algorithm 4 and 5 (which we implemented) should be equivalent to rescaling the parameters and the logdensity in a specific way. We implement both approaches when using the barker proposal to infer the mean and sigma of a normal distribution. We check that with two different random seeds the chains outputted are equivalent up to some tolerance. We also patch the original test in this file by adding an identity mass matrix.
@AdrienCorenflos I just added a test to specifically check the new implementation with the mass matrix and fixed all the tests that use Barker (by setting the inverse mass matrix to the identity). I'm open to suggestions on the new test or on any other changes. Thank you! |
Thanks for this. I think I now know how we can use general # Callable[[ArrayLikeTree, Arr], ArrayLikeTree]
def scale(position: ArrayLikeTree, vector: ArrayLikeTree, inv=False) -> ArrayLikeTree:
mass_matrix = mass_matrix_fn(position)
# ravel everything etc
if inv:
return triangular_solve(mass_matrix_sqrt, vector) # plus unravelling of course
return mass_matrix_sqrt @ vector # plus unravelling of course @junpenglao would that be ok to add a function like this (essentially scaling a vector by the If this is ok, then we can implement Barker's proposal with any sort of metric, even Riemannian, for cheap: z = metric.scale(y, y, True) - metric.scale(x, x, True) # Be careful here, it may be metric.scale(x, y, True) - metric.scale(y, x, True), need to check the detailed balanced condition properly and same below.
c_x = metric.scale(x, log_x)
c_y = metric.scale(y, log_y) And something similar here There would be some bookkeeping about shapes and all but these seem to be the only change. @ismael-mendoza @junpenglao what do you think? It;s |
This is not even covered in the papers by Livingstone and Zanella: it's a good example of why Blackjax composability is cool :D |
Thanks @AdrienCorenflos that seems like a good idea to me as it avoids the code redundancy and allows us to use Riemannian metrics for free like you pointed out. On your comment here
do you have suggestions for how to check the correct detailed balanced condition? I'm not too familiar with this and I assume the Barker paper wouldn't mention it as they only show the algorithm for a fixed mass matrix. Thanks! |
One quick follow up - am I missing something or would you need to implementations of |
Yes, it's no big deal: essentially, the matrix you used to propose the state needs to be one that ends up in the acceptance: So, here for the proposal we use changes to where So a bit different from what I had written originally but not too far :) Hopefully no typo/forgotten terms! |
We do need this (or similar) to be added I think, but let's wait for @junpenglao he will likely have a better opinion than me on the right API/design choice for these as he has thought about it a lot more than I have. |
Yes sounds like a great addition to |
Hello @AdrienCorenflos and @junpenglao, Thanks @AdrienCorenflos for implemeting the metric scaling. Now I can proceed finishing this PR, I just had two remaining questions. First, I noticed section 6.1 of the barker paper that the authors use the proposed adaptation of Andrieu and Thoms (2008) for their experiments. I think it would be nice to have the full version of their algorithm blackjax and I wondering if this adaptation has already been implemented somewhere for blackjax? or maybe there is already some clear alternative that should be used to adapt hmc-like algorithms? It's unclear to me if the Second, @AdrienCorenflos I was just wondering if you finished the derivation above and if there is some additional term that must be added? Thank you both for your help! |
Oh right, I had completely forgotten about this bit. I'm sure I've got the derivation on a piece of paper somewhere 😬 |
Regarding the adaptation bit, let's check when that's done |
Actually, I think the easiest is if you implement Barker for the standard fixed mass matrix using this new scale thing and then I'll adapt it to the manifold case |
Hi @AdrienCorenflos I noticed something strange in the
am I misunderstanding or the |
Yeah we have been discussing this with @junpenglao |
option to transpose the mass_matrix_sqrt or inv_mass_matrix_sqrt was necessary for the barker algorithm as far as I can tell. This has not been propagated to the riemannian metric
Make acceptance function metric agnostic
Add invariance test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made changes so that the proposal works with Riemannian metrics.
There is one last thing to fix and then we're good and I'll leave it to you @ismael-mendoza
-> the barker_nd sampling function was meant to take fully flat vectors, and here you pass it the metric which acts on trees. You are not seeing issues because your tests do not use nested trees, but this will fail. Can you modify the logic so that it simply does everything using metric on the non-flat vectors? This may require changing a few tests on the function here and there but should be easy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made changes so that the proposal works with Riemannian metrics.
There is one last thing to fix and then we're good and I'll leave it to you @ismael-mendoza
-> the barker_nd sampling function was meant to take fully flat vectors, and here you pass it the metric which acts on trees. You are not seeing issues because your tests do not use nested trees, but this will fail. Can you modify the logic so that it simply does everything using metric on the non-flat vectors? This may require changing a few tests on the function here and there but should be easy
Once this is done, it ships
Thanks @AdrienCorenflos
just to make sure I understand, do you want me to remove |
Well, I think you can just delete the nd version and do everything inside
the `_barker_sample`. It made sense when there was a pure flat logic, but
it's not the case anymore
…On Wed, 2 Oct 2024, 20:21 Ismael Mendoza, ***@***.***> wrote:
Thanks @AdrienCorenflos <https://github.com/AdrienCorenflos>
an you modify the logic so that it simply does everything using metric on the non-flat vectors?
just to make sure I understand, do you want me to remove metric from
_barker_sample_nd entirely and only pass in flat vectors to this
function? and the metric logic should live in _barker_sample ?
—
Reply to this email directly, view it on GitHub
<#731 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEYGFZ5K2TMDKIFUVG3T233ZZRBURAVCNFSM6AAAAABNO57PLGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOBZGUYDQMZXGU>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
oh I think I understand your suggestion now, will modify and commit soon |
@AdrienCorenflos I think I implemented what you had in mind |
Thanks, I'll check tomorrow. @junpenglao given I've touched that pr if you could also have a quick glance that would be useful |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor nit, could you try to see if the suggested changes works?
make inv and trans required kwarg with type bool in metric.scale Co-authored-by: Junpeng Lao <[email protected]>
lax.cond might not be needed in metric.scale as inv and trans are static kwarg Co-authored-by: Junpeng Lao <[email protected]>
@junpenglao thanks for the suggestion, I think works but I see that one of the HMC benchmarks had a regression. Although it's unclear to me how that might have happened after just implementing your suggestion |
Yeah the regression is not related, i will investigate. |
@AdrienCorenflos any more comments? Otherwise LGTM. |
Nope, I had no additional comments when I asked for your opinion :) |
Thank you @ismael-mendoza !! |
* Draft pre-conditioning matrix in Barker proposal. This is a first draft of adding the pre-conditioning to the Barker proposal. This follows Algorithms 4 and 5 in Appendix G of the original Barker proposal paper. It's somewhat unclear from the paper, but the separate step size that was already implemented serves as a global scale for the normal distribution of the proposal. The function `_compute_acceptance_probability` now takes in the transpose sqrt mass matrix and the inverse, also it has been flattened to accomodate the corresponding matrix multiplicatios. * Fix typing of inverse_mass_matrix argument Fix typing of mass matrix. * Fix docstrings. The original docstring of step_size was incorrect, there is no sympletic integrator. * Make test for Barker in test_sampling run again We make this possible by adding an identity pre-conditining matrix, which should make the test run in the same way as before. * Add test to ensure correctness of precond matrix We add a new test to barker.py to ensure that our implementation of the preconditioning matrix is correct. We follow Appendix G in the paper that mentions that algorithm 4 and 5 (which we implemented) should be equivalent to rescaling the parameters and the logdensity in a specific way. We implement both approaches when using the barker proposal to infer the mean and sigma of a normal distribution. We check that with two different random seeds the chains outputted are equivalent up to some tolerance. We also patch the original test in this file by adding an identity mass matrix. * Fix dimensionality of identity matrix * Add missing mass matrix in missing tests. * added option to transpose the matrix when scaling option to transpose the mass_matrix_sqrt or inv_mass_matrix_sqrt was necessary for the barker algorithm as far as I can tell. This has not been propagated to the riemannian metric * use the metric scaling function in barker Here we use the new metric.scale function to perform the operations required by the Barker proposal algorithm, instead of passing around the mass_matrix_sqrt and inv_mass_matrix_sqrt directly. We also make the `inverse_mass_matrix` argument optional to avoid breaking the API. * update test_sampling with barker api the mass matrix is now an optional argument in barker. * update test_barker so it works with metric.scale * fix tests add trans to scale * add trans argument to riemannian scaling * no default * Update barker.py Make acceptance function metric agnostic * Update test_barker.py Add invariance test * simplify logic to remove _barker_sample_nd * fix bug so now everything is tree_mapped in barker * fix test to not use _barker_sample_nd * Update blackjax/mcmc/metrics.py make inv and trans required kwarg with type bool in metric.scale Co-authored-by: Junpeng Lao <[email protected]> * Update blackjax/mcmc/metrics.py lax.cond might not be needed in metric.scale as inv and trans are static kwarg Co-authored-by: Junpeng Lao <[email protected]> * propagate changes of inv, trans as required kwarg * fix test metrics --------- Co-authored-by: Adrien Corenflos <[email protected]> Co-authored-by: Junpeng Lao <[email protected]>
Overview
This PR attempts to add a preconditioning matrix to the barker proposal implementation, as in appendix G of https://arxiv.org/abs/1908.11812
Discussion / Questions
_compute_acceptance_probability
to be flat given how it was necessary to matrix-multiply with the pre-conditioning matrix. Is that OK?Thank you for opening a PR!
A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.