From e91a5376f605774d4e1055baa6ed8c8115e2d637 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 16 Oct 2023 20:30:42 -0700 Subject: [PATCH] [LSC] change uses of jax.random.KeyArray and jax.random.PRNGKeyArray to jax.Array This change replaces uses of jax.random.KeyArray and jax.random.PRNGKeyArray in the context of type annotations with jax.Array, which is the correct annotation for JAX PRNG keys moving forward. The purpose of this change is to remove references to KeyArray and PRNGKeyArray, which are deprecated (https://github.com/google/jax/pull/17594) and will soon be removed from JAX. The design and thought process behind this is described in https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html. Note that KeyArray and PRNGKeyArray have always been aliased to Any, so the new type annotation is far more specific than the old one. PiperOrigin-RevId: 574006700 --- music_spectrogram_diffusion/__init__.py | 2 +- music_spectrogram_diffusion/audio_codecs.py | 2 +- music_spectrogram_diffusion/beam/evaluation.py | 2 +- music_spectrogram_diffusion/datasets.py | 2 +- music_spectrogram_diffusion/dump_task.py | 2 +- music_spectrogram_diffusion/event_codec.py | 2 +- music_spectrogram_diffusion/event_codec_test.py | 2 +- music_spectrogram_diffusion/feature_converters.py | 2 +- music_spectrogram_diffusion/inference.py | 2 +- music_spectrogram_diffusion/layers.py | 2 +- music_spectrogram_diffusion/layers_test.py | 2 +- music_spectrogram_diffusion/metrics.py | 2 +- music_spectrogram_diffusion/metrics_test.py | 2 +- .../models/autoregressive/models.py | 2 +- .../models/autoregressive/network.py | 2 +- .../models/autoregressive/output_functions.py | 2 +- .../models/diffusion/diffusion_utils.py | 8 ++++---- .../models/diffusion/feature_converters.py | 2 +- .../models/diffusion/models.py | 14 +++++++------- .../models/diffusion/network.py | 2 +- music_spectrogram_diffusion/note_sequences.py | 2 +- music_spectrogram_diffusion/note_sequences_test.py | 2 +- music_spectrogram_diffusion/postprocessors.py | 2 +- music_spectrogram_diffusion/preprocessors.py | 2 +- music_spectrogram_diffusion/run_length_encoding.py | 2 +- .../run_length_encoding_test.py | 2 +- music_spectrogram_diffusion/tasks.py | 2 +- .../transcription_inference.py | 2 +- music_spectrogram_diffusion/vocabularies.py | 2 +- music_spectrogram_diffusion/vocabularies_test.py | 2 +- 30 files changed, 39 insertions(+), 39 deletions(-) diff --git a/music_spectrogram_diffusion/__init__.py b/music_spectrogram_diffusion/__init__.py index 556872e..715d986 100644 --- a/music_spectrogram_diffusion/__init__.py +++ b/music_spectrogram_diffusion/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/audio_codecs.py b/music_spectrogram_diffusion/audio_codecs.py index 0088880..b64cd43 100644 --- a/music_spectrogram_diffusion/audio_codecs.py +++ b/music_spectrogram_diffusion/audio_codecs.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/beam/evaluation.py b/music_spectrogram_diffusion/beam/evaluation.py index 80a42c9..fb9dff0 100644 --- a/music_spectrogram_diffusion/beam/evaluation.py +++ b/music_spectrogram_diffusion/beam/evaluation.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/datasets.py b/music_spectrogram_diffusion/datasets.py index e0f5688..cebb914 100644 --- a/music_spectrogram_diffusion/datasets.py +++ b/music_spectrogram_diffusion/datasets.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/dump_task.py b/music_spectrogram_diffusion/dump_task.py index dfc239a..17e08a4 100644 --- a/music_spectrogram_diffusion/dump_task.py +++ b/music_spectrogram_diffusion/dump_task.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/event_codec.py b/music_spectrogram_diffusion/event_codec.py index 7df3dec..b8515d9 100644 --- a/music_spectrogram_diffusion/event_codec.py +++ b/music_spectrogram_diffusion/event_codec.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/event_codec_test.py b/music_spectrogram_diffusion/event_codec_test.py index 0e6d687..aeb3d07 100644 --- a/music_spectrogram_diffusion/event_codec_test.py +++ b/music_spectrogram_diffusion/event_codec_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/feature_converters.py b/music_spectrogram_diffusion/feature_converters.py index e429502..dbf3652 100644 --- a/music_spectrogram_diffusion/feature_converters.py +++ b/music_spectrogram_diffusion/feature_converters.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/inference.py b/music_spectrogram_diffusion/inference.py index e6f23f5..887a9f6 100644 --- a/music_spectrogram_diffusion/inference.py +++ b/music_spectrogram_diffusion/inference.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/layers.py b/music_spectrogram_diffusion/layers.py index 93f3127..fc9a035 100644 --- a/music_spectrogram_diffusion/layers.py +++ b/music_spectrogram_diffusion/layers.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/layers_test.py b/music_spectrogram_diffusion/layers_test.py index de14c6d..f82a415 100644 --- a/music_spectrogram_diffusion/layers_test.py +++ b/music_spectrogram_diffusion/layers_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/metrics.py b/music_spectrogram_diffusion/metrics.py index dbc11ec..998400b 100644 --- a/music_spectrogram_diffusion/metrics.py +++ b/music_spectrogram_diffusion/metrics.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/metrics_test.py b/music_spectrogram_diffusion/metrics_test.py index cd72bfd..a66bc89 100644 --- a/music_spectrogram_diffusion/metrics_test.py +++ b/music_spectrogram_diffusion/metrics_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/models/autoregressive/models.py b/music_spectrogram_diffusion/models/autoregressive/models.py index f26c55c..73586cd 100644 --- a/music_spectrogram_diffusion/models/autoregressive/models.py +++ b/music_spectrogram_diffusion/models/autoregressive/models.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/models/autoregressive/network.py b/music_spectrogram_diffusion/models/autoregressive/network.py index e984d5a..6a9f675 100644 --- a/music_spectrogram_diffusion/models/autoregressive/network.py +++ b/music_spectrogram_diffusion/models/autoregressive/network.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/models/autoregressive/output_functions.py b/music_spectrogram_diffusion/models/autoregressive/output_functions.py index 48e23f6..8c9c358 100644 --- a/music_spectrogram_diffusion/models/autoregressive/output_functions.py +++ b/music_spectrogram_diffusion/models/autoregressive/output_functions.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/models/diffusion/diffusion_utils.py b/music_spectrogram_diffusion/models/diffusion/diffusion_utils.py index ff0dcd0..ca33303 100644 --- a/music_spectrogram_diffusion/models/diffusion/diffusion_utils.py +++ b/music_spectrogram_diffusion/models/diffusion/diffusion_utils.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -234,7 +234,7 @@ def predict_x0_from_v(*, def get_diffusion_training_input( - rng: jax.random.KeyArray, + rng: jax.Array, x0: jnp.ndarray, diffusion_config: DiffusionConfig ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: @@ -396,7 +396,7 @@ def ddpm_step(i: jnp.ndarray, rng: jnp.ndarray, logsnr_s: jnp.ndarray, def eval_step( - rng: jax.random.KeyArray, + rng: jax.Array, diffusion_config: DiffusionConfig, batch_size: int, pred_fn: Callable[..., jnp.ndarray] @@ -453,7 +453,7 @@ def body(z_t, i): return body -def eval_scan(rng: jax.random.KeyArray, +def eval_scan(rng: jax.Array, target_shape: Tuple[int], pred_fn: Callable[..., jnp.ndarray], diffusion_config: DiffusionConfig) -> jnp.ndarray: diff --git a/music_spectrogram_diffusion/models/diffusion/feature_converters.py b/music_spectrogram_diffusion/models/diffusion/feature_converters.py index eee99e1..f2c89bb 100644 --- a/music_spectrogram_diffusion/models/diffusion/feature_converters.py +++ b/music_spectrogram_diffusion/models/diffusion/feature_converters.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/models/diffusion/models.py b/music_spectrogram_diffusion/models/diffusion/models.py index f4fc0f6..841700e 100644 --- a/music_spectrogram_diffusion/models/diffusion/models.py +++ b/music_spectrogram_diffusion/models/diffusion/models.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,12 +50,12 @@ def _compute_logits( self, params: PyTree, batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jax.random.KeyArray] = None) -> jnp.ndarray: + dropout_rng: Optional[jax.Array] = None) -> jnp.ndarray: raise NotImplementedError("Not used for the diffusion model.") def get_initial_variables( self, - rng: jax.random.KeyArray, + rng: jax.Array, input_shapes: Mapping[str, models.Array], input_types: Optional[Mapping[str, jnp.dtype]] = None ) -> flax_scope.FrozenVariableDict: @@ -150,7 +150,7 @@ def predict_batch_with_aux( self, params: PyTree, batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.random.KeyArray] = jax.random.PRNGKey(0), + rng: Optional[jax.Array] = jax.random.PRNGKey(0), ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Predict by doing a loop over the forward diffusion process. @@ -224,12 +224,12 @@ def _compute_logits( self, params: PyTree, batch: Mapping[str, jnp.ndarray], - dropout_rng: Optional[jax.random.KeyArray] = None) -> jnp.ndarray: + dropout_rng: Optional[jax.Array] = None) -> jnp.ndarray: raise NotImplementedError("Not used for the diffusion model.") def get_initial_variables( self, - rng: jax.random.KeyArray, + rng: jax.Array, input_shapes: Mapping[str, models.Array], input_types: Optional[Mapping[str, jnp.dtype]] = None ) -> flax_scope.FrozenVariableDict: @@ -341,7 +341,7 @@ def predict_batch_with_aux( self, params: PyTree, batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.random.KeyArray] = jax.random.PRNGKey(0), + rng: Optional[jax.Array] = jax.random.PRNGKey(0), ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Predict by doing a loop over the forward diffusion process. diff --git a/music_spectrogram_diffusion/models/diffusion/network.py b/music_spectrogram_diffusion/models/diffusion/network.py index a6bbae0..32a65a1 100644 --- a/music_spectrogram_diffusion/models/diffusion/network.py +++ b/music_spectrogram_diffusion/models/diffusion/network.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/note_sequences.py b/music_spectrogram_diffusion/note_sequences.py index 918fc3c..d4bcd10 100644 --- a/music_spectrogram_diffusion/note_sequences.py +++ b/music_spectrogram_diffusion/note_sequences.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/note_sequences_test.py b/music_spectrogram_diffusion/note_sequences_test.py index 95f40b2..813dcbf 100644 --- a/music_spectrogram_diffusion/note_sequences_test.py +++ b/music_spectrogram_diffusion/note_sequences_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/postprocessors.py b/music_spectrogram_diffusion/postprocessors.py index 2effc08..123d6b2 100644 --- a/music_spectrogram_diffusion/postprocessors.py +++ b/music_spectrogram_diffusion/postprocessors.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/preprocessors.py b/music_spectrogram_diffusion/preprocessors.py index 4e0c87f..ccfb0ec 100644 --- a/music_spectrogram_diffusion/preprocessors.py +++ b/music_spectrogram_diffusion/preprocessors.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/run_length_encoding.py b/music_spectrogram_diffusion/run_length_encoding.py index eb78b6c..2679455 100644 --- a/music_spectrogram_diffusion/run_length_encoding.py +++ b/music_spectrogram_diffusion/run_length_encoding.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/run_length_encoding_test.py b/music_spectrogram_diffusion/run_length_encoding_test.py index 016a659..aded33b 100644 --- a/music_spectrogram_diffusion/run_length_encoding_test.py +++ b/music_spectrogram_diffusion/run_length_encoding_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/tasks.py b/music_spectrogram_diffusion/tasks.py index 528cc85..61cd63a 100644 --- a/music_spectrogram_diffusion/tasks.py +++ b/music_spectrogram_diffusion/tasks.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/transcription_inference.py b/music_spectrogram_diffusion/transcription_inference.py index dee1b48..a49e72a 100644 --- a/music_spectrogram_diffusion/transcription_inference.py +++ b/music_spectrogram_diffusion/transcription_inference.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/vocabularies.py b/music_spectrogram_diffusion/vocabularies.py index ea8adbf..f6c0f3e 100644 --- a/music_spectrogram_diffusion/vocabularies.py +++ b/music_spectrogram_diffusion/vocabularies.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/music_spectrogram_diffusion/vocabularies_test.py b/music_spectrogram_diffusion/vocabularies_test.py index 3ce9631..a4a9cec 100644 --- a/music_spectrogram_diffusion/vocabularies_test.py +++ b/music_spectrogram_diffusion/vocabularies_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2023 The Music Spectrogram Diffusion Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.