Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistency Model: Fix JIT compilation #224

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 52 additions & 15 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math

import keras
from keras import ops
from keras.saving import (
register_keras_serializable,
)

import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs

Expand Down Expand Up @@ -82,32 +82,34 @@ def __init__(

self.s0 = float(s0)
self.s1 = float(s1)
self.current_step = 0.0
# create variable that works with JIT compilation
self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int32")
self.current_step.assign(0)

self.seed_generator = keras.random.SeedGenerator()

def _schedule_discretization(self) -> int:
def _schedule_discretization(self, step) -> float:
"""Schedule function for adjusting the discretization level `N` during
the course of training.

Implements the function N(k) from [2], Section 3.4.
"""

k_ = math.floor(self.total_steps / (math.log(self.s1 / self.s0) / math.log(2.0) + 1.0))
out = min(self.s0 * math.pow(2.0, math.floor(self.current_step / k_)), self.s1) + 1.0
return int(out)
k_ = ops.floor(self.total_steps / (ops.log(self.s1 / self.s0) / ops.log(2.0) + 1.0))
out = ops.minimum(self.s0 * ops.power(2.0, ops.floor(step / k_)), self.s1) + 1.0
return out

def _discretize_time(self, num_steps, rho=7.0):
"""Function for obtaining the discretized time according to [2],
Section 2, bottom of page 2.
"""

N = num_steps + 1.0
N = num_steps + 1
indices = ops.arange(1, N + 1, dtype="float32")
one_over_rho = 1.0 / rho
discretized_time = (
self.eps**one_over_rho
+ (indices - 1.0) / (N - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho)
+ (indices - 1.0) / (ops.cast(N, "float32") - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho)
) ** rho
return discretized_time

Expand All @@ -131,9 +133,36 @@ def build(self, xz_shape, conditions_shape=None):
self.student_projector.build(input_shape)

# Choose coefficient according to [2] Section 3.3
self.c_huber = 0.00054 * math.sqrt(xz_shape[-1])
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])
self.c_huber2 = self.c_huber**2

## Calculate discretization schedule in advance
# The Jax compiler requires fixed-size arrays, so we have
# to store all the discretized_times in one matrix in advance
# and later only access the relevant entries.

# First, we calculate all unique numbers of discretization steps n
# in a loop, as self.total_steps might be large
self.max_n = int(self._schedule_discretization(self.total_steps))
assert self.max_n == self.s1 + 1
unique_n = set()
for step in range(int(self.total_steps)):
unique_n.add(int(self._schedule_discretization(step)))
unique_n = sorted(list(unique_n))

# Next, we calculate the discretized times for each n
# and establish a mapping between n and the position i of the
# discretizated times in the vector
discretized_times = np.zeros((len(unique_n), self.max_n + 1))
discretization_map = np.zeros((self.max_n + 1,), dtype=np.int32)
for i, n in enumerate(unique_n):
disc = self._discretize_time(n)
discretized_times[i, : len(disc)] = disc
discretization_map[n] = i
# Finally, we convert the vectors to tensors
self.discretized_times = ops.convert_to_tensor(discretized_times, dtype="float32")
self.discretization_map = ops.convert_to_tensor(discretization_map)

def call(
self,
xz: Tensor,
Expand Down Expand Up @@ -224,19 +253,27 @@ def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "tr

# The discretization schedule requires the number of passed training steps.
# To be independent of external information, we track it here.
self.current_step += 1
if stage == "training":
self.current_step.assign_add(1)
self.current_step.assign(ops.minimum(self.current_step, self.total_steps - 1))

current_num_steps = self._schedule_discretization()
discretized_time = self._discretize_time(current_num_steps)
discretization_index = ops.take(
self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int32")
)
discretized_time = ops.take(self.discretized_times, discretization_index, axis=0)

# Randomly sample t_n and t_[n+1] and reshape to (batch_size, 1)
# adapted noise schedule from [2], Section 3.5
p_mean = -1.1
p_std = 2.0
log_p = ops.log(
p = ops.where(
discretized_time[1:] > 0.0,
ops.erf((ops.log(discretized_time[1:]) - p_mean) / (ops.sqrt(2.0) * p_std))
- ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std))
- ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std)),
0.0,
)

log_p = ops.log(p)
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
t1 = ops.take(discretized_time, times)[..., None]
t2 = ops.take(discretized_time, times + 1)[..., None]
Expand Down
Loading