Skip to content

Commit

Permalink
Enable vision models for (Together, Fireworks, Meta-Reference, Ollama) (
Browse files Browse the repository at this point in the history
#376)

* Enable vision models for Together and Fireworks

* Works with ollama 0.4.0 pre-release with the vision model

* localize media for meta_reference inference

* Fix
  • Loading branch information
ashwinb authored Nov 6, 2024
1 parent db30809 commit cde9bc1
Show file tree
Hide file tree
Showing 11 changed files with 465 additions and 81 deletions.
55 changes: 40 additions & 15 deletions llama_stack/providers/adapters/inference/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
convert_message_to_dict,
request_has_media,
)

from .config import FireworksImplConfig
Expand Down Expand Up @@ -82,14 +84,14 @@ async def completion(
async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks
) -> CompletionResponse:
params = self._get_params(request)
params = await self._get_params(request)
r = await client.completion.acreate(**params)
return process_completion_response(r, self.formatter)

async def _stream_completion(
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)

stream = client.completion.acreate(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
Expand Down Expand Up @@ -128,33 +130,55 @@ async def chat_completion(
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
params = await self._get_params(request)
if "messages" in params:
r = await client.chat.completions.acreate(**params)
else:
r = await client.completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)

if "messages" in params:
stream = client.chat.completions.acreate(**params)
else:
stream = client.completion.acreate(**params)

stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk

def _get_params(self, request) -> dict:
prompt = ""
if type(request) == ChatCompletionRequest:
prompt = chat_completion_request_to_prompt(request, self.formatter)
elif type(request) == CompletionRequest:
prompt = completion_request_to_prompt(request, self.formatter)
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
input_dict = {}
media_present = request_has_media(request)

if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
elif isinstance(request, CompletionRequest):
assert (
not media_present
), "Fireworks does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")

# Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]

options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512)
Expand All @@ -172,9 +196,10 @@ def _get_params(self, request) -> dict:
}
else:
raise ValueError(f"Unknown response format {fmt.type}")

return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
**input_dict,
"stream": request.stream,
**options,
}
Expand Down
125 changes: 93 additions & 32 deletions llama_stack/providers/adapters/inference/ollama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
convert_image_media_to_url,
request_has_media,
)

OLLAMA_SUPPORTED_MODELS = {
Expand All @@ -38,6 +40,7 @@
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "llama-guard3:8b",
"Llama-Guard-3-1B": "llama-guard3:1b",
"Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
}


Expand Down Expand Up @@ -109,22 +112,8 @@ async def completion(
else:
return await self._nonstream_completion(request)

def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": completion_request_to_prompt(request, self.formatter),
"options": sampling_options,
"raw": True,
"stream": request.stream,
}

async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
params = await self._get_params(request)

async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
Expand All @@ -142,7 +131,7 @@ async def _generate_and_convert_to_openai_compat():
yield chunk

async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
params = await self._get_params(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)

Expand Down Expand Up @@ -183,26 +172,66 @@ async def chat_completion(
else:
return await self._nonstream_chat_completion(request)

def _get_params(self, request: ChatCompletionRequest) -> dict:
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]

input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
contents = [
await convert_message_to_dict_for_ollama(m)
for m in request.messages
]
# flatten the list of lists
input_dict["messages"] = [
item for sublist in contents for item in sublist
]
else:
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
else:
assert (
not media_present
), "Ollama does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
input_dict["raw"] = True

return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request.sampling_params),
"raw": True,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.generate(**params)
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
else:
r = await self.client.generate(**params)
assert isinstance(r, dict)

choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
if "message" in r:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
Expand All @@ -211,15 +240,24 @@ async def _nonstream_chat_completion(
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)

async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
if "messages" in params:
s = await self.client.chat(**params)
else:
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
Expand All @@ -236,3 +274,26 @@ async def embeddings(
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()


async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
return {
"role": message.role,
"images": [
await convert_image_media_to_url(
content, download=True, include_format=False
)
],
}
else:
return {
"role": message.role,
"content": content,
}

if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]
51 changes: 35 additions & 16 deletions llama_stack/providers/adapters/inference/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
convert_message_to_dict,
request_has_media,
)

from .config import TogetherImplConfig
Expand Down Expand Up @@ -97,12 +99,12 @@ def _get_client(self) -> Together:
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_completion(request)
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
return process_completion_response(r, self.formatter)

async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
params = await self._get_params(request)

# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
Expand Down Expand Up @@ -131,14 +133,6 @@ def _build_options(

return options

def _get_params_for_completion(self, request: CompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
}

async def chat_completion(
self,
model: str,
Expand Down Expand Up @@ -171,18 +165,24 @@ async def chat_completion(
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = self._get_client().completions.create(**params)
params = await self._get_params(request)
if "messages" in params:
r = self._get_client().chat.completions.create(**params)
else:
r = self._get_client().completions.create(**params)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)

# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client().completions.create(**params)
if "messages" in params:
s = self._get_client().chat.completions.create(**params)
else:
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk

Expand All @@ -192,10 +192,29 @@ async def _to_async_generator():
):
yield chunk

def _get_params(self, request: ChatCompletionRequest) -> dict:
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
else:
assert (
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)

return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
}
Expand Down
Loading

0 comments on commit cde9bc1

Please sign in to comment.