Skip to content

Commit

Permalink
add mps to processor benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
hoesler authored and rlouf committed Jan 2, 2025
1 parent 2f0740e commit 6a8612b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,12 @@ def get_mock_processor_inputs(array_library, num_tokens=30000):
logits: (4, 30,000 ) dtype=float
input_ids shape: (4, 2048) dtype=int
"""
if array_library == "torch":
logits = torch.rand((4, num_tokens), dtype=torch.float)
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int
)
elif array_library == "torch_cuda":
logits = torch.rand((4, num_tokens), dtype=torch.float, device="cuda")
if array_library.startswith("torch"):
device = array_library.split("_")[1] if "_" in array_library else "cpu"

logits = torch.rand((4, num_tokens), dtype=torch.float, device=device)
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device="cuda"
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device=device
)
elif array_library == "numpy":
logits = np.random.rand(4, num_tokens).astype(np.float32)
Expand Down Expand Up @@ -88,6 +85,8 @@ class LogitsProcessorPassthroughBenchmark:
params += ["mlx"]
if torch.cuda.is_available():
params += ["torch_cuda"]
if torch.mps.is_available():
params += ["torch_mps"]
if is_jax_allowed():
params += ["jax"]

Expand All @@ -108,9 +107,10 @@ class LogitsProcessorStructuredBenchmark:
array_libraries = ["torch", "numpy"]
if is_mlx_lm_allowed():
array_libraries += ["mlx"]
# PR TODO
if torch.cuda.is_available():
array_libraries += ["torch_cuda"]
if torch.mps.is_available():
array_libraries += ["torch_mps"]

# accept very many or very few tokens, respectively
patterns = [r"[^Z]*", "Z*"]
Expand Down

0 comments on commit 6a8612b

Please sign in to comment.