Skip to content

Commit

Permalink
Enable steps_per_execution tuning for custom training loops
Browse files Browse the repository at this point in the history
Makes the underlying class available and adds documentation.

PiperOrigin-RevId: 540673611
  • Loading branch information
grasskin authored and tensorflower-gardener committed Aug 16, 2023
1 parent cdffff8 commit 5c79875
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
path: "tensorflow.keras.utils.StepsPerExecutionTuner"
tf_class {
is_instance: "<class \'keras.utils.steps_per_execution_tuning.StepsPerExecutionTuner\'>"
is_instance: "<type \'object\'>"
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'optimizer\', \'spe_variable\', \'interval\', \'change_spe_interval\', \'change_threshold\'], varargs=None, keywords=None, defaults=[\'5\', \'10\', \'0.1\'], "
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "stop"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.utils.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ tf_module {
name: "SequenceEnqueuer"
mtype: "<type \'type\'>"
}
member {
name: "StepsPerExecutionTuner"
mtype: "<type \'type\'>"
}
member {
name: "custom_object_scope"
mtype: "<type \'type\'>"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
path: "tensorflow.keras.utils.StepsPerExecutionTuner"
tf_class {
is_instance: "<class \'keras.utils.steps_per_execution_tuning.StepsPerExecutionTuner\'>"
is_instance: "<type \'object\'>"
member {
name: "steps_per_execution"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'optimizer\', \'spe_variable\', \'interval\', \'change_spe_interval\', \'change_threshold\'], varargs=None, keywords=None, defaults=[\'5\', \'10\', \'0.1\'], "
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "stop"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.utils.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ tf_module {
name: "SidecarEvaluator"
mtype: "<type \'type\'>"
}
member {
name: "StepsPerExecutionTuner"
mtype: "<type \'type\'>"
}
member {
name: "TimedThread"
mtype: "<type \'type\'>"
Expand Down
22 changes: 1 addition & 21 deletions keras/engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ py_library(
":input_spec",
":keras_tensor",
":node",
":steps_per_execution_tuning",
"//:expect_h5py_installed",
"//:expect_tensorboard_installed",
"//:expect_tensorflow_installed",
Expand All @@ -70,6 +69,7 @@ py_library(
"//keras/utils:engine_utils",
"//keras/utils:metrics_utils",
"//keras/utils:mode_keys",
"//keras/utils:steps_per_execution_tuning",
"//keras/utils:tf_utils",
"//keras/utils:version_utils",
],
Expand Down Expand Up @@ -206,26 +206,6 @@ py_library(
],
)

py_library(
name = "steps_per_execution_tuning",
srcs = ["steps_per_execution_tuning.py"],
srcs_version = "PY3",
deps = [
"//:expect_numpy_installed",
],
)

tf_py_test(
name = "steps_per_execution_tuning_test",
srcs = ["steps_per_execution_tuning_test.py"],
python_version = "PY3",
deps = [
":steps_per_execution_tuning",
"//:expect_tensorflow_installed",
"//keras/testing_infra:test_combinations",
],
)

tf_py_test(
name = "base_layer_utils_test",
srcs = ["base_layer_utils_test.py"],
Expand Down
2 changes: 1 addition & 1 deletion keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from keras.engine import compile_utils
from keras.engine import data_adapter
from keras.engine import input_layer as input_layer_module
from keras.engine import steps_per_execution_tuning
from keras.engine import training_utils
from keras.metrics import base_metric
from keras.mixed_precision import loss_scale_optimizer as lso
Expand All @@ -55,6 +54,7 @@
from keras.utils import generic_utils
from keras.utils import io_utils
from keras.utils import layer_utils
from keras.utils import steps_per_execution_tuning
from keras.utils import tf_inspect
from keras.utils import tf_utils
from keras.utils import traceback_utils
Expand Down
20 changes: 20 additions & 0 deletions keras/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,26 @@ py_library(
srcs_version = "PY3",
)

py_library(
name = "steps_per_execution_tuning",
srcs = ["steps_per_execution_tuning.py"],
srcs_version = "PY3",
deps = [
"//:expect_numpy_installed",
],
)

tf_py_test(
name = "steps_per_execution_tuning_test",
srcs = ["steps_per_execution_tuning_test.py"],
python_version = "PY3",
deps = [
":steps_per_execution_tuning",
"//:expect_tensorflow_installed",
"//keras/testing_infra:test_combinations",
],
)

tf_py_test(
name = "sidecar_evaluator_test",
size = "medium",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import time

import numpy as np
from tensorflow.python.util.tf_export import keras_export


@keras_export("keras.utils.StepsPerExecutionTuner")
class StepsPerExecutionTuner:
"""Steps per execution tuner class.
Expand All @@ -37,7 +39,82 @@ class StepsPerExecutionTuner:
before tuning. Defaults to 10.
change_threshold: Optional float, the percent different in throughput to
trigger a `steps_per_execution` change. For example, `0.1` triggers
changes if throughput ()
changes if throughput changes more than 10%.
Examples:
If you're using `model.compile` and `model.fit`, this functionality is
available at compile time with `steps_per_execution='auto'`
```python
model.compile(..., steps_per_execution='auto')
```
Custom training loop usage:
```python
# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Create our steps per execution variable
steps_per_execution = tf.Variable(
1,
dtype="int64",
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
)
# Create the tuner
tuner = StepsPerExecutionTuner(
optimizer, steps_per_execution
)
# Create a step function that runs a single training step
@tf.function
def step_fn(iterator):
batch_data, labels = next(iterator)
with tf.GradientTape() as tape:
logits = model(batch_data, training=True)
loss_value = loss_fn(labels, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# We can now pack multiple execution steps into one call
@tf.function
def multi_step_train_fn(iterator, steps_per_execution):
for _ in tf.range(steps_per_execution):
outputs = step_fn(iterator)
return
initial_steps_per_execution = 1
steps_per_epoch = 100
epochs = 2
# Start the tuner before training
tuner.start()
# We can now call our multi step training with our data
for epoch in range(epochs):
for _ in range(steps_per_epoch):
multi_step_train_fn(iterator, steps_per_execution)
# End the tuner after training
tuner.stop()
```
"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@

import tensorflow.compat.v2 as tf

from keras.engine import steps_per_execution_tuning
from keras import Input
from keras import Model
from keras import losses
from keras import optimizers
from keras.layers import Dense
from keras.testing_infra import test_combinations
from keras.utils import steps_per_execution_tuning


class mockOptimizer:
Expand Down Expand Up @@ -67,6 +72,69 @@ def test_settable_steps_per_execution(self):
assert spe_variable.numpy().item() == 5
assert tuner.init_spe == 5

def test_custom_training_loop(self):
dataset = _get_dataset()
iterator = iter(dataset)

inputs = Input(shape=(784,), name="digits")
x = Dense(64, activation="relu", name="dense_1")(inputs)
x = Dense(64, activation="relu", name="dense_2")(x)
outputs = Dense(10, name="predictions")(x)
model = Model(inputs=inputs, outputs=outputs)
optimizer = optimizers.SGD(learning_rate=1e-3)
loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)

# Create our steps per execution variable
steps_per_execution = tf.Variable(
1,
dtype="int64",
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)

# Create the tuner
tuner = steps_per_execution_tuning.StepsPerExecutionTuner(
optimizer, steps_per_execution
)

# Create a step function that runs a single training step
@tf.function
def step_fn(iterator):
batch_data, labels = next(iterator)
print(batch_data.shape, labels.shape)
with tf.GradientTape() as tape:
logits = model(batch_data, training=True)
loss_value = loss_fn(labels, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

# We can now pack multiple execution steps into one call
@tf.function
def multi_step_train_fn(iterator, steps_per_execution):
for _ in tf.range(steps_per_execution):
step_fn(iterator)
return

steps_per_epoch = 10
epochs = 2

# Start the tuner before training
tuner.start()

for _ in range(epochs):
for _ in range(steps_per_epoch):
multi_step_train_fn(iterator, steps_per_execution)

# End the tuner after training
tuner.stop()


def _get_dataset():
inputs = tf.zeros((1000, 784), dtype=tf.float32)
targets = tf.zeros((1000,), dtype=tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.batch(10)
return dataset


if __name__ == "__main__":
tf.test.main()

0 comments on commit 5c79875

Please sign in to comment.