diff --git a/bayesflow/simulation.py b/bayesflow/simulation.py index 9615957b8..1f5716bed 100644 --- a/bayesflow/simulation.py +++ b/bayesflow/simulation.py @@ -736,15 +736,15 @@ def __init__( parameters or on single parameter vectors via tha `simulator_is_batched` argument. """ - if type(prior) is not Prior: + if not isinstance(prior, Prior): prior_args = {"batch_prior_fun": prior} if prior_is_batched else {"prior_fun": prior} self.prior = Prior(**prior_args) self.prior_is_batched = prior_is_batched else: self.prior = prior - self.prior_is_batched = prior_is_batched + self.prior_is_batched = self.prior.is_batched - if type(simulator) is not Simulator: + if not isinstance(simulator, Simulator): self.simulator = self._config_custom_simulator(simulator, simulator_is_batched) else: self.simulator = simulator