Skip to content

Commit

Permalink
Jaxify evidence calculation for flows.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Nov 6, 2023
1 parent 8ac00a6 commit bb25009
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions harmonic/evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import cloudpickle
from harmonic import logs as lg
import jax.numpy as jnp
import jax


class Shifting(Enum):
Expand Down Expand Up @@ -192,6 +193,30 @@ def process_run(self):

return

def get_masks(self, chain_start_ixs: jnp.ndarray) -> jnp.ndarray:
"""Create mask array for a 2D array of concatenated chains of different lengths.
Args:
chain_start_ixs (jnp.ndarray[nchains+1]): Start indices of chains
in Chain object.
Returns:
jnp.ndarray[nchains,nsamples]: Mask array with each row corresponding to a chain
and entries with boolean values depending on if given sample at that
position is in that chain.
"""

nsamples = chain_start_ixs[-1]
range_vector = jnp.arange(nsamples)

# Create a mask array by broadcasting the range vector
masks_arr = (range_vector >= chain_start_ixs[:-1][:, None]) & (
range_vector < chain_start_ixs[1:][:, None]
)

return masks_arr

def add_chains(self, chains):
"""Add new chains and calculate an estimate of the inverse evidence, its
variance, and the variance of the variance.
Expand Down Expand Up @@ -266,32 +291,32 @@ def add_chains(self, chains):
# Shifts by the absolute maximum of log-posterior
self.set_shift(-lnargs[np.nanargmax(np.abs(lnargs))])

def get_running_sum(lnargs, mask):
running_sum = jnp.nansum(jnp.where(mask, jnp.exp(lnargs), 0.0))
return running_sum

def get_nans_per_chain(lnargs, mask):
nans_num = jnp.sum(jnp.where(mask, jnp.isnan(lnargs), 0.0))
return nans_num

if self.batch_calculation:
lnargs += self.shift_value
self.running_sum += jnp.array(
[
jnp.sum(
jnp.exp(
lnargs[
chains.start_indices[i_chains] : chains.start_indices[
i_chains + 1
]
]
)
)
for i_chains in range(nchains)
]

masks = self.get_masks(jnp.array(chains.start_indices))

running_sum_val = jax.vmap(get_running_sum, in_axes=(None, 0))(
lnargs, masks
)
self.running_sum += running_sum_val

# Count added number of samples per chain
added_nsamples_per_chain = np.diff(jnp.array(chains.start_indices))
self.nsamples_per_chain += added_nsamples_per_chain

# Count number of NaN values per chain and subtract to get effective
# number of added samples per chain
isnan_per_chain = np.split(np.isnan(lnargs), chains.start_indices[1:-1])
nan_count_per_chain = np.array(
[np.sum(nan_count) for nan_count in isnan_per_chain]
nan_count_per_chain = jax.vmap(get_nans_per_chain, in_axes=(None, 0))(
lnargs, masks
)
self.nsamples_eff_per_chain += (
added_nsamples_per_chain - nan_count_per_chain
Expand Down

0 comments on commit bb25009

Please sign in to comment.