Skip to content

Commit

Permalink
refactor and clean
Browse files Browse the repository at this point in the history
  • Loading branch information
sky-2002 committed Oct 14, 2024
1 parent f77c000 commit 722a25a
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions tests/processors/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,29 @@
import pytest
import torch

# Import mlx and the processor class
try:
import mlx.core as mx

MLX_AVAILABLE = True
except ImportError:
MLX_AVAILABLE = False

from outlines.processors.base_logits_processor import OutlinesLogitsProcessor

arrays = {
"list": [[1.0, 2.0], [3.0, 4.0]],
"np": np.array([[1, 2], [3, 4]], dtype=np.float32),
"jax": jnp.array([[1, 2], [3, 4]], dtype=jnp.float32),
"torch": torch.tensor([[1, 2], [3, 4]], dtype=torch.float32),
"mlx": mx.array([[1, 2], [3, 4]], dtype=mx.float32),
}

try:
import mlx.core as mx

arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
except ImportError:
pass

try:
import jax.numpy as jnp

arrays["jax"] = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
except ImportError:
pass


# Mock implementation of the abstract class for testing
class MockLogitsProcessor(OutlinesLogitsProcessor):
Expand Down

0 comments on commit 722a25a

Please sign in to comment.