From 9df95cee1468b43711381e37bdc90c346f285ee8 Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Fri, 13 Dec 2024 13:16:25 +0100 Subject: [PATCH 1/9] in test phase, print statements included but get item implemented, need to define set item, might be nice to list indecies of transforms in the default print --- bayesflow/adapters/adapter.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 0e1c68466..1a6c44fdc 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -78,6 +78,27 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) -> def __repr__(self): return f"Adapter([{' -> '.join(map(repr, self.transforms))}])" + def __getitem__(self, index): + if isinstance(index, slice): + sliced_transforms = self.transforms[index] + print("Are the sliced transforms a sequence") + print(isinstance(sliced_transforms, Sequence)) + print("Is there an associate print method?") + print(sliced_transforms) + + new_adapter = Adapter(transforms = sliced_transforms) + return new_adapter + elif isinstance(index, int): + if index < 0: + index = index + len(self.transforms) # negative indexing + if index < 0 or index >= len(self.transforms): + raise IndexError("Adapter index out of range.") + sliced_transforms = self.transforms[index] + new_adapter = Adapter(transforms = sliced_transforms) + return new_adapter + else: + raise TypeError("Invalid index type. Must be int or slice.") + def add_transform(self, transform: Transform): self.transforms.append(transform) return self @@ -104,6 +125,7 @@ def apply( self.transforms.append(transform) return self + # Begin of transformed derived from transform classes def as_set(self, keys: str | Sequence[str]): if isinstance(keys, str): keys = [keys] From 0e24fa243b74cb10171e746abcced0ecd6c654f2 Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Fri, 13 Dec 2024 14:56:28 +0100 Subject: [PATCH 2/9] finished slice implementation for set item, noticed also that element transform was not added to the list for things to imported in the adapter file so added that --- bayesflow/adapters/adapter.py | 54 +++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 1a6c44fdc..55c65ab2b 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -15,6 +15,7 @@ ConvertDType, Drop, ExpandDims, + ElementwiseTransform, # why wasn't this added before? FilterTransform, Keep, LambdaTransform, @@ -79,15 +80,25 @@ def __repr__(self): return f"Adapter([{' -> '.join(map(repr, self.transforms))}])" def __getitem__(self, index): + if isinstance(index, slice): - sliced_transforms = self.transforms[index] - print("Are the sliced transforms a sequence") - print(isinstance(sliced_transforms, Sequence)) - print("Is there an associate print method?") - print(sliced_transforms) - - new_adapter = Adapter(transforms = sliced_transforms) - return new_adapter + if index.start > index.stop: + raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") + if index.stop < len(self.transforms): + # print("What is the slice?") + # print(index) + # print(type(index)) + # check that the slice is in range + sliced_transforms = self.transforms[index] + # print("Are the sliced transforms a sequence") + # print(isinstance(sliced_transforms, Sequence)) + # print("What is in the slice?") + # print(sliced_transforms) + new_adapter = Adapter(transforms = sliced_transforms) + return new_adapter + else: + raise IndexError("Index slice out of range") + elif isinstance(index, int): if index < 0: index = index + len(self.transforms) # negative indexing @@ -98,6 +109,33 @@ def __getitem__(self, index): return new_adapter else: raise TypeError("Invalid index type. Must be int or slice.") + + + def __setitem__(self, index, new_value): + + if isinstance(index, slice): + if index.start > index.stop: + raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") + if index.stop < len(self.transforms): + new_transform = new_value.transforms + # print("what is self.transforms[index]?") + # print(self.transforms[index]) + # print("what is the value of the newvalue") + # print(new_transform) + # print(type(new_transform)) + self.transforms[index] = new_transform + # else raise theory + else: + raise IndexError("Index slice out of range") + + elif isinstance(index, int): + return + # check if in range + # if not inrange but it is just the len of the transforms (append ) + # + # else raise error + else: + raise TypeError("Invalid index type. Must be int or slice.") def add_transform(self, transform: Transform): self.transforms.append(transform) From 71374ef50ed1ed6100e096224357877e525eecfb Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Tue, 17 Dec 2024 08:12:49 +0100 Subject: [PATCH 3/9] finished adapter slicing, i also made a test file but its not in the commits --- bayesflow/adapters/adapter.py | 39 ++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 55c65ab2b..d6eb8ce52 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -112,28 +112,43 @@ def __getitem__(self, index): def __setitem__(self, index, new_value): + + if not isinstance(new_value, Adapter): + raise TypeError("new_value must be an Adapter instance") + + + new_transform = new_value.transforms + if len(new_transform) == 0: + raise ValueError("new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform.") + + if isinstance(index, slice): if index.start > index.stop: raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") + if index.stop < len(self.transforms): - new_transform = new_value.transforms - # print("what is self.transforms[index]?") - # print(self.transforms[index]) - # print("what is the value of the newvalue") - # print(new_transform) - # print(type(new_transform)) self.transforms[index] = new_transform - # else raise theory + else: raise IndexError("Index slice out of range") + elif isinstance(index, int): - return - # check if in range - # if not inrange but it is just the len of the transforms (append ) - # - # else raise error + if index < 0: # negative indexing + index = index + len(self.transforms) + + if index < 0 or index >= len(self.transforms): + raise IndexError("Index out of range.") + # could add that if the index is out of range, like index == len + # then we just add the transform + print("what is self.transforms[index]?") + print(self.transforms[index]) + print("what is the value of the newvalue") + print(new_transform) + print(type(new_transform)) + + self.transforms[index] = new_transform else: raise TypeError("Invalid index type. Must be int or slice.") From 4fce74420a3dbc4d4ce6ce7edb01f7d96dbcc9a0 Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Tue, 17 Dec 2024 08:24:11 +0100 Subject: [PATCH 4/9] ran linter --- bayesflow/adapters/adapter.py | 75 ++++++++++++++++------------------- 1 file changed, 35 insertions(+), 40 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index d6eb8ce52..f60973e34 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -15,7 +15,6 @@ ConvertDType, Drop, ExpandDims, - ElementwiseTransform, # why wasn't this added before? FilterTransform, Keep, LambdaTransform, @@ -80,78 +79,74 @@ def __repr__(self): return f"Adapter([{' -> '.join(map(repr, self.transforms))}])" def __getitem__(self, index): - - if isinstance(index, slice): - if index.start > index.stop: + if isinstance(index, slice): + if index.start > index.stop: raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") - if index.stop < len(self.transforms): + if index.stop < len(self.transforms): # print("What is the slice?") # print(index) # print(type(index)) - # check that the slice is in range + # check that the slice is in range sliced_transforms = self.transforms[index] # print("Are the sliced transforms a sequence") # print(isinstance(sliced_transforms, Sequence)) # print("What is in the slice?") # print(sliced_transforms) - new_adapter = Adapter(transforms = sliced_transforms) + new_adapter = Adapter(transforms=sliced_transforms) return new_adapter - else: + else: raise IndexError("Index slice out of range") - - elif isinstance(index, int): + + elif isinstance(index, int): if index < 0: - index = index + len(self.transforms) # negative indexing - if index < 0 or index >= len(self.transforms): + index = index + len(self.transforms) # negative indexing + if index < 0 or index >= len(self.transforms): raise IndexError("Adapter index out of range.") sliced_transforms = self.transforms[index] - new_adapter = Adapter(transforms = sliced_transforms) + new_adapter = Adapter(transforms=sliced_transforms) return new_adapter else: raise TypeError("Invalid index type. Must be int or slice.") - - - def __setitem__(self, index, new_value): - if not isinstance(new_value, Adapter): + def __setitem__(self, index, new_value): + if not isinstance(new_value, Adapter): raise TypeError("new_value must be an Adapter instance") - - - new_transform = new_value.transforms - - if len(new_transform) == 0: - raise ValueError("new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform.") + new_transform = new_value.transforms - if isinstance(index, slice): - if index.start > index.stop: + if len(new_transform) == 0: + raise ValueError( + "new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform." + ) + + if isinstance(index, slice): + if index.start > index.stop: raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") - + if index.stop < len(self.transforms): self.transforms[index] = new_transform - - else: + + else: raise IndexError("Index slice out of range") - - elif isinstance(index, int): - if index < 0: # negative indexing + elif isinstance(index, int): + if index < 0: # negative indexing index = index + len(self.transforms) - - if index < 0 or index >= len(self.transforms): + + if index < 0 or index >= len(self.transforms): raise IndexError("Index out of range.") - # could add that if the index is out of range, like index == len - # then we just add the transform + # could add that if the index is out of range, like index == len + # then we just add the transform print("what is self.transforms[index]?") print(self.transforms[index]) print("what is the value of the newvalue") print(new_transform) print(type(new_transform)) - + self.transforms[index] = new_transform - else: - raise TypeError("Invalid index type. Must be int or slice.") - + else: + raise TypeError("Invalid index type. Must be int or slice.") + def add_transform(self, transform: Transform): self.transforms.append(transform) return self @@ -178,7 +173,7 @@ def apply( self.transforms.append(transform) return self - # Begin of transformed derived from transform classes + # Begin of transformed derived from transform classes def as_set(self, keys: str | Sequence[str]): if isinstance(keys, str): keys = [keys] From 78dd5d5551b5f259b99fa7212dc33da99f5ed344 Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Tue, 17 Dec 2024 08:31:41 +0100 Subject: [PATCH 5/9] removed print statements --- bayesflow/adapters/adapter.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index f60973e34..fbd640c73 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -83,15 +83,7 @@ def __getitem__(self, index): if index.start > index.stop: raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") if index.stop < len(self.transforms): - # print("What is the slice?") - # print(index) - # print(type(index)) - # check that the slice is in range sliced_transforms = self.transforms[index] - # print("Are the sliced transforms a sequence") - # print(isinstance(sliced_transforms, Sequence)) - # print("What is in the slice?") - # print(sliced_transforms) new_adapter = Adapter(transforms=sliced_transforms) return new_adapter else: @@ -137,11 +129,6 @@ def __setitem__(self, index, new_value): raise IndexError("Index out of range.") # could add that if the index is out of range, like index == len # then we just add the transform - print("what is self.transforms[index]?") - print(self.transforms[index]) - print("what is the value of the newvalue") - print(new_transform) - print(type(new_transform)) self.transforms[index] = new_transform else: From 2de69c944612cc9eabbf3263c4f6b73d366c2b43 Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Wed, 1 Jan 2025 14:28:03 -0600 Subject: [PATCH 6/9] indexing added to print statements for adapter --- bayesflow/adapters/adapter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index fbd640c73..179d690c5 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -76,7 +76,12 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) -> return self.forward(data, **kwargs) def __repr__(self): - return f"Adapter([{' -> '.join(map(repr, self.transforms))}])" + str_transf = '' + for i in range(0, len(self.transforms)): + str_transf = str_transf + str(i) + ': ' + repr(self.transforms[i]) + if i != len(self.transforms) - 1: + str_transf = str_transf + ' -> ' + return f"Adapter([{str_transf}])" def __getitem__(self, index): if isinstance(index, slice): From 9f6a5110f2301882af0a138e8244e11b889b7c4f Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Sat, 4 Jan 2025 15:13:58 -0600 Subject: [PATCH 7/9] made modifications in line with Lars comments and they passed the tests that I had previously written --- bayesflow/adapters/adapter.py | 61 +++++++++-------------------------- 1 file changed, 16 insertions(+), 45 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 179d690c5..6aa7355ae 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -76,66 +76,37 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) -> return self.forward(data, **kwargs) def __repr__(self): - str_transf = '' - for i in range(0, len(self.transforms)): - str_transf = str_transf + str(i) + ': ' + repr(self.transforms[i]) - if i != len(self.transforms) - 1: - str_transf = str_transf + ' -> ' - return f"Adapter([{str_transf}])" + str_transf = "" + if isinstance(self.transforms, list): + for i in range(0, len(self.transforms)): + str_transf = str_transf + str(i) + ": " + repr(self.transforms[i]) + if i != len(self.transforms) - 1: + str_transf = str_transf + " -> " + return f"Adapter([{str_transf}])" + else: + return f"Adapter([ 0: {repr(self.transforms)}])" def __getitem__(self, index): - if isinstance(index, slice): - if index.start > index.stop: - raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") - if index.stop < len(self.transforms): - sliced_transforms = self.transforms[index] - new_adapter = Adapter(transforms=sliced_transforms) - return new_adapter - else: - raise IndexError("Index slice out of range") - - elif isinstance(index, int): - if index < 0: - index = index + len(self.transforms) # negative indexing - if index < 0 or index >= len(self.transforms): - raise IndexError("Adapter index out of range.") - sliced_transforms = self.transforms[index] - new_adapter = Adapter(transforms=sliced_transforms) - return new_adapter - else: - raise TypeError("Invalid index type. Must be int or slice.") + return Adapter(transforms=self.transforms[index]) def __setitem__(self, index, new_value): if not isinstance(new_value, Adapter): raise TypeError("new_value must be an Adapter instance") - new_transform = new_value.transforms + # new_transform = new_value.transforms - if len(new_transform) == 0: + # To be tested + if len(new_value.transforms) == 0: raise ValueError( "new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform." ) if isinstance(index, slice): - if index.start > index.stop: - raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") - - if index.stop < len(self.transforms): - self.transforms[index] = new_transform - - else: - raise IndexError("Index slice out of range") + self.transforms[index] = new_value.transforms[:] elif isinstance(index, int): - if index < 0: # negative indexing - index = index + len(self.transforms) - - if index < 0 or index >= len(self.transforms): - raise IndexError("Index out of range.") - # could add that if the index is out of range, like index == len - # then we just add the transform + self.transforms[index : index + 1] = new_value.transforms[:] - self.transforms[index] = new_transform else: raise TypeError("Invalid index type. Must be int or slice.") @@ -165,7 +136,7 @@ def apply( self.transforms.append(transform) return self - # Begin of transformed derived from transform classes + # Begin of transforms derived from transform classes def as_set(self, keys: str | Sequence[str]): if isinstance(keys, str): keys = [keys] From aebf317019d41ee50bf74b8d1df34a002bef2ba8 Mon Sep 17 00:00:00 2001 From: Leona Odole Date: Sat, 4 Jan 2025 15:15:34 -0600 Subject: [PATCH 8/9] removed unecessary comments2 --- bayesflow/adapters/adapter.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 6aa7355ae..8b96c3b39 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -93,9 +93,6 @@ def __setitem__(self, index, new_value): if not isinstance(new_value, Adapter): raise TypeError("new_value must be an Adapter instance") - # new_transform = new_value.transforms - - # To be tested if len(new_value.transforms) == 0: raise ValueError( "new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform." @@ -136,7 +133,7 @@ def apply( self.transforms.append(transform) return self - # Begin of transforms derived from transform classes + def as_set(self, keys: str | Sequence[str]): if isinstance(keys, str): keys = [keys] From 08f6da0cc0375504a5c6979edb3fae720f5f2e90 Mon Sep 17 00:00:00 2001 From: larskue Date: Wed, 15 Jan 2025 16:03:14 +0100 Subject: [PATCH 9/9] add basic list methods to Adapter --- bayesflow/adapters/adapter.py | 87 +++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 30 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 8b96c3b39..e6b898d87 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Sequence +from collections.abc import Callable, MutableSequence, Sequence import numpy as np from keras.saving import ( @@ -25,17 +25,16 @@ ToArray, Transform, ) - from .transforms.filter_transform import Predicate @serializable(package="bayesflow.adapters") -class Adapter: +class Adapter(MutableSequence[Transform]): def __init__(self, transforms: Sequence[Transform] | None = None): if transforms is None: transforms = [] - self.transforms = transforms + self.transforms = list(transforms) @staticmethod def create_default(inference_variables: Sequence[str]) -> "Adapter": @@ -76,41 +75,70 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) -> return self.forward(data, **kwargs) def __repr__(self): - str_transf = "" - if isinstance(self.transforms, list): - for i in range(0, len(self.transforms)): - str_transf = str_transf + str(i) + ": " + repr(self.transforms[i]) - if i != len(self.transforms) - 1: - str_transf = str_transf + " -> " - return f"Adapter([{str_transf}])" - else: - return f"Adapter([ 0: {repr(self.transforms)}])" + result = "" + for i, transform in enumerate(self): + result += f"{i}: {transform!r}" + if i != len(self) - 1: + result += " -> " + + return f"Adapter([{result}])" + + # list methods + + def append(self, value: Transform) -> "Adapter": + self.transforms.append(value) + return self + + def __delitem__(self, key: int | slice): + del self.transforms[key] - def __getitem__(self, index): - return Adapter(transforms=self.transforms[index]) + def extend(self, values: Sequence[Transform]) -> "Adapter": + if isinstance(values, Adapter): + values = values.transforms - def __setitem__(self, index, new_value): - if not isinstance(new_value, Adapter): - raise TypeError("new_value must be an Adapter instance") + self.transforms.extend(values) - if len(new_value.transforms) == 0: - raise ValueError( - "new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform." - ) + return self - if isinstance(index, slice): - self.transforms[index] = new_value.transforms[:] + def __getitem__(self, item: int | slice) -> "Adapter": + if isinstance(item, int): + return self.transforms[item] - elif isinstance(index, int): - self.transforms[index : index + 1] = new_value.transforms[:] + return Adapter(self.transforms[item]) + def insert(self, index: int, value: Transform | Sequence[Transform]) -> "Adapter": + if isinstance(value, Adapter): + value = value.transforms + + if isinstance(value, Sequence): + # convenience: Adapters are always flat + self.transforms = self.transforms[:index] + list(value) + self.transforms[index:] else: - raise TypeError("Invalid index type. Must be int or slice.") + self.transforms.insert(index, value) - def add_transform(self, transform: Transform): - self.transforms.append(transform) return self + def __setitem__(self, key: int | slice, value: Transform | Sequence[Transform]) -> "Adapter": + if isinstance(value, Adapter): + value = value.transforms + + if isinstance(key, int) and isinstance(value, Sequence): + if key < 0: + key += len(self.transforms) + + key = slice(key, key + 1) + + self.transforms[key] = value + + return self + + def __len__(self): + return len(self.transforms) + + # adapter methods + + add_transform = append + def apply( self, *, @@ -133,7 +161,6 @@ def apply( self.transforms.append(transform) return self - def as_set(self, keys: str | Sequence[str]): if isinstance(keys, str): keys = [keys]