Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Nov 8, 2024
1 parent 8c900a9 commit acf845a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
6 changes: 2 additions & 4 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,11 @@ def as_set(self, keys: str | Sequence[str]):
self.transforms.append(transform)
return self

def broadcast(
self, keys: str | Sequence[str], *, to: str, batch_dims_only: bool = True, scalars_to_arrays: bool = True
):
def broadcast(self, keys: str | Sequence[str], *, to: str, batch_dims_only: bool = True):
if isinstance(keys, str):
keys = [keys]

transform = Broadcast(keys, to=to, batch_dims_only=batch_dims_only, scalars_to_arrays=scalars_to_arrays)
transform = Broadcast(keys, to=to, batch_dims_only=batch_dims_only)
self.transforms.append(transform)
return self

Expand Down
25 changes: 23 additions & 2 deletions bayesflow/adapters/transforms/broadcast.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
from collections.abc import Sequence
import numpy as np

from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)

from .transform import Transform


@serializable(package="bayesflow.adapters")
class Broadcast(Transform):
"""
Broadcasts arrays or scalars to the shape of a given other array. Only batch dimensions
will be considered per default, i.e., all but the last dimension.
Examples: #TODO
"""

def __init__(self, keys: Sequence[str], *, to: str, batch_dims_only: bool = True, scalars_to_arrays: bool = True):
def __init__(self, keys: Sequence[str], *, to: str, batch_dims_only: bool = True):
super().__init__()
self.keys = keys
self.to = to
self.batch_dims_only = batch_dims_only
self.scalars_to_arrays = scalars_to_arrays

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
return cls(
keys=deserialize(config["keys"], custom_objects),
to=deserialize(config["to"], custom_objects),
batch_dims_only=deserialize(config["batch_dims_only"], custom_objects),
)

def get_config(self) -> dict:
return {
"keys": serialize(self.keys),
"to": serialize(self.to),
"batch_dims_only": serialize(self.batch_dims_only),
}

# noinspection PyMethodOverriding
def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
Expand Down

0 comments on commit acf845a

Please sign in to comment.