diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 5b5a031961..0070756d8e 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + convert_message_to_dict, + request_has_media, ) from .config import FireworksImplConfig @@ -82,14 +84,14 @@ async def completion( async def _nonstream_completion( self, request: CompletionRequest, client: Fireworks ) -> CompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) r = await client.completion.acreate(**params) return process_completion_response(r, self.formatter) async def _stream_completion( self, request: CompletionRequest, client: Fireworks ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) stream = client.completion.acreate(**params) async for chunk in process_completion_stream_response(stream, self.formatter): @@ -128,33 +130,55 @@ async def chat_completion( async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks ) -> ChatCompletionResponse: - params = self._get_params(request) - r = await client.completion.acreate(**params) + params = await self._get_params(request) + if "messages" in params: + r = await client.chat.completions.acreate(**params) + else: + r = await client.completion.acreate(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) + + if "messages" in params: + stream = client.chat.completions.acreate(**params) + else: + stream = client.completion.acreate(**params) - stream = client.completion.acreate(**params) async for chunk in process_chat_completion_stream_response( stream, self.formatter ): yield chunk - def _get_params(self, request) -> dict: - prompt = "" - if type(request) == ChatCompletionRequest: - prompt = chat_completion_request_to_prompt(request, self.formatter) - elif type(request) == CompletionRequest: - prompt = completion_request_to_prompt(request, self.formatter) + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + input_dict = {} + media_present = request_has_media(request) + + if isinstance(request, ChatCompletionRequest): + if media_present: + input_dict["messages"] = [ + await convert_message_to_dict(m) for m in request.messages + ] + else: + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + elif isinstance(request, CompletionRequest): + assert ( + not media_present + ), "Fireworks does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) else: raise ValueError(f"Unknown request type {type(request)}") # Fireworks always prepends with BOS - if prompt.startswith("<|begin_of_text|>"): - prompt = prompt[len("<|begin_of_text|>") :] + if "prompt" in input_dict: + if input_dict["prompt"].startswith("<|begin_of_text|>"): + input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] options = get_sampling_options(request.sampling_params) options.setdefault("max_tokens", 512) @@ -172,9 +196,10 @@ def _get_params(self, request) -> dict: } else: raise ValueError(f"Unknown response format {fmt.type}") + return { "model": self.map_to_provider_model(request.model), - "prompt": prompt, + **input_dict, "stream": request.stream, **options, } diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 916241a7c6..3530e1234a 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -29,6 +29,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + convert_image_media_to_url, + request_has_media, ) OLLAMA_SUPPORTED_MODELS = { @@ -38,6 +40,7 @@ "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", "Llama-Guard-3-8B": "llama-guard3:8b", "Llama-Guard-3-1B": "llama-guard3:1b", + "Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16", } @@ -109,22 +112,8 @@ async def completion( else: return await self._nonstream_completion(request) - def _get_params_for_completion(self, request: CompletionRequest) -> dict: - sampling_options = get_sampling_options(request.sampling_params) - # This is needed since the Ollama API expects num_predict to be set - # for early truncation instead of max_tokens. - if sampling_options["max_tokens"] is not None: - sampling_options["num_predict"] = sampling_options["max_tokens"] - return { - "model": OLLAMA_SUPPORTED_MODELS[request.model], - "prompt": completion_request_to_prompt(request, self.formatter), - "options": sampling_options, - "raw": True, - "stream": request.stream, - } - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): s = await self.client.generate(**params) @@ -142,7 +131,7 @@ async def _generate_and_convert_to_openai_compat(): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) r = await self.client.generate(**params) assert isinstance(r, dict) @@ -183,26 +172,66 @@ async def chat_completion( else: return await self._nonstream_chat_completion(request) - def _get_params(self, request: ChatCompletionRequest) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + sampling_options = get_sampling_options(request.sampling_params) + # This is needed since the Ollama API expects num_predict to be set + # for early truncation instead of max_tokens. + if sampling_options.get("max_tokens") is not None: + sampling_options["num_predict"] = sampling_options["max_tokens"] + + input_dict = {} + media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): + if media_present: + contents = [ + await convert_message_to_dict_for_ollama(m) + for m in request.messages + ] + # flatten the list of lists + input_dict["messages"] = [ + item for sublist in contents for item in sublist + ] + else: + input_dict["raw"] = True + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + else: + assert ( + not media_present + ), "Ollama does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["raw"] = True + return { "model": OLLAMA_SUPPORTED_MODELS[request.model], - "prompt": chat_completion_request_to_prompt(request, self.formatter), - "options": get_sampling_options(request.sampling_params), - "raw": True, + **input_dict, + "options": sampling_options, "stream": request.stream, } async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params(request) - r = await self.client.generate(**params) + params = await self._get_params(request) + if "messages" in params: + r = await self.client.chat(**params) + else: + r = await self.client.generate(**params) assert isinstance(r, dict) - choice = OpenAICompatCompletionChoice( - finish_reason=r["done_reason"] if r["done"] else None, - text=r["response"], - ) + if "message" in r: + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["message"]["content"], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["response"], + ) response = OpenAICompatCompletionResponse( choices=[choice], ) @@ -211,15 +240,24 @@ async def _nonstream_chat_completion( async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): - s = await self.client.generate(**params) + if "messages" in params: + s = await self.client.chat(**params) + else: + s = await self.client.generate(**params) async for chunk in s: - choice = OpenAICompatCompletionChoice( - finish_reason=chunk["done_reason"] if chunk["done"] else None, - text=chunk["response"], - ) + if "message" in chunk: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["message"]["content"], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["response"], + ) yield OpenAICompatCompletionResponse( choices=[choice], ) @@ -236,3 +274,26 @@ async def embeddings( contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + +async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: + async def _convert_content(content) -> dict: + if isinstance(content, ImageMedia): + return { + "role": message.role, + "images": [ + await convert_image_media_to_url( + content, download=True, include_format=False + ) + ], + } + else: + return { + "role": message.role, + "content": content, + } + + if isinstance(message.content, list): + return [await _convert_content(c) for c in message.content] + else: + return [await _convert_content(message.content)] diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 5decea4823..28a5664158 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + convert_message_to_dict, + request_has_media, ) from .config import TogetherImplConfig @@ -97,12 +99,12 @@ def _get_client(self) -> Together: async def _nonstream_completion( self, request: CompletionRequest ) -> ChatCompletionResponse: - params = self._get_params_for_completion(request) + params = await self._get_params(request) r = self._get_client().completions.create(**params) return process_completion_response(r, self.formatter) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): @@ -131,14 +133,6 @@ def _build_options( return options - def _get_params_for_completion(self, request: CompletionRequest) -> dict: - return { - "model": self.map_to_provider_model(request.model), - "prompt": completion_request_to_prompt(request, self.formatter), - "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format), - } - async def chat_completion( self, model: str, @@ -171,18 +165,24 @@ async def chat_completion( async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params(request) - r = self._get_client().completions.create(**params) + params = await self._get_params(request) + if "messages" in params: + r = self._get_client().chat.completions.create(**params) + else: + r = self._get_client().completions.create(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): - s = self._get_client().completions.create(**params) + if "messages" in params: + s = self._get_client().chat.completions.create(**params) + else: + s = self._get_client().completions.create(**params) for chunk in s: yield chunk @@ -192,10 +192,29 @@ async def _to_async_generator(): ): yield chunk - def _get_params(self, request: ChatCompletionRequest) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + input_dict = {} + media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): + if media_present: + input_dict["messages"] = [ + await convert_message_to_dict(m) for m in request.messages + ] + else: + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + else: + assert ( + not media_present + ), "Together does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + return { "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request, self.formatter), + **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), } diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 5588be6c0d..b643ac238f 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -14,6 +14,11 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_image_media_to_url, + request_has_media, +) + from .config import MetaReferenceInferenceConfig from .generation import Llama from .model_parallel import LlamaModelParallelGenerator @@ -87,6 +92,7 @@ async def completion( logprobs=logprobs, ) self.check_model(request) + request = await request_with_localized_media(request) if request.stream: return self._stream_completion(request) @@ -211,6 +217,7 @@ async def chat_completion( logprobs=logprobs, ) self.check_model(request) + request = await request_with_localized_media(request) if self.config.create_distributed_process_group: if SEMAPHORE.locked(): @@ -388,3 +395,31 @@ async def embeddings( contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + +async def request_with_localized_media( + request: Union[ChatCompletionRequest, CompletionRequest], +) -> Union[ChatCompletionRequest, CompletionRequest]: + if not request_has_media(request): + return request + + async def _convert_single_content(content): + if isinstance(content, ImageMedia): + url = await convert_image_media_to_url(content, download=True) + return ImageMedia(image=URL(uri=url)) + else: + return content + + async def _convert_content(content): + if isinstance(content, list): + return [await _convert_single_content(c) for c in content] + else: + return await _convert_single_content(content) + + if isinstance(request, ChatCompletionRequest): + for m in request.messages: + m.content = await _convert_content(m.content) + else: + request.content = await _convert_content(request.content) + + return request diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 71253871de..ba60b9925a 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -19,12 +19,11 @@ def pytest_addoption(parser): def pytest_configure(config): - config.addinivalue_line( - "markers", "llama_8b: mark test to run only with the given model" - ) - config.addinivalue_line( - "markers", "llama_3b: mark test to run only with the given model" - ) + for model in ["llama_8b", "llama_3b", "llama_vision"]: + config.addinivalue_line( + "markers", f"{model}: mark test to run only with the given model" + ) + for fixture_name in INFERENCE_FIXTURES: config.addinivalue_line( "markers", @@ -37,6 +36,14 @@ def pytest_configure(config): pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"), ] +VISION_MODEL_PARAMS = [ + pytest.param( + "Llama3.2-11B-Vision-Instruct", + marks=pytest.mark.llama_vision, + id="llama_vision", + ), +] + def pytest_generate_tests(metafunc): if "inference_model" in metafunc.fixturenames: @@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc): if model: params = [pytest.param(model, id="")] else: - params = MODEL_PARAMS + cls_name = metafunc.cls.__name__ + if "Vision" in cls_name: + params = VISION_MODEL_PARAMS + else: + params = MODEL_PARAMS metafunc.parametrize( "inference_model", diff --git a/llama_stack/providers/tests/inference/pasta.jpeg b/llama_stack/providers/tests/inference/pasta.jpeg new file mode 100644 index 0000000000..e8299321c3 Binary files /dev/null and b/llama_stack/providers/tests/inference/pasta.jpeg differ diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 29fdc43a4c..342117536d 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import itertools import pytest @@ -15,6 +14,9 @@ from llama_stack.distribution.datatypes import * # noqa: F403 +from .utils import group_chunks + + # How to run this test: # # pytest -v -s llama_stack/providers/tests/inference/test_inference.py @@ -22,15 +24,6 @@ # --env FIREWORKS_API_KEY= -def group_chunks(response): - return { - event_type: list(group) - for event_type, group in itertools.groupby( - response, key=lambda chunk: chunk.event.event_type - ) - } - - def get_expected_stop_reason(model: str): return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py new file mode 100644 index 0000000000..1939d6934f --- /dev/null +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +import pytest +from PIL import Image as PIL_Image + + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 + +from .utils import group_chunks + +THIS_DIR = Path(__file__).parent + + +class TestVisionModelInference: + @pytest.mark.asyncio + async def test_vision_chat_completion_non_streaming( + self, inference_model, inference_stack + ): + inference_impl, _ = inference_stack + + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::together", + "remote::fireworks", + "remote::ollama", + ): + pytest.skip( + "Other inference providers don't support vision chat completion() yet" + ) + + images = [ + ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")), + ImageMedia( + image=URL( + uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + ) + ), + ] + + # These are a bit hit-and-miss, need to be careful + expected_strings_to_check = [ + ["spaghetti"], + ["puppy"], + ] + for image, expected_strings in zip(images, expected_strings_to_check): + response = await inference_impl.chat_completion( + model=inference_model, + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage( + content=[image, "Describe this image in two sentences."] + ), + ], + stream=False, + ) + + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + for expected_string in expected_strings: + assert expected_string in response.completion_message.content + + @pytest.mark.asyncio + async def test_vision_chat_completion_streaming( + self, inference_model, inference_stack + ): + inference_impl, _ = inference_stack + + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type not in ( + "meta-reference", + "remote::together", + "remote::fireworks", + "remote::ollama", + ): + pytest.skip( + "Other inference providers don't support vision chat completion() yet" + ) + + images = [ + ImageMedia( + image=URL( + uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + ) + ), + ] + expected_strings_to_check = [ + ["puppy"], + ] + for image, expected_strings in zip(images, expected_strings_to_check): + response = [ + r + async for r in await inference_impl.chat_completion( + model=inference_model, + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage( + content=[image, "Describe this image in two sentences."] + ), + ], + stream=True, + ) + ] + + assert len(response) > 0 + assert all( + isinstance(chunk, ChatCompletionResponseStreamChunk) + for chunk in response + ) + grouped = group_chunks(response) + assert len(grouped[ChatCompletionResponseEventType.start]) == 1 + assert len(grouped[ChatCompletionResponseEventType.progress]) > 0 + assert len(grouped[ChatCompletionResponseEventType.complete]) == 1 + + content = "".join( + chunk.event.delta + for chunk in grouped[ChatCompletionResponseEventType.progress] + ) + for expected_string in expected_strings: + assert expected_string in content diff --git a/llama_stack/providers/tests/inference/utils.py b/llama_stack/providers/tests/inference/utils.py new file mode 100644 index 0000000000..aa8d377e9d --- /dev/null +++ b/llama_stack/providers/tests/inference/utils.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import itertools + + +def group_chunks(response): + return { + event_type: list(group) + for event_type, group in itertools.groupby( + response, key=lambda chunk: chunk.event.event_type + ) + } diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 086227c731..cc3e7a2ce1 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -46,6 +46,9 @@ def text_from_choice(choice) -> str: if hasattr(choice, "delta") and choice.delta: return choice.delta.content + if hasattr(choice, "message"): + return choice.message.content + return choice.text @@ -99,7 +102,6 @@ def process_chat_completion_response( async def process_completion_stream_response( stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat ) -> AsyncGenerator: - stop_reason = None async for chunk in stream: @@ -158,6 +160,10 @@ async def process_chat_completion_stream_response( break text = text_from_choice(choice) + if not text: + # Sometimes you get empty chunks from providers + continue + # check if its a tool call ( aka starts with <|python_tag|> ) if not ipython and text.startswith("<|python_tag|>"): ipython = True diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 386146ed9b..9decf5a007 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -3,10 +3,16 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +import base64 +import io import json from typing import Tuple +import httpx + from llama_models.llama3.api.chat_format import ChatFormat +from PIL import Image as PIL_Image from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -24,6 +30,90 @@ from llama_stack.providers.utils.inference import supported_inference_models +def content_has_media(content: InterleavedTextMedia): + def _has_media_content(c): + return isinstance(c, ImageMedia) + + if isinstance(content, list): + return any(_has_media_content(c) for c in content) + else: + return _has_media_content(content) + + +def messages_have_media(messages: List[Message]): + return any(content_has_media(m.content) for m in messages) + + +def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]): + if isinstance(request, ChatCompletionRequest): + return messages_have_media(request.messages) + else: + return content_has_media(request.content) + + +async def convert_image_media_to_url( + media: ImageMedia, download: bool = False, include_format: bool = True +) -> str: + if isinstance(media.image, PIL_Image.Image): + if media.image.format == "PNG": + format = "png" + elif media.image.format == "GIF": + format = "gif" + elif media.image.format == "JPEG": + format = "jpeg" + else: + raise ValueError(f"Unsupported image format {media.image.format}") + + bytestream = io.BytesIO() + media.image.save(bytestream, format=media.image.format) + bytestream.seek(0) + content = bytestream.getvalue() + else: + if not download: + return media.image.uri + else: + assert isinstance(media.image, URL) + async with httpx.AsyncClient() as client: + r = await client.get(media.image.uri) + content = r.content + content_type = r.headers.get("content-type") + if content_type: + format = content_type.split("/")[-1] + else: + format = "png" + + if include_format: + return f"data:image/{format};base64," + base64.b64encode(content).decode( + "utf-8" + ) + else: + return base64.b64encode(content).decode("utf-8") + + +async def convert_message_to_dict(message: Message) -> dict: + async def _convert_content(content) -> dict: + if isinstance(content, ImageMedia): + return { + "type": "image_url", + "image_url": { + "url": await convert_image_media_to_url(content), + }, + } + else: + assert isinstance(content, str) + return {"type": "text", "text": content} + + if isinstance(message.content, list): + content = [await _convert_content(c) for c in message.content] + else: + content = [await _convert_content(message.content)] + + return { + "role": message.role, + "content": content, + } + + def completion_request_to_prompt( request: CompletionRequest, formatter: ChatFormat ) -> str: