Skip to content

Commit

Permalink
Per token latency outliers (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Jul 3, 2024
1 parent 8ebe853 commit 7999050
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
12 changes: 6 additions & 6 deletions optimum_benchmark/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,16 @@ def get_hf_libs_info():
return {
"optimum_benchmark_version": optimum_benchmark_version(),
"optimum_benchmark_commit": get_git_revision_hash("optimum_benchmark"),
"transformers_version": transformers_version(),
"transformers_version": transformers_version() if is_transformers_available() else None,
"transformers_commit": get_git_revision_hash("transformers"),
"accelerate_version": accelerate_version(),
"accelerate_version": accelerate_version() if is_accelerate_available else None,
"accelerate_commit": get_git_revision_hash("accelerate"),
"diffusers_version": diffusers_version(),
"diffusers_version": diffusers_version() if is_diffusers_available() else None,
"diffusers_commit": get_git_revision_hash("diffusers"),
"optimum_version": optimum_version(),
"optimum_version": optimum_version() if is_optimum_available() else None,
"optimum_commit": get_git_revision_hash("optimum"),
"timm_version": timm_version(),
"timm_version": timm_version() if is_timm_available() else None,
"timm_commit": get_git_revision_hash("timm"),
"peft_version": peft_version(),
"peft_version": peft_version() if is_peft_available() else None,
"peft_commit": get_git_revision_hash("peft"),
}
25 changes: 14 additions & 11 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(self, device: str, backend: str):
self.start_time: Optional[float] = None
self.prefilled: Optional[bool] = None

self.per_token_events: List[Union[float, torch.cuda.Event]] = []
self.per_token_events: List[List[Union[float, torch.cuda.Event]]] = []
self.prefill_start_events: List[Union[float, torch.cuda.Event]] = []
self.prefill_end_events: List[Union[float, torch.cuda.Event]] = []
self.decode_start_events: List[Union[float, torch.cuda.Event]] = []
Expand All @@ -282,6 +282,9 @@ def reset(self):

@contextmanager
def track(self):
self.prefilled = False
self.per_token_events.append([])

if self.is_distributed:
torch.distributed.barrier()

Expand All @@ -291,14 +294,10 @@ def track(self):
else:
self.prefill_start_events.append(time.perf_counter())

self.prefilled = False

# this is where generate is called,
# and for each decoded token, we record an event
yield

self.prefilled = None

if self.is_asynchronous:
self.decode_end_events.append(torch.cuda.Event(enable_timing=True))
self.decode_end_events[-1].record()
Expand All @@ -308,6 +307,8 @@ def track(self):
if self.is_distributed:
torch.distributed.barrier()

self.prefilled = False

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
assert (
self.prefilled is not None
Expand All @@ -319,13 +320,13 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
else:
event = time.perf_counter()

self.per_token_events.append(event)

if not self.prefilled:
self.prefill_end_events.append(event)
self.decode_start_events.append(event)
self.prefilled = True

self.per_token_events[-1].append(event)

return scores

def get_prefill_latency(self) -> Latency:
Expand Down Expand Up @@ -368,13 +369,15 @@ def get_per_token_latency(self) -> Latency:
torch.cuda.synchronize()

latencies_list = [
self.per_token_events[i].elapsed_time(self.per_token_events[i + 1]) / 1e3
for i in range(0, len(self.per_token_events) - 1)
self.per_token_events[i][j].elapsed_time(self.per_token_events[i][j + 1]) / 1e3
for i in range(len(self.per_token_events))
for j in range(0, len(self.per_token_events[i]) - 1)
]
else:
latencies_list = [
(self.per_token_events[i + 1] - self.per_token_events[i])
for i in range(0, len(self.per_token_events) - 1)
(self.per_token_events[i][j + 1] - self.per_token_events[i][j])
for i in range(len(self.per_token_events))
for j in range(0, len(self.per_token_events[i]) - 1)
]

assert not any(latency < 0 for latency in latencies_list), "Negative latency detected"
Expand Down

0 comments on commit 7999050

Please sign in to comment.