From 8d4e9af845759e0c27dbcfbaba2a99abbf50cb1e Mon Sep 17 00:00:00 2001 From: Tomasz Kalinowski Date: Wed, 4 Sep 2024 07:17:30 -0400 Subject: [PATCH] retether vignettes --- .tether/vignettes-src/distribution.Rmd | 4 +--- .../vignettes-src/parked/_custom_train_step_in_torch.Rmd | 4 ++-- .../parked/_writing_a_custom_training_loop_in_jax.Rmd | 2 +- .tether/vignettes-src/writing_your_own_callbacks.Rmd | 2 +- vignettes-src/distribution.Rmd | 9 ++++----- vignettes-src/parked/_custom_train_step_in_torch.Rmd | 4 ++-- .../parked/_writing_a_custom_training_loop_in_jax.Rmd | 2 +- 7 files changed, 12 insertions(+), 15 deletions(-) diff --git a/.tether/vignettes-src/distribution.Rmd b/.tether/vignettes-src/distribution.Rmd index d1fb7deeb..28f46bfa7 100644 --- a/.tether/vignettes-src/distribution.Rmd +++ b/.tether/vignettes-src/distribution.Rmd @@ -188,9 +188,7 @@ layout_map["d1/bias"] = ("model",) # You can also set the layout for the layer output like layout_map["d2/output"] = ("data", None) -model_parallel = keras.distribution.ModelParallel( - mesh_2d, layout_map, batch_dim_name="data" -) +model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data") keras.distribution.set_distribution(model_parallel) diff --git a/.tether/vignettes-src/parked/_custom_train_step_in_torch.Rmd b/.tether/vignettes-src/parked/_custom_train_step_in_torch.Rmd index 505a4422f..e1b2edb55 100644 --- a/.tether/vignettes-src/parked/_custom_train_step_in_torch.Rmd +++ b/.tether/vignettes-src/parked/_custom_train_step_in_torch.Rmd @@ -2,7 +2,7 @@ title: Customizing what happens in `fit()` with PyTorch author: '[fchollet](https://twitter.com/fchollet)' date-created: 2023/06/27 -last-modified: 2023/06/27 +last-modified: 2024/08/01 description: Overriding the training step of the Model class with PyTorch. accelerator: GPU output: rmarkdown::html_vignette @@ -390,7 +390,7 @@ class GAN(keras.Model): def train_step(self, real_images): device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(real_images, tuple): + if isinstance(real_images, tuple) or isinstance(real_images, list): real_images = real_images[0] # Sample random points in the latent space batch_size = real_images.shape[0] diff --git a/.tether/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd b/.tether/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd index 48d789d6f..89ad7e599 100644 --- a/.tether/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd +++ b/.tether/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd @@ -153,7 +153,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y) ``` Once you have such a function, you can get the gradient function by -specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss +specifying `has_aux` in `value_and_grad`: it tells JAX that the loss computation function returns more outputs than just the loss. Note that the loss should always be the first output. diff --git a/.tether/vignettes-src/writing_your_own_callbacks.Rmd b/.tether/vignettes-src/writing_your_own_callbacks.Rmd index cee15b9c3..1803477d4 100644 --- a/.tether/vignettes-src/writing_your_own_callbacks.Rmd +++ b/.tether/vignettes-src/writing_your_own_callbacks.Rmd @@ -293,7 +293,7 @@ class EarlyStoppingAtMinLoss(keras.callbacks.Callback): # The epoch the training stops at. self.stopped_epoch = 0 # Initialize the best as infinity. - self.best = np.Inf + self.best = np.inf def on_epoch_end(self, epoch, logs=None): current = logs.get("loss") diff --git a/vignettes-src/distribution.Rmd b/vignettes-src/distribution.Rmd index c4b3b343c..a70cca771 100644 --- a/vignettes-src/distribution.Rmd +++ b/vignettes-src/distribution.Rmd @@ -93,7 +93,7 @@ mesh <- keras$distribution$DeviceMesh( # "data" as columns, and it is a [4, 2] grid when it mapped to the physical # devices on the mesh. layout_2d <- keras$distribution$TensorLayout( - axes = c("model", "data"), + axes = c("model", "data"), device_mesh = mesh ) @@ -131,8 +131,8 @@ data_parallel <- keras$distribution$DataParallel(devices = devices) # Or you can choose to create DataParallel with a 1D `DeviceMesh`. mesh_1d <- keras$distribution$DeviceMesh( - shape = shape(8), - axis_names = list("data"), + shape = shape(8), + axis_names = list("data"), devices = devices ) data_parallel <- keras$distribution$DataParallel(device_mesh = mesh_1d) @@ -213,8 +213,7 @@ layout_map["d1/bias"] <- tuple("model") layout_map["d2/output"] <- tuple("data", NULL) model_parallel <- keras$distribution$ModelParallel( - layout_map = layout_map, - batch_dim_name = "data" + layout_map, batch_dim_name = "data" ) keras$distribution$set_distribution(model_parallel) diff --git a/vignettes-src/parked/_custom_train_step_in_torch.Rmd b/vignettes-src/parked/_custom_train_step_in_torch.Rmd index f1d3106b3..078ae4d7e 100644 --- a/vignettes-src/parked/_custom_train_step_in_torch.Rmd +++ b/vignettes-src/parked/_custom_train_step_in_torch.Rmd @@ -2,7 +2,7 @@ title: Customizing what happens in `fit()` with PyTorch author: '[fchollet](https://twitter.com/fchollet)' date-created: 2023/06/27 -last-modified: 2023/06/27 +last-modified: 2024/08/01 description: Overriding the training step of the Model class with PyTorch. accelerator: GPU output: rmarkdown::html_vignette @@ -390,7 +390,7 @@ class GAN(keras.Model): def train_step(self, real_images): device = "cuda" if torch.cuda.is_available() else "cpu" - if isinstance(real_images, tuple): + if isinstance(real_images, tuple) or isinstance(real_images, list): real_images = real_images[0] # Sample random points in the latent space batch_size = real_images.shape[0] diff --git a/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd b/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd index 95c7a916f..0b483d53d 100644 --- a/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd +++ b/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd @@ -153,7 +153,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y) ``` Once you have such a function, you can get the gradient function by -specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss +specifying `has_aux` in `value_and_grad`: it tells JAX that the loss computation function returns more outputs than just the loss. Note that the loss should always be the first output.