Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ContinuousApproximator.sample() fails without previous adapter calls (e.g., when loading data) #255

Open
elseml opened this issue Nov 21, 2024 · 6 comments
Labels
bug Something isn't working
Milestone

Comments

@elseml
Copy link
Member

elseml commented Nov 21, 2024

I noticed that after switching from generating bf.datasets on-the-fly to loading pre-simulated data, ContinuousApproximator.sample() fails since the adapter is not called before sampling anymore. Concretely, in line 141 of continuous_approximator.py, the adapter is called with strict=False to process the observed data (and not require parameter keys while doing so):

conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) 

This raises the following error in the adapters forward() method when working with loaded data:

"ValueError: Cannot call `forward` with `strict=False` before calling `forward` with `strict=True`.". 

The error is easily fixed by manually calling the adapter on the data before sampling, but of course unexpected for the user and should therefore be handled internally.
@LarsKue @stefanradev93: what do you think would be a principled handling of this behavior?

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Nov 21, 2024

Based on how I understand what you are doing, I agree with you that this should be differently handled. Just to make sure I understand you correctly, could you add a small example here that (only) includes the relevant code parts?

@paul-buerkner paul-buerkner added user interface Changes to the user interface and improvements in usability v2 labels Nov 21, 2024
@paul-buerkner paul-buerkner added this to the BayesFlow 2.0 milestone Nov 21, 2024
@paul-buerkner paul-buerkner removed the v2 label Nov 21, 2024
@elseml
Copy link
Member Author

elseml commented Nov 21, 2024

I looked further into the issue, as far as I can see it is caused by the OfflineDataset and approximator no longer referring to the same adapter object in memory:

  • When creating the OfflineDataset right before training, the adapter is already called during approximator.fit() via OfflineDataset.__getitem__ .
  • When loading pre-simulated data, the adapter passed to OfflineDataset does not longer refer to the same adapter in memory that the approximator uses. Thus, approximator.adapter is not called during training, only OfflineDataset.adapter -> sampling fails.

Here is some reduced pseudocode to keep things concise:

Simulating at the beginning does not fail:

adapter = Adapter()
data = OfflineDataset(simulate(), adapter)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
approximator.sample(data)

When the data is loaded from an external source (where the adapter was also supplied to OfflineDataset), sampling fails:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
approximator.sample(data)

Calling the adapter manually before sampling fixes the error:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
_ = adapter(data)
approximator.sample(data)

Creating data manually before sampling does not fix it (i.e., simply creating an OfflineDataset) since the adapter is not called during OfflineDataset construction:

adapter = Adapter()
data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, adapter)
approximator.fit(data)
data_2 = OfflineDataset(simulate(), adapter)
approximator.sample(data_2)

@paul-buerkner
Copy link
Contributor

Thank you! This is very helpful! @LarsKue and @stefanradev93 what are your takes on how to fix this?

@paul-buerkner paul-buerkner added bug Something isn't working and removed user interface Changes to the user interface and improvements in usability labels Nov 21, 2024
@elseml
Copy link
Member Author

elseml commented Nov 21, 2024

Indeed, when passing OfflineDataset.adapter to the approximator, the error is gone (so it is not really a bug but more of an unexpected behavior). But this is a rather unintuitive solution for users that should not be required.

data = load_data(path)
approximator = ContinuousApproximator(summary_net, inference_net, data.adapter)
approximator.fit(data)
approximator.sample(data)

@paul-buerkner
Copy link
Contributor

It will appear to users as a bug because it should just work. In any case, we should fix it before 2.0 release.

@LarsKue
Copy link
Contributor

LarsKue commented Nov 21, 2024

Could be faulty serialization in the Adapter. I will investigate next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants