From 066daf2a2be866abd9b71237fa3be73c619d6472 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 27 Jul 2023 12:27:56 -0700 Subject: [PATCH] Reenable the model fit test for DTensor model under layout scope. PiperOrigin-RevId: 551608308 --- keras/dtensor/mnist_model_test.py | 4 ---- keras/dtensor/test_util.py | 34 ++++++++++++++++--------------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/keras/dtensor/mnist_model_test.py b/keras/dtensor/mnist_model_test.py index ffb172c8c7e..13cd15d5a4a 100644 --- a/keras/dtensor/mnist_model_test.py +++ b/keras/dtensor/mnist_model_test.py @@ -64,15 +64,11 @@ def test_mnist_training(self): self.assertEqual(train_losses, sorted(train_losses, reverse=True)) def test_model_fit(self): - if self.mesh.device_type() == "GPU": - self.skipTest("TODO(b/292596476)") - layout_map = layout_map_lib.LayoutMap(self.mesh) with layout_map.scope(): model = integration_test_utils.get_model() optimizer = adam.Adam(learning_rate=0.001, mesh=self.mesh) - optimizer.build(model.trainable_variables) global_batch_size = 64 model.compile( diff --git a/keras/dtensor/test_util.py b/keras/dtensor/test_util.py index 84ed3458b04..44e2b7f709c 100644 --- a/keras/dtensor/test_util.py +++ b/keras/dtensor/test_util.py @@ -117,19 +117,18 @@ def reset_logical_devices(device_type, count): device_type: The device_type to reset. count: numbers of virtual device to reset to. """ - reset_context() - devices = tf.config.list_physical_devices(device_type) - if device_type.upper() == "CPU": - tf.config.set_logical_device_configuration( - devices[0], - [ - tf.config.LogicalDeviceConfiguration(), - ] - * count, + if device_type.upper() not in ["CPU", "GPU"]: + raise ValueError( + "resetting logical device for non-supported device type: " + f"{device_type}" ) - elif device_type.upper() == "GPU": + reset_context() + + cpus = tf.config.list_physical_devices("CPU") + if device_type.upper() == "GPU": + gpus = tf.config.list_physical_devices(device_type) tf.config.set_logical_device_configuration( - devices[0], + gpus[0], [ tf.config.LogicalDeviceConfiguration( memory_limit=_DEFAULT_GPU_MEMORY_LIMIT @@ -137,11 +136,14 @@ def reset_logical_devices(device_type, count): ] * count, ) - else: - dt = device_type - raise ValueError( - f"resetting logical device for non-supported device type: {dt}" - ) + # Always config CPU mesh as the host mesh for DTensor + tf.config.set_logical_device_configuration( + cpus[0], + [ + tf.config.LogicalDeviceConfiguration(), + ] + * count, + ) def reset_dtensor():