Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlowTrainer: Add low-level API unrolled_steps_per_execution parameter #20451

Merged
merged 2 commits into from
Nov 6, 2024

Conversation

nicolaspi
Copy link
Contributor

@nicolaspi nicolaspi commented Nov 5, 2024

Add low-level API parameter unrolled_steps_per_execution for TensorFlowTrainer.

This parameter specifies how many steps of the step_per_execution loop to unroll. Increasing this value can reduce kernel launch overhead, but will increase memory usage and compilation time.

Usage:

model.unrolled_steps_per_execution = 4
model.fit(...)

The following example shows how this parameter can reduce GPU idling by reducing kernel launch delay when a model is trained using step_per_execution = 4, resulting in a 2x speedup.

With unrolled_steps_per_execution = 1 (default value):

image

With unrolled_steps_per_execution = 4:

image

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf  # For tf.data
import tensorflow_datasets as tfds

from keras.applications import MobileNet as app_model
from keras.src import callbacks

IMG_SIZE = 224
BATCH_SIZE = 1  # small patch size to reduce the data pipeline overhead
steps_per_execution = 4
dataset_size = BATCH_SIZE * max(steps_per_execution, 32)

dataset_name = "stanford_dogs"
(ds_train, ds_test), ds_info = tfds.load(
    dataset_name, split=["train", "test"], with_info=True, as_supervised=True
)
NUM_CLASSES = ds_info.features["label"].num_classes

size = (IMG_SIZE, IMG_SIZE)
ds_train = ds_train.take(dataset_size * (len(ds_train) // dataset_size)).map(
    lambda image, label: (tf.image.resize(image, size), label)
)
ds_test = ds_test.take(dataset_size * (len(ds_test) // dataset_size)).map(
    lambda image, label: (tf.image.resize(image, size), label)
)


# One-hot / categorical encoding
def input_preprocess(image, label):
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label


ds_train = ds_train.map(input_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True)
ds_train = (
    ds_train.take(4 * steps_per_execution).cache().prefetch(tf.data.AUTOTUNE)
)

ds_test = ds_test.map(input_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = (
    ds_test.batch(batch_size=BATCH_SIZE, drop_remainder=True)
    .take(4 * steps_per_execution)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

model = app_model(
    include_top=True,
    weights=None,
    classes=NUM_CLASSES,
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
)

model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"],
    steps_per_execution=steps_per_execution,
    jit_compile=True,
)

# model.summary()
logdir = "logs/"
tb_cbk = callbacks.TensorBoard(
    logdir, histogram_freq=1, profile_batch="3,4", write_graph=False
)
model.unrolled_steps_per_execution = 4
epochs = 2
hist = model.fit(
    ds_train,
    epochs=epochs,
    validation_data=ds_test,
    callbacks=[tb_cbk],
)

@codecov-commenter
Copy link

codecov-commenter commented Nov 5, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.00%. Comparing base (272bb90) to head (daf4370).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20451      +/-   ##
==========================================
- Coverage   82.01%   82.00%   -0.01%     
==========================================
  Files         514      515       +1     
  Lines       47239    47282      +43     
  Branches     7413     7422       +9     
==========================================
+ Hits        38741    38773      +32     
- Misses       6704     6712       +8     
- Partials     1794     1797       +3     
Flag Coverage Δ
keras 81.85% <100.00%> (-0.01%) ⬇️
keras-jax 64.87% <0.00%> (-0.04%) ⬇️
keras-numpy 59.82% <0.00%> (-0.04%) ⬇️
keras-tensorflow 65.90% <100.00%> (+<0.01%) ⬆️
keras-torch 64.81% <0.00%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. It it going to be very niche (and currently undocumented), but it is useful to keep it around.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 6, 2024
@fchollet fchollet merged commit 0d96e54 into keras-team:master Nov 6, 2024
7 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Nov 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants