Skip to content

Commit

Permalink
Consistency Model: Fix JIT compilation (#224)
Browse files Browse the repository at this point in the history
* Consistency Models: fixes for TF jit

* Consistency Model: Fix JIT compilation

Adjusting the schedule requires special care to work with JIT
compilation.
With this commit, we pre-calculate the schedule, so that setting
`current_step` works in all backends.
  • Loading branch information
vpratz authored Oct 25, 2024
1 parent 3c155bd commit 0854460
Showing 1 changed file with 52 additions and 15 deletions.
67 changes: 52 additions & 15 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math

import keras
from keras import ops
from keras.saving import (
register_keras_serializable,
)

import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs

Expand Down Expand Up @@ -82,32 +82,34 @@ def __init__(

self.s0 = float(s0)
self.s1 = float(s1)
self.current_step = 0.0
# create variable that works with JIT compilation
self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int32")
self.current_step.assign(0)

self.seed_generator = keras.random.SeedGenerator()

def _schedule_discretization(self) -> int:
def _schedule_discretization(self, step) -> float:
"""Schedule function for adjusting the discretization level `N` during
the course of training.
Implements the function N(k) from [2], Section 3.4.
"""

k_ = math.floor(self.total_steps / (math.log(self.s1 / self.s0) / math.log(2.0) + 1.0))
out = min(self.s0 * math.pow(2.0, math.floor(self.current_step / k_)), self.s1) + 1.0
return int(out)
k_ = ops.floor(self.total_steps / (ops.log(self.s1 / self.s0) / ops.log(2.0) + 1.0))
out = ops.minimum(self.s0 * ops.power(2.0, ops.floor(step / k_)), self.s1) + 1.0
return out

def _discretize_time(self, num_steps, rho=7.0):
"""Function for obtaining the discretized time according to [2],
Section 2, bottom of page 2.
"""

N = num_steps + 1.0
N = num_steps + 1
indices = ops.arange(1, N + 1, dtype="float32")
one_over_rho = 1.0 / rho
discretized_time = (
self.eps**one_over_rho
+ (indices - 1.0) / (N - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho)
+ (indices - 1.0) / (ops.cast(N, "float32") - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho)
) ** rho
return discretized_time

Expand All @@ -131,9 +133,36 @@ def build(self, xz_shape, conditions_shape=None):
self.student_projector.build(input_shape)

# Choose coefficient according to [2] Section 3.3
self.c_huber = 0.00054 * math.sqrt(xz_shape[-1])
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])
self.c_huber2 = self.c_huber**2

## Calculate discretization schedule in advance
# The Jax compiler requires fixed-size arrays, so we have
# to store all the discretized_times in one matrix in advance
# and later only access the relevant entries.

# First, we calculate all unique numbers of discretization steps n
# in a loop, as self.total_steps might be large
self.max_n = int(self._schedule_discretization(self.total_steps))
assert self.max_n == self.s1 + 1
unique_n = set()
for step in range(int(self.total_steps)):
unique_n.add(int(self._schedule_discretization(step)))
unique_n = sorted(list(unique_n))

# Next, we calculate the discretized times for each n
# and establish a mapping between n and the position i of the
# discretizated times in the vector
discretized_times = np.zeros((len(unique_n), self.max_n + 1))
discretization_map = np.zeros((self.max_n + 1,), dtype=np.int32)
for i, n in enumerate(unique_n):
disc = self._discretize_time(n)
discretized_times[i, : len(disc)] = disc
discretization_map[n] = i
# Finally, we convert the vectors to tensors
self.discretized_times = ops.convert_to_tensor(discretized_times, dtype="float32")
self.discretization_map = ops.convert_to_tensor(discretization_map)

def call(
self,
xz: Tensor,
Expand Down Expand Up @@ -224,19 +253,27 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr

# The discretization schedule requires the number of passed training steps.
# To be independent of external information, we track it here.
self.current_step += 1
if stage == "training":
self.current_step.assign_add(1)
self.current_step.assign(ops.minimum(self.current_step, self.total_steps - 1))

current_num_steps = self._schedule_discretization()
discretized_time = self._discretize_time(current_num_steps)
discretization_index = ops.take(
self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int32")
)
discretized_time = ops.take(self.discretized_times, discretization_index, axis=0)

# Randomly sample t_n and t_[n+1] and reshape to (batch_size, 1)
# adapted noise schedule from [2], Section 3.5
p_mean = -1.1
p_std = 2.0
log_p = ops.log(
p = ops.where(
discretized_time[1:] > 0.0,
ops.erf((ops.log(discretized_time[1:]) - p_mean) / (ops.sqrt(2.0) * p_std))
- ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std))
- ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std)),
0.0,
)

log_p = ops.log(p)
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
t1 = ops.take(discretized_time, times)[..., None]
t2 = ops.take(discretized_time, times + 1)[..., None]
Expand Down

0 comments on commit 0854460

Please sign in to comment.