Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yanxi0830 committed Oct 10, 2024
1 parent 73c1f04 commit 568af48
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 65 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "llama_stack_client"
version = "0.0.39"
version = "0.0.40"
description = "The official Python library for the llama-stack-client API"
dynamic = ["readme"]
license = "Apache-2.0"
Expand Down
5 changes: 5 additions & 0 deletions src/llama_stack_client/lib/cli/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
20 changes: 20 additions & 0 deletions src/llama_stack_client/lib/cli/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from tabulate import tabulate

def print_table_from_response(response, headers=[]):
if not headers:
headers = sorted(response[0].__dict__.keys())

rows = []
for spec in response:
rows.append(
[
spec.__dict__[headers[i]] for i in range(len(headers))
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
24 changes: 6 additions & 18 deletions src/llama_stack_client/lib/cli/memory_banks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from llama_stack_client.lib.cli.subcommand import Subcommand

from tabulate import tabulate
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class MemoryBanksList(Subcommand):
Expand Down Expand Up @@ -41,27 +42,14 @@ def _run_memory_banks_list_cmd(self, args: argparse.Namespace):
)

headers = [
"Identifier",
"Provider ID",
"Memory Bank Type",
"identifier",
"provider_id",
"type",
"embedding_model",
"chunk_size_in_tokens",
"overlap_size_in_tokens",
]

memory_banks_list_response = client.memory_banks.list()
rows = []

for bank_spec in memory_banks_list_response:
rows.append(
[
bank_spec.identifier,
bank_spec.provider_id,
bank_spec.type,
bank_spec.embedding_model,
bank_spec.chunk_size_in_tokens,
bank_spec.overlap_size_in_tokens,
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
if memory_banks_list_response:
print_table_from_response(memory_banks_list_response, headers)
16 changes: 4 additions & 12 deletions src/llama_stack_client/lib/cli/models/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,20 @@ def _run_models_list_cmd(self, args: argparse.Namespace):
base_url=args.endpoint,
)

headers = [
"Model ID (model)",
"Model Metadata",
"Provider Type",
"Provider Config",
]

models_get_response = client.models.get(core_model_id=args.model_id)
models_get_response = client.models.retrieve(identifier=args.model_id)

if not models_get_response:
print(
f"Model {args.model_id} is not found at distribution endpoint {args.endpoint}. Please ensure endpoint is serving specified model. "
)
return

headers = sorted(models_get_response.__dict__.keys())

rows = []
rows.append(
[
models_get_response.llama_model["core_model_id"],
json.dumps(models_get_response.llama_model, indent=4),
models_get_response.provider_config.provider_type,
json.dumps(models_get_response.provider_config.config, indent=4),
models_get_response.__dict__[headers[i]] for i in range(len(headers))
]
)

Expand Down
29 changes: 8 additions & 21 deletions src/llama_stack_client/lib/cli/models/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import json
import argparse

from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class ModelsList(Subcommand):
Expand Down Expand Up @@ -41,23 +40,11 @@ def _run_models_list_cmd(self, args: argparse.Namespace):
)

headers = [
"Model ID (model)",
"Llama Model",
"Provider ID",
"Model Metadata ",
"identifier",
"llama_model",
"provider_id",
"metadata"
]

models_list_response = client.models.list()
rows = []

for model_spec in models_list_response:
rows.append(
[
model_spec.identifier,
model_spec.llama_model,
model_spec.provider_id,
json.dumps(model_spec.metadata, indent=4),
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
response = client.models.list()
if response:
print_table_from_response(response, headers)
15 changes: 2 additions & 13 deletions src/llama_stack_client/lib/cli/shields/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import json
import argparse

from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class ShieldsList(Subcommand):
Expand Down Expand Up @@ -45,16 +44,6 @@ def _run_shields_list_cmd(self, args: argparse.Namespace):
)

shields_list_response = client.shields.list()
rows = []

if shields_list_response:
headers = sorted(shields_list_response[0].__dict__.keys())

for shield_spec in shields_list_response:
rows.append(
[
shield_spec.__dict__[headers[i]] for i in range(len(headers))
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
print_table_from_response(shields_list_response)

0 comments on commit 568af48

Please sign in to comment.