diff --git a/src/llama_stack_client/lib/cli/memory_banks/list.py b/src/llama_stack_client/lib/cli/memory_banks/list.py index a205a90..d57e647 100644 --- a/src/llama_stack_client/lib/cli/memory_banks/list.py +++ b/src/llama_stack_client/lib/cli/memory_banks/list.py @@ -28,15 +28,15 @@ def __init__(self, subparsers: argparse._SubParsersAction): self.parser.set_defaults(func=self._run_memory_banks_list_cmd) def _add_arguments(self): - self.endpoint = get_config().get("endpoint") self.parser.add_argument( "--endpoint", type=str, help="Llama Stack distribution endpoint", - default=self.endpoint, ) def _run_memory_banks_list_cmd(self, args: argparse.Namespace): + args.endpoint = get_config().get("endpoint") or args.endpoint + client = LlamaStackClient( base_url=args.endpoint, ) diff --git a/src/llama_stack_client/lib/cli/models/get.py b/src/llama_stack_client/lib/cli/models/get.py index 3ae97b5..dd8a7bc 100644 --- a/src/llama_stack_client/lib/cli/models/get.py +++ b/src/llama_stack_client/lib/cli/models/get.py @@ -4,15 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import argparse - -from tabulate import tabulate +import json from llama_stack_client import LlamaStackClient from llama_stack_client.lib.cli.configure import get_config from llama_stack_client.lib.cli.subcommand import Subcommand +from tabulate import tabulate + class ModelsGet(Subcommand): def __init__(self, subparsers: argparse._SubParsersAction): @@ -33,15 +33,17 @@ def _add_arguments(self): help="Model ID to query information about", ) - self.endpoint = get_config().get("endpoint") self.parser.add_argument( "--endpoint", type=str, help="Llama Stack distribution endpoint", - default=self.endpoint, ) def _run_models_list_cmd(self, args: argparse.Namespace): + config = get_config() + if config: + args.endpoint = config.get("endpoint") + client = LlamaStackClient( base_url=args.endpoint, ) @@ -58,9 +60,7 @@ def _run_models_list_cmd(self, args: argparse.Namespace): rows = [] rows.append( - [ - models_get_response.__dict__[headers[i]] for i in range(len(headers)) - ] + [models_get_response.__dict__[headers[i]] for i in range(len(headers))] ) print(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/models/list.py b/src/llama_stack_client/lib/cli/models/list.py index bb3fb48..48d55dd 100644 --- a/src/llama_stack_client/lib/cli/models/list.py +++ b/src/llama_stack_client/lib/cli/models/list.py @@ -26,15 +26,17 @@ def __init__(self, subparsers: argparse._SubParsersAction): self.parser.set_defaults(func=self._run_models_list_cmd) def _add_arguments(self): - self.endpoint = get_config().get("endpoint") self.parser.add_argument( "--endpoint", type=str, help="Llama Stack distribution endpoint", - default=self.endpoint, ) def _run_models_list_cmd(self, args: argparse.Namespace): + config = get_config() + if config: + args.endpoint = config.get("endpoint") + client = LlamaStackClient( base_url=args.endpoint, ) diff --git a/src/llama_stack_client/lib/cli/providers.py b/src/llama_stack_client/lib/cli/providers.py index a0575de..0006713 100644 --- a/src/llama_stack_client/lib/cli/providers.py +++ b/src/llama_stack_client/lib/cli/providers.py @@ -30,15 +30,17 @@ def __init__(self, subparsers: argparse._SubParsersAction): self.parser.set_defaults(func=self._run_providers_cmd) def _add_arguments(self): - self.endpoint = get_config().get("endpoint") self.parser.add_argument( "--endpoint", type=str, help="Llama Stack distribution endpoint", - default=self.endpoint, ) def _run_providers_cmd(self, args: argparse.Namespace): + config = get_config() + if config: + args.endpoint = config.get("endpoint") + client = LlamaStackClient( base_url=args.endpoint, ) diff --git a/src/llama_stack_client/lib/cli/shields/list.py b/src/llama_stack_client/lib/cli/shields/list.py index a53db0f..593b510 100644 --- a/src/llama_stack_client/lib/cli/shields/list.py +++ b/src/llama_stack_client/lib/cli/shields/list.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import argparse +import json from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.cli.common.utils import print_table_from_response from llama_stack_client.lib.cli.configure import get_config from llama_stack_client.lib.cli.subcommand import Subcommand -from llama_stack_client.lib.cli.common.utils import print_table_from_response class ShieldsList(Subcommand): @@ -26,15 +26,18 @@ def __init__(self, subparsers: argparse._SubParsersAction): self.parser.set_defaults(func=self._run_shields_list_cmd) def _add_arguments(self): - self.endpoint = get_config().get("endpoint") self.parser.add_argument( "--endpoint", type=str, help="Llama Stack distribution endpoint", - default=self.endpoint, + default="", ) def _run_shields_list_cmd(self, args: argparse.Namespace): + config = get_config() + if config: + args.endpoint = config.get("endpoint") + if not args.endpoint: self.parser.error( "A valid endpoint is required. Please run llama-stack-client configure first or pass in a valid endpoint with --endpoint. "