Skip to content

Commit

Permalink
add new measure_time capability in ./utils (#760)
Browse files Browse the repository at this point in the history
* add new measure_time capability in ./utils and move items related to Andrej's runner to andrej-runner

* add __init__ to make utils a package

* typo

* typo

* resolve name clash
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 00c909f commit 95ae79c
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 30 deletions.
File renamed without changes.
File renamed without changes.
29 changes: 13 additions & 16 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Union
from utils.measure_time import measure_time

import torch
import torch._dynamo.config
Expand Down Expand Up @@ -381,10 +382,9 @@ def _initialize_model(
# quantize is None or quantize == "{ }"
# ), "quantize not valid for exported DSO model. Specify quantization during export."

t0 = time.time()
model = _load_model(builder_args, only_config=True)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args, only_config=True)
device_sync(device=builder_args.device)

try:
# Replace model forward with the AOT-compiled forward
Expand All @@ -409,10 +409,9 @@ def _initialize_model(
# quantize is None or quantize == "{ }"
# ), "quantize not valid for exported PTE model. Specify quantization during export."

t0 = time.time()
model = _load_model(builder_args, only_config=True)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args, only_config=True)
device_sync(device=builder_args.device)

try:
from build.model_et import PTEModel
Expand All @@ -421,17 +420,15 @@ def _initialize_model(
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
else:
t0 = time.time()
model = _load_model(builder_args)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args)
device_sync(device=builder_args.device)

if quantize:
t0q = time.time()
print(f"Quantizing the model with: {quantize}")
quantize_model(model, builder_args.device, quantize, tokenizer)
device_sync(device=builder_args.device)
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")
with measure_time("Time to quantize model: {time:.02f} seconds"):
quantize_model(model, builder_args.device, quantize, tokenizer)
device_sync(device=builder_args.device)

if builder_args.setup_caches:
with torch.device(builder_args.device):
Expand Down
28 changes: 14 additions & 14 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config

from utils.measure_time import measure_time
from build.builder import (
_initialize_model,
_initialize_tokenizer,
Expand Down Expand Up @@ -141,9 +141,9 @@ def _model_call(self, inps):
)
)
x = seq.index_select(0, input_pos).view(1, -1)
start = time.time()
logits = model_forward(self._model, x, input_pos)
self.times.append(time.time() - start)
with measure_time(message=None) as measure:
logits = model_forward(self._model, x, input_pos)
self.times.append(measure.get_time())
return logits

def _model_generate(self, context, max_length, eos_token_id):
Expand Down Expand Up @@ -241,16 +241,16 @@ def main(args) -> None:
)
torch._inductor.config.coordinate_descent_tuning = True

t1 = time.time()
result = eval(
model.to(device),
tokenizer,
tasks,
limit,
max_seq_length,
device=builder_args.device,
)
print(f"Time to run eval: {time.time() - t1:.02f}s.")
with measure_time("Time to run eval: {time:.02f}s."):
result = eval(
model.to(device),
tokenizer,
tasks,
limit,
max_seq_length,
device=builder_args.device,
)

times = torch.tensor(result["times"])
print(
f"Time in model.forward: {times.sum():.02f}s, over {times.numel()} model evaluations"
Expand Down
Empty file added utils/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions utils/measure_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from time import perf_counter
from typing import Optional

class measure_time:
def __init__(
self,
message: Optional[str] = 'Time: {time:.3f} seconds'
):
self.message = message

def __enter__(
self,
):
self.start = perf_counter()
self.message
return self

def __exit__(self, type, value, traceback):
self.time = perf_counter() - self.start
if self.message is not None:
print(self.message.format(time=self.time))

def get_time(self):
return self.time

0 comments on commit 95ae79c

Please sign in to comment.