diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 7526b1d8..9cd991a0 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -122,12 +122,18 @@ def as_time_series(self, keys: str | Sequence[str]): return self def broadcast( - self, keys: str | Sequence[str], *, to: str, expand: str | int | tuple = "left", exclude: int | tuple = -1 + self, + keys: str | Sequence[str], + *, + to: str, + expand: str | int | tuple = "left", + exclude: int | tuple = -1, + squeeze: int | tuple = None, ): if isinstance(keys, str): keys = [keys] - transform = Broadcast(keys, to=to, expand=expand, exclude=exclude) + transform = Broadcast(keys, to=to, expand=expand, exclude=exclude, squeeze=squeeze) self.transforms.append(transform) return self