-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes to pyro model initialisation & sampling [WIP] #2695
base: main
Are you sure you want to change the base?
Conversation
I don't fully understand the reason for the errors - they don't happen in The difference maybe the timing when the plates are first used. I will look into this later. |
Also this code for posterior sampling is indeed ~2-3x faster but it creates samples of huge observed data matrixes (copies data n_samples times - eg 1000): if isinstance(self.module.guide, poutine.messenger.Messenger):
# This already includes trace-replay behavior.
sample = self.module.guide(*args, **kwargs) An alternative way to deal with this issue would be this: if isinstance(self.module.guide, poutine.messenger.Messenger):
# This already includes trace-replay behavior.
sample = self.module.guide(*args, **kwargs)
# include and exclude requested sites
sample = {k: v for k, v in sample.items() if k in return_sites}
sample = {k: v for k, v in sample.items() if k not in exclude_vars} # this has to be provided by model developer @martinkim0 What do you think we should do? What do you think about the initialisation solution? |
@vitkl hey sorry for the delay, I'm planning on taking a look at this tomorrow! |
This is actually my first time at taking a look at some of our Pyro code - I hadn't really interacted with it before. So I don't really understand the reason why some things are done, e.g., the warmup callbacks. I definitely need to take a deep dive into all of this. However, it looks like both Regarding the sampling changes, would it be possible to include that in a separate PR? And then we can discuss that there. Thanks! |
Just a brief reply. Happy to have a zoom call about pyro. Pyro automatic variational distribution (Guide) doesn’t have any parameters until you do a first pass through the model and guide. When moving my code to multi-GPU training I found that this needs to be done in setup step of the Lightning workflow - otherwise parameters created on GPU don’t get moved between devices correctly - so it’s it would not in on_train_start. However, in the latest version the setup step also doesn’t work - as reported in the original issue. Moving the code to this function and calling it before using any Lightning workflow steps seems to solve the problem for cell2location and my other project. Actually the reason for the errors might be resolved if you call both the model and guide with one batch (it’s possibly the issue with LDA model that uses a custom guide). |
args, kwargs = pl_module.module._get_fn_args_from_batch(tens) | ||
pyro_guide(*args, **kwargs) | ||
break | ||
for tensors in dataloader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to do next(iter(dataloader)) to get a single batch. I think still having the class makes sense. Within this class, there can be a manual_start function.
break | ||
|
||
|
||
class PyroModelGuideWarmup(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do those two classes exist in the first place?
Please split into two PRs. One for the warmup changes and one for the inference changes. This makes it easier to follow changes. |
Addresses #2616
Replaces #1805