From c6483f9dccf4d53582aeab650cc7a001e0ef090e Mon Sep 17 00:00:00 2001 From: danielward27 Date: Sat, 9 Apr 2022 10:54:56 +0100 Subject: [PATCH] leave bnaf bijection to throw error when sampling --- flowjax/bijections/abc.py | 1 + flowjax/flows.py | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/flowjax/bijections/abc.py b/flowjax/bijections/abc.py index f9e5427b..346b94a7 100644 --- a/flowjax/bijections/abc.py +++ b/flowjax/bijections/abc.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import jax.numpy as jnp + class Bijection(ABC): """Basic bijection class. All bijections should support conditioning variables (even if ignored).""" diff --git a/flowjax/flows.py b/flowjax/flows.py index 1d9c341c..49ae7f72 100644 --- a/flowjax/flows.py +++ b/flowjax/flows.py @@ -190,9 +190,3 @@ def __init__( ) bijection = Chain(bijections) super().__init__(bijection, target_dim, base_log_prob, base_sample) - - def sample(self, *args, **kwargs): - raise NotImplementedError( - "Sampling this flow would require numerical inversion of the bijection" - ) -