From da82fb22f6ab53a5f83087c3f2b85b389ec2de45 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Sun, 22 Dec 2024 15:47:39 +1100 Subject: [PATCH 1/2] Add JSON structured outputs to Ollama --- .../providers/remote/inference/ollama/ollama.py | 17 +++++++++++++---- .../tests/inference/test_text_inference.py | 2 ++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index bf55c5ad2c..823c7267c8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -48,10 +48,10 @@ "llama3.1:8b-instruct-fp16", CoreModelId.llama3_1_8b_instruct.value, ), - build_model_alias_with_just_provider_model_id( - "llama3.1:8b", - CoreModelId.llama3_1_8b_instruct.value, - ), + # build_model_alias_with_just_provider_model_id( + # "llama3.1:8b", + # CoreModelId.llama3_1_8b_instruct.value, + # ), build_model_alias( "llama3.1:70b-instruct-fp16", CoreModelId.llama3_1_70b_instruct.value, @@ -214,6 +214,7 @@ async def chat_completion( tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + response_format=response_format, ) if stream: return self._stream_chat_completion(request) @@ -257,6 +258,14 @@ async def _get_params( ) input_dict["raw"] = True + if fmt := request.response_format: + if fmt.type == "json_schema": + input_dict["format"] = fmt.json_schema + elif fmt.type == "grammar": + raise NotImplementedError("Grammar response format is not supported") + else: + raise ValueError(f"Unknown response format type: {fmt.type}") + return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac080..3d3d1ad555 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -191,6 +191,7 @@ async def test_completion_structured_output(self, inference_model, inference_sta provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "inline::meta-reference", + "remote::ollama", "remote::tgi", "remote::together", "remote::fireworks", @@ -253,6 +254,7 @@ async def test_structured_output( provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "inline::meta-reference", + "remote::ollama", "remote::fireworks", "remote::tgi", "remote::together", From 0ffcbb8cdbeb45c520e03f30fc06457d6395b435 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Sun, 22 Dec 2024 16:07:02 +1100 Subject: [PATCH 2/2] uncomment --- llama_stack/providers/remote/inference/ollama/ollama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 823c7267c8..74fdcff667 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -48,10 +48,10 @@ "llama3.1:8b-instruct-fp16", CoreModelId.llama3_1_8b_instruct.value, ), - # build_model_alias_with_just_provider_model_id( - # "llama3.1:8b", - # CoreModelId.llama3_1_8b_instruct.value, - # ), + build_model_alias_with_just_provider_model_id( + "llama3.1:8b", + CoreModelId.llama3_1_8b_instruct.value, + ), build_model_alias( "llama3.1:70b-instruct-fp16", CoreModelId.llama3_1_70b_instruct.value,