Skip to content

Commit

Permalink
Merge pull request #5 from meta-llama/registry
Browse files Browse the repository at this point in the history
[SDK] adapt to server updates #201
  • Loading branch information
yanxi0830 authored Oct 10, 2024
2 parents b3b16ef + 568af48 commit c0628cf
Show file tree
Hide file tree
Showing 139 changed files with 6,794 additions and 2,147 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "llama_stack_client"
version = "0.0.39"
version = "0.0.40"
description = "The official Python library for the llama-stack-client API"
dynamic = ["readme"]
license = "Apache-2.0"
Expand Down
2 changes: 0 additions & 2 deletions src/llama_stack_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ._types import NOT_GIVEN, NoneType, NotGiven, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import (
ENVIRONMENTS,
Client,
Stream,
Timeout,
Expand Down Expand Up @@ -69,7 +68,6 @@
"AsyncStream",
"LlamaStackClient",
"AsyncLlamaStackClient",
"ENVIRONMENTS",
"file_from_path",
"BaseModel",
"DEFAULT_TIMEOUT",
Expand Down
14 changes: 12 additions & 2 deletions src/llama_stack_client/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def __init__(
self.url = url
self.params = params

@override
def __repr__(self) -> str:
if self.url:
return f"{self.__class__.__name__}(url={self.url})"
return f"{self.__class__.__name__}(params={self.params})"


class BasePage(GenericModel, Generic[_T]):
"""
Expand Down Expand Up @@ -412,7 +418,10 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()

headers.setdefault("x-stainless-retry-count", str(retries_taken))
# Don't set the retry count header if it was already set or removed by the caller. We check
# `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case.
if "x-stainless-retry-count" not in (header.lower() for header in custom_headers):
headers["x-stainless-retry-count"] = str(retries_taken)

return headers

Expand Down Expand Up @@ -686,7 +695,8 @@ def _calculate_retry_timeout(
if retry_after is not None and 0 < retry_after <= 60:
return retry_after

nb_retries = max_retries - remaining_retries
# Also cap retry count to 1000 to avoid any potential overflows with `pow`
nb_retries = min(max_retries - remaining_retries, 1000)

# Apply exponential backoff, but not more than the max.
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
Expand Down
202 changes: 84 additions & 118 deletions src/llama_stack_client/_client.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/llama_stack_client/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
if cast_to == float:
return cast(R, float(response.text))

if cast_to == bool:
return cast(R, response.text.lower() == "true")

origin = get_origin(cast_to) or cast_to

if origin == APIResponse:
Expand Down
6 changes: 0 additions & 6 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from typing import List, Optional, Union

from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agents import AgentsTurnStreamChunk

from termcolor import cprint


Expand Down Expand Up @@ -66,10 +64,6 @@ async def log(self, event_generator):
)
continue

if not isinstance(chunk, AgentsTurnStreamChunk):
yield LogEvent(chunk, color="yellow")
continue

event = chunk.event
event_type = event.payload.event_type

Expand Down
5 changes: 5 additions & 0 deletions src/llama_stack_client/lib/cli/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
20 changes: 20 additions & 0 deletions src/llama_stack_client/lib/cli/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from tabulate import tabulate

def print_table_from_response(response, headers=[]):
if not headers:
headers = sorted(response[0].__dict__.keys())

rows = []
for spec in response:
rows.append(
[
spec.__dict__[headers[i]] for i in range(len(headers))
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse

from .configure import ConfigureParser
from .providers import ProvidersParser
from .memory_banks import MemoryBanksParser

from .models import ModelsParser
Expand All @@ -31,6 +32,7 @@ def __init__(self):
MemoryBanksParser.create(subparsers)
ShieldsParser.create(subparsers)
ConfigureParser.create(subparsers)
ProvidersParser.create(subparsers)

def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
Expand Down
24 changes: 9 additions & 15 deletions src/llama_stack_client/lib/cli/memory_banks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from llama_stack_client.lib.cli.subcommand import Subcommand

from tabulate import tabulate
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class MemoryBanksList(Subcommand):
Expand Down Expand Up @@ -41,21 +42,14 @@ def _run_memory_banks_list_cmd(self, args: argparse.Namespace):
)

headers = [
"Memory Bank Type",
"Provider Type",
"Provider Config",
"identifier",
"provider_id",
"type",
"embedding_model",
"chunk_size_in_tokens",
"overlap_size_in_tokens",
]

memory_banks_list_response = client.memory_banks.list()
rows = []

for bank_spec in memory_banks_list_response:
rows.append(
[
bank_spec.bank_type,
bank_spec.provider_config.provider_type,
json.dumps(bank_spec.provider_config.config, indent=4),
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
if memory_banks_list_response:
print_table_from_response(memory_banks_list_response, headers)
16 changes: 4 additions & 12 deletions src/llama_stack_client/lib/cli/models/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,20 @@ def _run_models_list_cmd(self, args: argparse.Namespace):
base_url=args.endpoint,
)

headers = [
"Model ID (model)",
"Model Metadata",
"Provider Type",
"Provider Config",
]

models_get_response = client.models.get(core_model_id=args.model_id)
models_get_response = client.models.retrieve(identifier=args.model_id)

if not models_get_response:
print(
f"Model {args.model_id} is not found at distribution endpoint {args.endpoint}. Please ensure endpoint is serving specified model. "
)
return

headers = sorted(models_get_response.__dict__.keys())

rows = []
rows.append(
[
models_get_response.llama_model["core_model_id"],
json.dumps(models_get_response.llama_model, indent=4),
models_get_response.provider_config.provider_type,
json.dumps(models_get_response.provider_config.config, indent=4),
models_get_response.__dict__[headers[i]] for i in range(len(headers))
]
)

Expand Down
29 changes: 8 additions & 21 deletions src/llama_stack_client/lib/cli/models/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import json
import argparse

from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class ModelsList(Subcommand):
Expand Down Expand Up @@ -41,23 +40,11 @@ def _run_models_list_cmd(self, args: argparse.Namespace):
)

headers = [
"Model ID (model)",
"Model Metadata",
"Provider Type",
"Provider Config",
"identifier",
"llama_model",
"provider_id",
"metadata"
]

models_list_response = client.models.list()
rows = []

for model_spec in models_list_response:
rows.append(
[
model_spec.llama_model["core_model_id"],
json.dumps(model_spec.llama_model, indent=4),
model_spec.provider_config.provider_type,
json.dumps(model_spec.provider_config.config, indent=4),
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
response = client.models.list()
if response:
print_table_from_response(response, headers)
65 changes: 65 additions & 0 deletions src/llama_stack_client/lib/cli/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse
import os

import yaml
from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.configure import get_config


class ProvidersParser(Subcommand):
"""Configure Llama Stack Client CLI"""

def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"providers",
prog="llama-stack-client providers",
description="List available providers Llama Stack Client CLI",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_providers_cmd)

def _add_arguments(self):
self.endpoint = get_config().get("endpoint")
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
default=self.endpoint,
)

def _run_providers_cmd(self, args: argparse.Namespace):
client = LlamaStackClient(
base_url=args.endpoint,
)

headers = [
"API",
"Provider ID",
"Provider Type",
]

providers_response = client.providers.list()
rows = []

for k, v in providers_response.items():
for provider_info in v:
rows.append(
[
k,
provider_info.provider_id,
provider_info.provider_type
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
22 changes: 3 additions & 19 deletions src/llama_stack_client/lib/cli/shields/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import json
import argparse

from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class ShieldsList(Subcommand):
Expand Down Expand Up @@ -44,22 +43,7 @@ def _run_shields_list_cmd(self, args: argparse.Namespace):
base_url=args.endpoint,
)

headers = [
"Shield Type (shield_type)",
"Provider Type",
"Provider Config",
]

shields_list_response = client.shields.list()
rows = []

for shield_spec in shields_list_response:
rows.append(
[
shield_spec.shield_type,
shield_spec.provider_config.provider_type,
json.dumps(shield_spec.provider_config.config, indent=4),
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
if shields_list_response:
print_table_from_response(shields_list_response)
25 changes: 6 additions & 19 deletions src/llama_stack_client/lib/inference/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


from llama_stack_client.types import (
ChatCompletionStreamChunk,
InferenceChatCompletionResponse,
)
from termcolor import cprint


Expand All @@ -30,17 +24,10 @@ def print(self, flush=True):
class EventLogger:
async def log(self, event_generator):
for chunk in event_generator:
if isinstance(chunk, ChatCompletionStreamChunk):
event = chunk.event
if event.event_type == "start":
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == "progress":
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == "complete":
yield LogEvent("")
elif isinstance(chunk, InferenceChatCompletionResponse):
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")
else:
event = chunk.event
if event.event_type == "start":
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk, color="yellow")
elif event.event_type == "progress":
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == "complete":
yield LogEvent("")
Loading

0 comments on commit c0628cf

Please sign in to comment.