From 085446053a18317ce2f66d13bc80e548b3de40ce Mon Sep 17 00:00:00 2001 From: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Date: Sat, 26 Oct 2024 01:50:54 +0200 Subject: [PATCH] Consistency Model: Fix JIT compilation (#224) * Consistency Models: fixes for TF jit * Consistency Model: Fix JIT compilation Adjusting the schedule requires special care to work with JIT compilation. With this commit, we pre-calculate the schedule, so that setting `current_step` works in all backends. --- .../consistency_models/consistency_model.py | 67 ++++++++++++++----- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index d316709b8..138668abf 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -1,11 +1,11 @@ -import math - import keras from keras import ops from keras.saving import ( register_keras_serializable, ) +import numpy as np + from bayesflow.types import Tensor from bayesflow.utils import find_network, keras_kwargs @@ -82,32 +82,34 @@ def __init__( self.s0 = float(s0) self.s1 = float(s1) - self.current_step = 0.0 + # create variable that works with JIT compilation + self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int32") + self.current_step.assign(0) self.seed_generator = keras.random.SeedGenerator() - def _schedule_discretization(self) -> int: + def _schedule_discretization(self, step) -> float: """Schedule function for adjusting the discretization level `N` during the course of training. Implements the function N(k) from [2], Section 3.4. """ - k_ = math.floor(self.total_steps / (math.log(self.s1 / self.s0) / math.log(2.0) + 1.0)) - out = min(self.s0 * math.pow(2.0, math.floor(self.current_step / k_)), self.s1) + 1.0 - return int(out) + k_ = ops.floor(self.total_steps / (ops.log(self.s1 / self.s0) / ops.log(2.0) + 1.0)) + out = ops.minimum(self.s0 * ops.power(2.0, ops.floor(step / k_)), self.s1) + 1.0 + return out def _discretize_time(self, num_steps, rho=7.0): """Function for obtaining the discretized time according to [2], Section 2, bottom of page 2. """ - N = num_steps + 1.0 + N = num_steps + 1 indices = ops.arange(1, N + 1, dtype="float32") one_over_rho = 1.0 / rho discretized_time = ( self.eps**one_over_rho - + (indices - 1.0) / (N - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho) + + (indices - 1.0) / (ops.cast(N, "float32") - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho) ) ** rho return discretized_time @@ -131,9 +133,36 @@ def build(self, xz_shape, conditions_shape=None): self.student_projector.build(input_shape) # Choose coefficient according to [2] Section 3.3 - self.c_huber = 0.00054 * math.sqrt(xz_shape[-1]) + self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1]) self.c_huber2 = self.c_huber**2 + ## Calculate discretization schedule in advance + # The Jax compiler requires fixed-size arrays, so we have + # to store all the discretized_times in one matrix in advance + # and later only access the relevant entries. + + # First, we calculate all unique numbers of discretization steps n + # in a loop, as self.total_steps might be large + self.max_n = int(self._schedule_discretization(self.total_steps)) + assert self.max_n == self.s1 + 1 + unique_n = set() + for step in range(int(self.total_steps)): + unique_n.add(int(self._schedule_discretization(step))) + unique_n = sorted(list(unique_n)) + + # Next, we calculate the discretized times for each n + # and establish a mapping between n and the position i of the + # discretizated times in the vector + discretized_times = np.zeros((len(unique_n), self.max_n + 1)) + discretization_map = np.zeros((self.max_n + 1,), dtype=np.int32) + for i, n in enumerate(unique_n): + disc = self._discretize_time(n) + discretized_times[i, : len(disc)] = disc + discretization_map[n] = i + # Finally, we convert the vectors to tensors + self.discretized_times = ops.convert_to_tensor(discretized_times, dtype="float32") + self.discretization_map = ops.convert_to_tensor(discretization_map) + def call( self, xz: Tensor, @@ -224,19 +253,27 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr # The discretization schedule requires the number of passed training steps. # To be independent of external information, we track it here. - self.current_step += 1 + if stage == "training": + self.current_step.assign_add(1) + self.current_step.assign(ops.minimum(self.current_step, self.total_steps - 1)) - current_num_steps = self._schedule_discretization() - discretized_time = self._discretize_time(current_num_steps) + discretization_index = ops.take( + self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int32") + ) + discretized_time = ops.take(self.discretized_times, discretization_index, axis=0) # Randomly sample t_n and t_[n+1] and reshape to (batch_size, 1) # adapted noise schedule from [2], Section 3.5 p_mean = -1.1 p_std = 2.0 - log_p = ops.log( + p = ops.where( + discretized_time[1:] > 0.0, ops.erf((ops.log(discretized_time[1:]) - p_mean) / (ops.sqrt(2.0) * p_std)) - - ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std)) + - ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std)), + 0.0, ) + + log_p = ops.log(p) times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0] t1 = ops.take(discretized_time, times)[..., None] t2 = ops.take(discretized_time, times + 1)[..., None]