Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made Adapters Sliceable #285

Merged
merged 9 commits into from
Jan 15, 2025
71 changes: 64 additions & 7 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable, Sequence
from collections.abc import Callable, MutableSequence, Sequence

import numpy as np
from keras.saving import (
Expand All @@ -25,17 +25,16 @@
ToArray,
Transform,
)

from .transforms.filter_transform import Predicate


@serializable(package="bayesflow.adapters")
class Adapter:
class Adapter(MutableSequence[Transform]):
def __init__(self, transforms: Sequence[Transform] | None = None):
if transforms is None:
transforms = []

self.transforms = transforms
self.transforms = list(transforms)

@staticmethod
def create_default(inference_variables: Sequence[str]) -> "Adapter":
Expand Down Expand Up @@ -76,12 +75,70 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) ->
return self.forward(data, **kwargs)

def __repr__(self):
return f"Adapter([{' -> '.join(map(repr, self.transforms))}])"
result = ""
for i, transform in enumerate(self):
result += f"{i}: {transform!r}"
if i != len(self) - 1:
result += " -> "

return f"Adapter([{result}])"

# list methods

def append(self, value: Transform) -> "Adapter":
self.transforms.append(value)
return self

def __delitem__(self, key: int | slice):
del self.transforms[key]

def extend(self, values: Sequence[Transform]) -> "Adapter":
if isinstance(values, Adapter):
values = values.transforms

self.transforms.extend(values)

return self

def __getitem__(self, item: int | slice) -> "Adapter":
if isinstance(item, int):
return self.transforms[item]

return Adapter(self.transforms[item])

def insert(self, index: int, value: Transform | Sequence[Transform]) -> "Adapter":
if isinstance(value, Adapter):
value = value.transforms

if isinstance(value, Sequence):
# convenience: Adapters are always flat
self.transforms = self.transforms[:index] + list(value) + self.transforms[index:]
else:
self.transforms.insert(index, value)

return self

def __setitem__(self, key: int | slice, value: Transform | Sequence[Transform]) -> "Adapter":
if isinstance(value, Adapter):
value = value.transforms

if isinstance(key, int) and isinstance(value, Sequence):
if key < 0:
key += len(self.transforms)

key = slice(key, key + 1)

self.transforms[key] = value

def add_transform(self, transform: Transform):
self.transforms.append(transform)
return self

def __len__(self):
return len(self.transforms)

# adapter methods

add_transform = append

def apply(
self,
*,
Expand Down
Loading