Skip to content

Commit

Permalink
fix type hints on templates
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed May 29, 2024
1 parent 4a0c92d commit 89ec32f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

from typing import Tuple, Union

import keras
from keras.saving import (
deserialize_keras_object,
Expand All @@ -7,7 +9,7 @@
)
from scipy.integrate import solve_ivp

from bayesflow.experimental.types import Tensor
from bayesflow.experimental.types import Shape, Tensor
from ..inference_network import InferenceNetwork


Expand Down Expand Up @@ -48,12 +50,12 @@ def train_step(self, data):
super().train_step(data)
self.call = call

def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Tensor | (Tensor, Tensor):
def _forward(self, x: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]:
# implement conditions = None and jacobian = False first
# then work your way up
raise NotImplementedError

def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Tensor | (Tensor, Tensor):
def _inverse(self, z: Tensor, conditions: any = None, jacobian: bool = False, steps: int = 100, method: str = "RK45") -> Union[Tensor, Tuple[Tensor, Tensor]]:
raise NotImplementedError

def compute_loss(self, x=None, **kwargs):
Expand Down
8 changes: 5 additions & 3 deletions bayesflow/experimental/networks/inference_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

from typing import Tuple, Union

import keras
from keras.saving import (
register_keras_serializable,
Expand All @@ -21,15 +23,15 @@ def get_config(self) -> dict:
config = {"base_distribution": serialize_keras_object(self.base_distribution)}
return base_config | config

def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Tensor | (Tensor, Tensor):
def call(self, xz: Tensor, inverse: bool = False, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if inverse:
return self._inverse(xz, **kwargs)
return self._forward(xz, **kwargs)

def _forward(self, x: Tensor, **kwargs) -> Tensor | (Tensor, Tensor):
def _forward(self, x: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
raise NotImplementedError

def _inverse(self, z: Tensor, **kwargs) -> Tensor | (Tensor, Tensor):
def _inverse(self, z: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Tensor]]:
raise NotImplementedError

def sample(self, num_samples: int, **kwargs) -> Tensor:
Expand Down

0 comments on commit 89ec32f

Please sign in to comment.