diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index 6e63ef5b6..d8b7e032c 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -167,12 +167,7 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]: prob = softmax_logits[0, token] return token, prob - kv_heads = ( - [self.model.n_kv_heads] * len(self.model.layers) - if isinstance(self.model.n_kv_heads, int) - else self.model.n_kv_heads - ) - cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads] + cache = mlx_lm.models.cache.make_prompt_cache(self.model) # kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model() unprocessed_input_ids = prompt diff --git a/pyproject.toml b/pyproject.toml index 1fd2897aa..e4f12f76c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ test = [ "beartype<0.16.0", "responses", "llama-cpp-python", - "mlx-lm; platform_machine == 'arm64' and sys_platform == 'darwin'", + "mlx-lm>=0.19.2; platform_machine == 'arm64' and sys_platform == 'darwin'", "huggingface_hub", "openai>=1.0.0", "vllm; sys_platform != 'darwin'", diff --git a/tests/models/test_mlxlm.py b/tests/models/test_mlxlm.py new file mode 100644 index 000000000..20e59da81 --- /dev/null +++ b/tests/models/test_mlxlm.py @@ -0,0 +1,100 @@ +import pytest + +from outlines.models.mlxlm import mlxlm +from outlines.models.transformers import TransformerTokenizer + +try: + import mlx.core as mx + + HAS_MLX = mx.metal.is_available() +except ImportError: + HAS_MLX = False + + +TEST_MODEL = "mlx-community/SmolLM-135M-Instruct-4bit" + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_model(): + model = mlxlm(TEST_MODEL) + assert hasattr(model, "model") + assert hasattr(model, "tokenizer") + assert isinstance(model.tokenizer, TransformerTokenizer) + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_tokenizer(): + model = mlxlm(TEST_MODEL) + + # Test single string encoding/decoding + test_text = "Hello, world!" + token_ids = mx.array(model.mlx_tokenizer.encode(test_text)) + assert isinstance(token_ids, mx.array) + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_generate(): + from outlines.generate.api import GenerationParameters, SamplingParameters + + model = mlxlm(TEST_MODEL) + prompt = "Write a haiku about programming:" + + # Test with basic generation parameters + gen_params = GenerationParameters(max_tokens=50, stop_at=None, seed=None) + + # Test with different sampling parameters + sampling_params = SamplingParameters( + sampler="multinomial", num_samples=1, top_p=0.9, top_k=None, temperature=0.7 + ) + + # Test generation + output = model.generate(prompt, gen_params, None, sampling_params) + assert isinstance(output, str) + assert len(output) > 0 + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_stream(): + from outlines.generate.api import GenerationParameters, SamplingParameters + + model = mlxlm(TEST_MODEL) + prompt = "Count from 1 to 5:" + + gen_params = GenerationParameters(max_tokens=20, stop_at=None, seed=None) + + sampling_params = SamplingParameters( + sampler="greedy", # Use greedy sampling for deterministic output + num_samples=1, + top_p=None, + top_k=None, + temperature=0.0, + ) + + # Test streaming + stream = model.stream(prompt, gen_params, None, sampling_params) + tokens = list(stream) + assert len(tokens) > 0 + assert all(isinstance(token, str) for token in tokens) + + # Test that concatenated streaming output matches generate output + streamed_text = "".join(tokens) + generated_text = model.generate(prompt, gen_params, None, sampling_params) + assert streamed_text == generated_text + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_errors(): + model = mlxlm(TEST_MODEL) + + # Test batch inference (should raise NotImplementedError) + with pytest.raises(NotImplementedError): + from outlines.generate.api import GenerationParameters, SamplingParameters + + gen_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None) + sampling_params = SamplingParameters("multinomial", 1, None, None, 1.0) + model.generate(["prompt1", "prompt2"], gen_params, None, sampling_params) + + # Test beam search (should raise NotImplementedError) + with pytest.raises(NotImplementedError): + sampling_params = SamplingParameters("beam_search", 1, None, None, 1.0) + model.generate("test prompt", gen_params, None, sampling_params)