Skip to content

Commit

Permalink
Reenable the model fit test for DTensor model under layout scope.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551608308
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jul 31, 2023
1 parent 397ad57 commit 066daf2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
4 changes: 0 additions & 4 deletions keras/dtensor/mnist_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,11 @@ def test_mnist_training(self):
self.assertEqual(train_losses, sorted(train_losses, reverse=True))

def test_model_fit(self):
if self.mesh.device_type() == "GPU":
self.skipTest("TODO(b/292596476)")

layout_map = layout_map_lib.LayoutMap(self.mesh)
with layout_map.scope():
model = integration_test_utils.get_model()

optimizer = adam.Adam(learning_rate=0.001, mesh=self.mesh)
optimizer.build(model.trainable_variables)

global_batch_size = 64
model.compile(
Expand Down
34 changes: 18 additions & 16 deletions keras/dtensor/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,33 @@ def reset_logical_devices(device_type, count):
device_type: The device_type to reset.
count: numbers of virtual device to reset to.
"""
reset_context()
devices = tf.config.list_physical_devices(device_type)
if device_type.upper() == "CPU":
tf.config.set_logical_device_configuration(
devices[0],
[
tf.config.LogicalDeviceConfiguration(),
]
* count,
if device_type.upper() not in ["CPU", "GPU"]:
raise ValueError(
"resetting logical device for non-supported device type: "
f"{device_type}"
)
elif device_type.upper() == "GPU":
reset_context()

cpus = tf.config.list_physical_devices("CPU")
if device_type.upper() == "GPU":
gpus = tf.config.list_physical_devices(device_type)
tf.config.set_logical_device_configuration(
devices[0],
gpus[0],
[
tf.config.LogicalDeviceConfiguration(
memory_limit=_DEFAULT_GPU_MEMORY_LIMIT
),
]
* count,
)
else:
dt = device_type
raise ValueError(
f"resetting logical device for non-supported device type: {dt}"
)
# Always config CPU mesh as the host mesh for DTensor
tf.config.set_logical_device_configuration(
cpus[0],
[
tf.config.LogicalDeviceConfiguration(),
]
* count,
)


def reset_dtensor():
Expand Down

0 comments on commit 066daf2

Please sign in to comment.