From 279296ea65a4e65f04a2f95cf8f0d79e6170893c Mon Sep 17 00:00:00 2001 From: Andreas Bergmeister Date: Fri, 8 Mar 2024 14:29:24 +0100 Subject: [PATCH] include cuda rng state in consistency_loss --- cm/karras_diffusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cm/karras_diffusion.py b/cm/karras_diffusion.py index 87b0e18..34ecae1 100644 --- a/cm/karras_diffusion.py +++ b/cm/karras_diffusion.py @@ -187,7 +187,7 @@ def euler_solver(samples, t, next_t, x0): x_t = x_start + noise * append_dims(t, dims) - dropout_state = th.get_rng_state() + dropout_state = (th.get_rng_state(), th.cuda.get_rng_state()) distiller = denoise_fn(x_t, t) if teacher_model is None: @@ -195,7 +195,8 @@ def euler_solver(samples, t, next_t, x0): else: x_t2 = heun_solver(x_t, t, t2, x_start).detach() - th.set_rng_state(dropout_state) + th.set_rng_state(dropout_state[0]) + th.cuda.set_rng_state(dropout_state[1]) distiller_target = target_denoise_fn(x_t2, t2) distiller_target = distiller_target.detach()