Skip to content

Commit

Permalink
Semantic parameter names for coupling flows
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed May 14, 2024
1 parent 525e7de commit c8005fb
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions bayesflow/experimental/networks/coupling_flow/coupling_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,39 +43,39 @@ def call(self, *args, **kwargs):
def compute_loss(self, x=None, y=None, y_pred=None, **kwargs):
z, log_det = y_pred
log_prob = self.base_distribution.log_prob(z)
nll = -keras.ops.mean(log_prob + log_det, axis=0)
nll = -keras.ops.mean(log_prob + log_det)

return nll

def compute_metrics(self, x, y, y_pred, **kwargs):
return {}

def forward(self, x, c=None, **kwargs):
z = x
def forward(self, targets, conditions=None, **kwargs) -> (Tensor, Tensor):
latents = targets
log_det = 0.
for coupling in self.layers:
z, det = coupling.forward(z, c, **kwargs)
latents, det = coupling.forward(latents, conditions, **kwargs)
log_det += det

return z, log_det
return latents, log_det

def inverse(self, z, c=None):
x = z
def inverse(self, latents, conditions=None) -> (Tensor, Tensor):
targets = latents
log_det = 0.
for coupling in reversed(self.layers):
x, det = coupling.inverse(x, c)
targets, det = coupling.inverse(targets, conditions)
log_det += det

return x, log_det
return targets, log_det

def sample(self, batch_shape: Shape):
z = self.base_distribution.sample(batch_shape)
x, _ = self.inverse(z)
def sample(self, batch_shape: Shape, conditions=None) -> Tensor:
latents = self.base_distribution.sample(batch_shape)
targets, _ = self.inverse(latents, conditions)

return x
return targets

def log_prob(self, x: Tensor, **kwargs) -> Tensor:
z, log_det = self.forward(x, **kwargs)
log_prob = self.base_distribution.log_prob(z)
def log_prob(self, targets: Tensor, conditions=None, **kwargs) -> Tensor:
latents, log_det = self.forward(targets, conditions, **kwargs)
log_prob = self.base_distribution.log_prob(latents)

return log_prob + log_det

0 comments on commit c8005fb

Please sign in to comment.