Skip to content

Commit

Permalink
Fix profiling for Tensorflow and JAX (keras-team#20450)
Browse files Browse the repository at this point in the history
* Fix profiling for tensorflow and JAX

* Update doc

* Test fix
  • Loading branch information
nicolaspi authored Nov 5, 2024
1 parent 272bb90 commit 1a01cbd
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 21 deletions.
1 change: 1 addition & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.backend.jax import nn
from keras.src.backend.jax import numpy
from keras.src.backend.jax import random
from keras.src.backend.jax import tensorboard
from keras.src.backend.jax.core import IS_THREAD_SAFE
from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.jax.core import Variable
Expand Down
23 changes: 23 additions & 0 deletions keras/src/backend/jax/tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from keras.src.utils.module_utils import jax


def start_trace(logdir):
if logdir:
jax.profiler.start_trace(logdir)


def stop_trace(save):
if save:
jax.profiler.stop_trace()


def start_batch_trace(batch):
batch_trace_context = jax.profiler.TraceAnnotation(
f"Profiled batch {batch}"
)
batch_trace_context.__enter__()
return batch_trace_context


def stop_batch_trace(batch_trace_context):
batch_trace_context.__exit__(None, None, None)
14 changes: 13 additions & 1 deletion keras/src/backend/tensorflow/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
from keras.src.utils.module_utils import tensorflow as tf


def start_trace(logdir):
Expand All @@ -7,3 +7,15 @@ def start_trace(logdir):

def stop_trace(save):
tf.profiler.experimental.stop(save=save)


def start_batch_trace(batch):
batch_trace_context = tf.profiler.experimental.Trace(
"Profiled batch", step_num=batch
)
batch_trace_context.__enter__()
return batch_trace_context


def stop_batch_trace(batch_trace_context):
batch_trace_context.__exit__(None, None, None)
44 changes: 31 additions & 13 deletions keras/src/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ class TensorBoard(Callback):
[TensorBoard Scalars tutorial](
https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging)
for more details.
profile_batch: (Not supported at this time)
Profile the batch(es) to sample compute characteristics.
profile_batch: Profile the batch(es) to sample compute characteristics.
profile_batch must be a non-negative integer or a tuple of integers.
A pair of positive integers signify a range of batches to profile.
By default, profiling is disabled.
Expand Down Expand Up @@ -176,14 +175,23 @@ def __init__(
self.update_freq = 1 if update_freq == "batch" else update_freq
self.embeddings_freq = embeddings_freq
self.embeddings_metadata = embeddings_metadata
if profile_batch and backend.backend() != "tensorflow":
# TODO: profiling not available in JAX/torch
raise ValueError(
"Profiling is not yet available with the "
f"{backend.backend()} backend. Please open a PR "
"if you'd like to add this feature. Received: "
f"profile_batch={profile_batch} (must be 0)"
)
if profile_batch:
if backend.backend() not in ("jax", "tensorflow"):
# TODO: profiling not available in torch, numpy
raise ValueError(
"Profiling is not yet available with the "
f"{backend.backend()} backend. Please open a PR "
"if you'd like to add this feature. Received: "
f"profile_batch={profile_batch} (must be 0)"
)
elif backend.backend() == "jax":
if sys.version_info[1] < 12:
warnings.warn(
"Profiling with the "
f"{backend.backend()} backend requires python >= 3.12."
)
profile_batch = 0

self._init_profile_batch(profile_batch)
self._global_train_batch = 0
self._previous_epoch_iterations = 0
Expand Down Expand Up @@ -384,6 +392,8 @@ def _init_profile_batch(self, profile_batch):
# We track the status here to make sure callbacks do not interfere with
# each other. The callback will only stop the profiler it started.
self._profiler_started = False
self._batch_trace_context = None

if self._start_batch > 0:
# Warm up and improve the profiling accuracy.
self._start_profiler(logdir="")
Expand Down Expand Up @@ -437,6 +447,10 @@ def on_train_batch_begin(self, batch, logs=None):

if self._global_train_batch == self._start_batch:
self._start_trace()
if self._profiler_started:
self._batch_trace_context = backend.tensorboard.start_batch_trace(
batch
)

def on_train_batch_end(self, batch, logs=None):
if self._should_write_train_graph:
Expand All @@ -460,8 +474,12 @@ def on_train_batch_end(self, batch, logs=None):
if not self._should_trace:
return

if self._is_tracing and self._global_train_batch >= self._stop_batch:
self._stop_trace()
if self._is_tracing:
if self._profiler_started and self._batch_trace_context is not None:
backend.tensorboard.stop_batch_trace(self._batch_trace_context)
self._batch_trace_context = None
if self._global_train_batch >= self._stop_batch:
self._stop_trace()

def on_epoch_begin(self, epoch, logs=None):
# Keeps track of epoch for profiling.
Expand All @@ -483,7 +501,7 @@ def on_epoch_end(self, epoch, logs=None):

def _start_trace(self):
self.summary.trace_on(graph=True, profiler=False)
self._start_profiler(logdir=self.log_dir)
self._start_profiler(logdir=self._train_dir)
self._is_tracing = True

def _stop_trace(self, batch=None):
Expand Down
21 changes: 14 additions & 7 deletions keras/src/callbacks/tensorboard_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import os
import random
import sys

import numpy as np
import pytest
Expand Down Expand Up @@ -736,14 +737,10 @@ def test_TensorBoard_write_model(self):
pass

@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="The profiling test can only run with TF backend.",
backend.backend() not in ("jax", "tensorflow"),
reason="The profiling test can only run with TF and JAX backends.",
)
def test_TensorBoard_auto_trace(self):
# TODO: Waiting for implementation for torch/jax for profiling ops
# if backend.backend() == "jax":
# return
# TODO: Debug profiling for JAX
logdir, train_dir, validation_dir = self._get_log_dirs()
model = models.Sequential(
[
Expand All @@ -753,6 +750,16 @@ def test_TensorBoard_auto_trace(self):
]
)
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
if backend.backend() == "jax" and sys.version_info[1] < 12:
with pytest.warns(match="backend requires python >= 3.12"):
callbacks.TensorBoard(
logdir, histogram_freq=1, profile_batch=1, write_graph=False
)
self.skipTest(
"Profiling with JAX and python < 3.12 "
"raises segmentation fault."
)

tb_cbk = callbacks.TensorBoard(
logdir, histogram_freq=1, profile_batch=1, write_graph=False
)
Expand All @@ -773,5 +780,5 @@ def test_TensorBoard_auto_trace(self):
_ObservedSummary(logdir=train_dir, tag="batch_1"),
},
)
self.assertEqual(1, self._count_xplane_file(logdir=logdir))
self.assertEqual(1, self._count_xplane_file(logdir=train_dir))
pass

0 comments on commit 1a01cbd

Please sign in to comment.