Skip to content

Commit

Permalink
[LSC] change uses of jax.random.KeyArray and jax.random.PRNGKeyArray …
Browse files Browse the repository at this point in the history
…to jax.Array

This change replaces uses of jax.random.KeyArray and jax.random.PRNGKeyArray in the context of type annotations with jax.Array, which is the correct annotation for JAX PRNG keys moving forward.

The purpose of this change is to remove references to KeyArray and PRNGKeyArray, which are deprecated (jax-ml/jax#17594) and will soon be removed from JAX. The design and thought process behind this is described in https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html.

Note that KeyArray and PRNGKeyArray have always been aliased to Any, so the new type annotation is far more specific than the old one.

PiperOrigin-RevId: 574006700
  • Loading branch information
Jake VanderPlas authored and Magenta Team committed Oct 17, 2023
1 parent 3ed19db commit e91a537
Show file tree
Hide file tree
Showing 30 changed files with 39 additions and 39 deletions.
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/audio_codecs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/beam/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/dump_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/event_codec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/event_codec_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/feature_converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/layers_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -234,7 +234,7 @@ def predict_x0_from_v(*,


def get_diffusion_training_input(
rng: jax.random.KeyArray,
rng: jax.Array,
x0: jnp.ndarray,
diffusion_config: DiffusionConfig
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
Expand Down Expand Up @@ -396,7 +396,7 @@ def ddpm_step(i: jnp.ndarray, rng: jnp.ndarray, logsnr_s: jnp.ndarray,


def eval_step(
rng: jax.random.KeyArray,
rng: jax.Array,
diffusion_config: DiffusionConfig,
batch_size: int,
pred_fn: Callable[..., jnp.ndarray]
Expand Down Expand Up @@ -453,7 +453,7 @@ def body(z_t, i):
return body


def eval_scan(rng: jax.random.KeyArray,
def eval_scan(rng: jax.Array,
target_shape: Tuple[int],
pred_fn: Callable[..., jnp.ndarray],
diffusion_config: DiffusionConfig) -> jnp.ndarray:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
14 changes: 7 additions & 7 deletions music_spectrogram_diffusion/models/diffusion/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,12 +50,12 @@ def _compute_logits(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray] = None) -> jnp.ndarray:
dropout_rng: Optional[jax.Array] = None) -> jnp.ndarray:
raise NotImplementedError("Not used for the diffusion model.")

def get_initial_variables(
self,
rng: jax.random.KeyArray,
rng: jax.Array,
input_shapes: Mapping[str, models.Array],
input_types: Optional[Mapping[str, jnp.dtype]] = None
) -> flax_scope.FrozenVariableDict:
Expand Down Expand Up @@ -150,7 +150,7 @@ def predict_batch_with_aux(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = jax.random.PRNGKey(0),
rng: Optional[jax.Array] = jax.random.PRNGKey(0),
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
"""Predict by doing a loop over the forward diffusion process.
Expand Down Expand Up @@ -224,12 +224,12 @@ def _compute_logits(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray] = None) -> jnp.ndarray:
dropout_rng: Optional[jax.Array] = None) -> jnp.ndarray:
raise NotImplementedError("Not used for the diffusion model.")

def get_initial_variables(
self,
rng: jax.random.KeyArray,
rng: jax.Array,
input_shapes: Mapping[str, models.Array],
input_types: Optional[Mapping[str, jnp.dtype]] = None
) -> flax_scope.FrozenVariableDict:
Expand Down Expand Up @@ -341,7 +341,7 @@ def predict_batch_with_aux(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
rng: Optional[jax.random.KeyArray] = jax.random.PRNGKey(0),
rng: Optional[jax.Array] = jax.random.PRNGKey(0),
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
"""Predict by doing a loop over the forward diffusion process.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/models/diffusion/network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/note_sequences.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/note_sequences_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/run_length_encoding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/run_length_encoding_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/transcription_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/vocabularies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion music_spectrogram_diffusion/vocabularies_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
# Copyright 2023 The Music Spectrogram Diffusion Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit e91a537

Please sign in to comment.