Skip to content

Commit

Permalink
Merge pull request #9 from meta-llama/cli_fix
Browse files Browse the repository at this point in the history
[bugfix] llama-stack-client endpoint fix
  • Loading branch information
yanxi0830 authored Oct 29, 2024
2 parents c22a1c5 + 0459788 commit 251ba23
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/llama_stack_client/lib/cli/memory_banks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def __init__(self, subparsers: argparse._SubParsersAction):
self.parser.set_defaults(func=self._run_memory_banks_list_cmd)

def _add_arguments(self):
self.endpoint = get_config().get("endpoint")
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
default=self.endpoint,
)

def _run_memory_banks_list_cmd(self, args: argparse.Namespace):
args.endpoint = get_config().get("endpoint") or args.endpoint

client = LlamaStackClient(
base_url=args.endpoint,
)
Expand Down
16 changes: 8 additions & 8 deletions src/llama_stack_client/lib/cli/models/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import argparse

from tabulate import tabulate
import json

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand

from tabulate import tabulate


class ModelsGet(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
Expand All @@ -33,15 +33,17 @@ def _add_arguments(self):
help="Model ID to query information about",
)

self.endpoint = get_config().get("endpoint")
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
default=self.endpoint,
)

def _run_models_list_cmd(self, args: argparse.Namespace):
config = get_config()
if config:
args.endpoint = config.get("endpoint")

client = LlamaStackClient(
base_url=args.endpoint,
)
Expand All @@ -58,9 +60,7 @@ def _run_models_list_cmd(self, args: argparse.Namespace):

rows = []
rows.append(
[
models_get_response.__dict__[headers[i]] for i in range(len(headers))
]
[models_get_response.__dict__[headers[i]] for i in range(len(headers))]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
6 changes: 4 additions & 2 deletions src/llama_stack_client/lib/cli/models/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ def __init__(self, subparsers: argparse._SubParsersAction):
self.parser.set_defaults(func=self._run_models_list_cmd)

def _add_arguments(self):
self.endpoint = get_config().get("endpoint")
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
default=self.endpoint,
)

def _run_models_list_cmd(self, args: argparse.Namespace):
config = get_config()
if config:
args.endpoint = config.get("endpoint")

client = LlamaStackClient(
base_url=args.endpoint,
)
Expand Down
6 changes: 4 additions & 2 deletions src/llama_stack_client/lib/cli/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@ def __init__(self, subparsers: argparse._SubParsersAction):
self.parser.set_defaults(func=self._run_providers_cmd)

def _add_arguments(self):
self.endpoint = get_config().get("endpoint")
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
default=self.endpoint,
)

def _run_providers_cmd(self, args: argparse.Namespace):
config = get_config()
if config:
args.endpoint = config.get("endpoint")

client = LlamaStackClient(
base_url=args.endpoint,
)
Expand Down
11 changes: 7 additions & 4 deletions src/llama_stack_client/lib/cli/shields/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import argparse
import json

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.common.utils import print_table_from_response
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class ShieldsList(Subcommand):
Expand All @@ -26,15 +26,18 @@ def __init__(self, subparsers: argparse._SubParsersAction):
self.parser.set_defaults(func=self._run_shields_list_cmd)

def _add_arguments(self):
self.endpoint = get_config().get("endpoint")
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
default=self.endpoint,
default="",
)

def _run_shields_list_cmd(self, args: argparse.Namespace):
config = get_config()
if config:
args.endpoint = config.get("endpoint")

if not args.endpoint:
self.parser.error(
"A valid endpoint is required. Please run llama-stack-client configure first or pass in a valid endpoint with --endpoint. "
Expand Down

0 comments on commit 251ba23

Please sign in to comment.