Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Portkey AI Inference Provider Integration #672

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions distributions/dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -393,5 +393,33 @@
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"portkey": [
"aiosqlite",
"blobfile",
"portkey-ai",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
]
}
17 changes: 17 additions & 0 deletions distributions/portkey/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
version: '2'
name: portkey
distribution_spec:
description: Use Portkey for running LLM inference
docker_image: null
providers:
inference:
- remote::portkey
safety:
- inline::llama-guard
memory:
- inline::meta-reference
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda
Empty file.
77 changes: 77 additions & 0 deletions distributions/portkey/run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
version: '2'
image_name: portkey
docker_image: null
conda_env: portkey
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: portkey
provider_type: remote::portkey
config:
base_url: https://api.portkey.ai
api_key: ${env.PORTKEY_API_KEY}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
memory:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/portkey}/faiss_store.db
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/portkey}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/portkey/trace_store.db}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/portkey}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: portkey
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: portkey
provider_model_id: llama-3.3-70b
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
16 changes: 16 additions & 0 deletions llama_stack/providers/remote/inference/portkey/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

from .config import PortkeyImplConfig


async def get_adapter_impl(config: PortkeyImplConfig, _deps):
from .portkey import PortkeyInferenceAdapter

assert isinstance(
config, PortkeyImplConfig
), f"Unexpected config type: {type(config)}"

impl = PortkeyInferenceAdapter(config)

await impl.initialize()

return impl
40 changes: 40 additions & 0 deletions llama_stack/providers/remote/inference/portkey/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os
from typing import Any, Dict, Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field

DEFAULT_BASE_URL = "https://api.portkey.ai/v1"


@json_schema_type
class PortkeyImplConfig(BaseModel):
base_url: str = Field(
default=os.environ.get("PORTKEY_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Portkey API",
)
api_key: Optional[str] = Field(
default=os.environ.get("PORTKEY_API_KEY"),
description="Portkey API Key",
)
virtual_key: Optional[str] = Field(
default=os.environ.get("PORTKEY_VIRTUAL_KEY"),
description="Portkey Virtual Key",
)
config: Optional[str] = Field(
default=os.environ.get("PORTKEY_CONFIG_ID"),
description="Portkey Config ID",
)

@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": "${env.PORTKEY_API_KEY}",
}
193 changes: 193 additions & 0 deletions llama_stack/providers/remote/inference/portkey/portkey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import AsyncGenerator

from portkey_ai import AsyncPortkey

from llama_models.llama3.api.chat_format import ChatFormat

from llama_models.llama3.api.tokenizer import Tokenizer

from llama_stack.apis.inference import * # noqa: F403

from llama_models.datatypes import CoreModelId

from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)

from .config import PortkeyImplConfig


model_aliases = [
build_model_alias(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"llama"
)
]


class PortkeyInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: PortkeyImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())

self.client = AsyncPortkey(
base_url=self.config.base_url, api_key=self.config.api_key
)

async def initialize(self) -> None:
return

async def shutdown(self) -> None:
pass

async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(
request,
)
else:
return await self._nonstream_completion(request)

async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = await self._get_params(request)

r = await self.client.completions.create(**params)

return process_completion_response(r, self.formatter)

async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)

stream = await self.client.completions.create(**params)

async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk

async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)

if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)

async def _nonstream_chat_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = await self._get_params(request)

r = await self.client.completions.create(**params)

return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)

stream = await self.client.completions.create(**params)

async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk

async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
if request.sampling_params and request.sampling_params.top_k:
raise ValueError("`top_k` not supported by Portkey")

prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")

return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

# async def embeddings(
# self,
# model_id: str,
# contents: List[InterleavedContent],
# ) -> EmbeddingsResponse:
# raise NotImplementedError()
Loading