From aa333be9b9ef291ce2cb5f8d374241fd197e4589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20H=C3=B6sler?= Date: Thu, 28 Nov 2024 13:16:25 +0100 Subject: [PATCH 1/3] remove non_blocking=True --- outlines/processors/structured.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index d2bc15f77..104257cdf 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -108,7 +108,7 @@ def process_logits( batch_indices = [] for i, guide_state in enumerate(sequence_states): allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( - mask.device, non_blocking=True + mask.device ) allowed_tokens_batch.append(allowed_tokens) batch_indices.append( From 55cc500ed0b5b2756804164044359759021c6c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20H=C3=B6sler?= Date: Thu, 5 Dec 2024 15:24:00 +0100 Subject: [PATCH 2/3] optimize tensor creation --- outlines/processors/structured.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index 104257cdf..7d5c9cc70 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -102,22 +102,19 @@ def process_logits( sequence_states.append(self._guide_states[curr_state_key]) - mask = torch.ones_like(logits, dtype=torch.bool) - allowed_tokens_batch = [] batch_indices = [] for i, guide_state in enumerate(sequence_states): - allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( - mask.device - ) + allowed_tokens = self.guide.get_next_instruction(guide_state).tokens allowed_tokens_batch.append(allowed_tokens) batch_indices.append( torch.full_like(allowed_tokens, i) ) # Store batch index for each allowed token - allowed_tokens_concat = torch.cat(allowed_tokens_batch) - batch_indices_concat = torch.cat(batch_indices) + allowed_tokens_concat = torch.cat(allowed_tokens_batch).to(logits.device) + batch_indices_concat = torch.cat(batch_indices).to(logits.device) + mask = torch.ones_like(logits, dtype=torch.bool) mask[batch_indices_concat, allowed_tokens_concat] = False logits.masked_fill_(mask, float("-inf")) From 35fc711a9862d0a3b8ce69f9668dc6e2c0bf90cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20H=C3=B6sler?= Date: Thu, 5 Dec 2024 21:18:32 +0100 Subject: [PATCH 3/3] add mps to processor benchmark --- benchmarks/bench_processors.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/benchmarks/bench_processors.py b/benchmarks/bench_processors.py index db1e4a8f1..02ea52b79 100644 --- a/benchmarks/bench_processors.py +++ b/benchmarks/bench_processors.py @@ -37,15 +37,12 @@ def get_mock_processor_inputs(array_library, num_tokens=30000): logits: (4, 30,000 ) dtype=float input_ids shape: (4, 2048) dtype=int """ - if array_library == "torch": - logits = torch.rand((4, num_tokens), dtype=torch.float) - input_ids = torch.randint( - low=0, high=num_tokens, size=(4, 2048), dtype=torch.int - ) - elif array_library == "torch_cuda": - logits = torch.rand((4, num_tokens), dtype=torch.float, device="cuda") + if array_library.startswith("torch"): + device = array_library.split("_")[1] if "_" in array_library else "cpu" + + logits = torch.rand((4, num_tokens), dtype=torch.float, device=device) input_ids = torch.randint( - low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device="cuda" + low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device=device ) elif array_library == "numpy": logits = np.random.rand(4, num_tokens).astype(np.float32) @@ -88,6 +85,8 @@ class LogitsProcessorPassthroughBenchmark: params += ["mlx"] if torch.cuda.is_available(): params += ["torch_cuda"] + if torch.mps.is_available(): + params += ["torch_mps"] if is_jax_allowed(): params += ["jax"] @@ -108,9 +107,10 @@ class LogitsProcessorStructuredBenchmark: array_libraries = ["torch", "numpy"] if is_mlx_lm_allowed(): array_libraries += ["mlx"] - # PR TODO if torch.cuda.is_available(): array_libraries += ["torch_cuda"] + if torch.mps.is_available(): + array_libraries += ["torch_mps"] # accept very many or very few tokens, respectively patterns = [r"[^Z]*", "Z*"]