Skip to content

Commit

Permalink
TF Profiler Instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
DEKHTIARJonathan committed Apr 15, 2022
1 parent 263043b commit eb97b19
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 38 deletions.
8 changes: 8 additions & 0 deletions tftrt/examples/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ def __init__(self):
"to the set location in JSON format for further processing."
)

self._parser.add_argument(
"--tf_profile_export_path",
type=str,
default=None,
help="If set, the script will export tf.profile files for further "
"performance analysis."
)

self._add_bool_argument(
name="debug",
default=False,
Expand Down
101 changes: 63 additions & 38 deletions tftrt/examples/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os

import abc
import contextlib
import copy
import json
import logging
Expand Down Expand Up @@ -390,7 +391,12 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
)

@force_gpu_resync
@tf.function()
@tf.function(jit_compile=self._args.use_xla)
def dequeue_batch(ds_iter):
return next(ds_iter)

@force_gpu_resync
@tf.function(jit_compile=self._args.use_xla)
def force_data_on_gpu(data, device="/gpu:0"):
with tf.device(device):
if isinstance(data, (list, tuple)):
Expand All @@ -403,58 +409,77 @@ def force_data_on_gpu(data, device="/gpu:0"):
output_data[k] = tf.identity(v)
else:
output_data = tf.identity(data)

return output_data

if self._args.tf_profile_export_path:
profiling_ctx = tf.profiler.experimental.Profile(
self._args.tf_profile_export_path
)
tracing_ctx = tf.profiler.experimental.Trace
else:
profiling_ctx = contextlib.nullcontext()
tracing_ctx = lambda *a, **kw: contextlib.nullcontext()

step_idx = 0
ds_iter = iter(dataset)

while True:
with profiling_ctx:

try:
start_time = time.time()
data_batch = next(ds_iter)
dequeue_times.append(time.time() - start_time)
except:
break

start_time = time.time()
data_batch = force_data_on_gpu(data_batch)
memcopy_times.append(time.time() - start_time)

x, y = self.preprocess_model_inputs(data_batch)

start_time = time.time()
y_pred = infer_batch(x)
iter_times.append(time.time() - start_time)

if not self._args.debug_performance:
log_step(
step_idx + 1,
display_every=self._args.display_every,
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
)
else:
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")
while True:

if not self._args.use_synthetic_data:
data_aggregator.aggregate_data(y_pred, y)
step_idx += 1

if (self._args.num_iterations is not None and
step_idx + 1 >= self._args.num_iterations):
break
if (self._args.num_iterations is not None and
step_idx >= self._args.num_iterations):
break

with tracing_ctx('Inference Step', step_num=step_idx, _r=1):

with tracing_ctx('Input Dequeueing', step_num=step_idx, _r=1):
try:
start_time = time.time()
data_batch = dequeue_batch(ds_iter)
dequeue_times.append(time.time() - start_time)
except:
print("[Exiting] Reached end of dataset ...")
break

with tracing_ctx('Inputs MemcpyHtoD', step_num=step_idx, _r=1):
start_time = time.time()
data_batch = force_data_on_gpu(data_batch)
memcopy_times.append(time.time() - start_time)

with tracing_ctx('Inputs Preprocessing', step_num=step_idx, _r=1):
x, y = self.preprocess_model_inputs(data_batch)

with tracing_ctx('GPU Inference', step_num=step_idx, _r=1):
start_time = time.time()
y_pred = infer_batch(x)
iter_times.append(time.time() - start_time)

if not self._args.debug_performance:
log_step(
step_idx,
display_every=self._args.display_every,
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
)
else:
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")

step_idx += 1
if not self._args.use_synthetic_data:
data_aggregator.aggregate_data(y_pred, y)

if (
not self._args.debug_performance and
step_idx % self._args.display_every != 0
): # avoids double printing
log_step(
step_idx + 1,
step_idx,
display_every=1, # force print
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
Expand Down

0 comments on commit eb97b19

Please sign in to comment.