Skip to content

Commit

Permalink
Merge pull request #32 from meta-llama/dineshyv/memorybank-update-delete
Browse files Browse the repository at this point in the history
memory bank unregister
  • Loading branch information
dineshyv authored Nov 16, 2024
2 parents 649e149 + d7d6f4f commit 5eb688f
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 60 deletions.
23 changes: 23 additions & 0 deletions src/llama_stack_client/lib/cli/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# the root directory of this source tree.
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from functools import wraps


def create_bar_chart(data, labels, title=""):
Expand All @@ -28,3 +30,24 @@ def create_bar_chart(data, labels, title=""):
table.add_row(label, f"[{color}]{bar}[/] {value}/{total_count}")

console.print(table)


def handle_client_errors(operation_name):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
console = Console()
console.print(
Panel.fit(
f"[bold red]Failed to {operation_name}[/bold red]\n\n"
f"[yellow]Error Type:[/yellow] {e.__class__.__name__}\n"
f"[yellow]Details:[/yellow] {str(e)}"
)
)

return wrapper

return decorator
3 changes: 3 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.command("list")
@click.pass_context
@handle_client_errors("list datasets")
def list_datasets(ctx):
"""Show available datasets on distribution endpoint"""
client = ctx.obj["client"]
Expand Down
3 changes: 3 additions & 0 deletions src/llama_stack_client/lib/cli/datasets/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import click
import yaml

from ..common.utils import handle_client_errors


def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
Expand All @@ -38,6 +40,7 @@ def data_url_from_file(file_path: str) -> str:
)
@click.option("--schema", type=str, help="JSON schema of the dataset", required=True)
@click.pass_context
@handle_client_errors("register dataset")
def register(
ctx,
dataset_id: str,
Expand Down
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import click
import yaml

from ..common.utils import handle_client_errors
from .list import list_eval_tasks


Expand All @@ -28,6 +29,7 @@ def eval_tasks():
@click.option("--provider-eval-task-id", help="Provider's eval task ID", default=None)
@click.option("--metadata", type=str, help="Metadata for the eval task in JSON format")
@click.pass_context
@handle_client_errors("register eval task")
def register(
ctx,
eval_task_id: str,
Expand Down
4 changes: 4 additions & 0 deletions src/llama_stack_client/lib/cli/eval_tasks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from rich.table import Table


from ..common.utils import handle_client_errors


@click.command("list")
@click.pass_context
@handle_client_errors("list eval tasks")
def list_eval_tasks(ctx):
"""Show available eval tasks on distribution endpoint"""

Expand Down
82 changes: 66 additions & 16 deletions src/llama_stack_client/lib/cli/memory_banks/memory_banks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.group()
def memory_banks():
Expand All @@ -20,36 +22,66 @@ def memory_banks():

@click.command("list")
@click.pass_context
@handle_client_errors("list memory banks")
def list(ctx):
"""Show available memory banks on distribution endpoint"""

client = ctx.obj["client"]
console = Console()
memory_banks_list_response = client.memory_banks.list()
headers = []
if memory_banks_list_response and len(memory_banks_list_response) > 0:
headers = sorted(memory_banks_list_response[0].__dict__.keys())

if memory_banks_list_response:
table = Table()
for header in headers:
table.add_column(header)
# Add our specific columns
table.add_column("identifier")
table.add_column("provider_id")
table.add_column("provider_resource_id")
table.add_column("memory_bank_type")
table.add_column("params")

for item in memory_banks_list_response:
table.add_row(*[str(getattr(item, header)) for header in headers])
# Create a dict of all attributes
item_dict = item.__dict__

# Extract our main columns
identifier = str(item_dict.pop("identifier", ""))
provider_id = str(item_dict.pop("provider_id", ""))
provider_resource_id = str(item_dict.pop("provider_resource_id", ""))
memory_bank_type = str(item_dict.pop("memory_bank_type", ""))
# Convert remaining attributes to YAML string for params column
params = yaml.dump(item_dict, default_flow_style=False)

table.add_row(identifier, provider_id, provider_resource_id, memory_bank_type, params)

console.print(table)


@memory_banks.command()
@click.option("--memory-bank-id", required=True, help="Id of the memory bank")
@click.argument("memory-bank-id")
@click.option("--type", type=click.Choice(["vector", "keyvalue", "keyword", "graph"]), required=True)
@click.option("--provider-id", help="Provider ID for the memory bank", default=None)
@click.option("--provider-memory-bank-id", help="Provider's memory bank ID", default=None)
@click.option("--chunk-size", type=int, help="Chunk size in tokens (for vector type)", default=512)
@click.option("--embedding-model", type=str, help="Embedding model (for vector type)", default="all-MiniLM-L6-v2")
@click.option("--overlap-size", type=int, help="Overlap size in tokens (for vector type)", default=64)
@click.option(
"--chunk-size",
type=int,
help="Chunk size in tokens (for vector type)",
default=512,
)
@click.option(
"--embedding-model",
type=str,
help="Embedding model (for vector type)",
default="all-MiniLM-L6-v2",
)
@click.option(
"--overlap-size",
type=int,
help="Overlap size in tokens (for vector type)",
default=64,
)
@click.pass_context
def create(
@handle_client_errors("register memory bank")
def register(
ctx,
memory_bank_id: str,
type: str,
Expand All @@ -65,18 +97,24 @@ def create(
config = None
if type == "vector":
config = {
"type": "vector",
"memory_bank_type": "vector",
"chunk_size_in_tokens": chunk_size,
"embedding_model": embedding_model,
}
if overlap_size:
config["overlap_size_in_tokens"] = overlap_size
elif type == "keyvalue":
config = {"type": "keyvalue"}
config = {"memory_bank_type": "keyvalue"}
elif type == "keyword":
config = {"type": "keyword"}
config = {"memory_bank_type": "keyword"}
elif type == "graph":
config = {"type": "graph"}
config = {"memory_bank_type": "graph"}

from rich import print as rprint
from rich.pretty import pprint

rprint("\n[bold blue]Memory Bank Configuration:[/bold blue]")
pprint(config, expand_all=True)

response = client.memory_banks.register(
memory_bank_id=memory_bank_id,
Expand All @@ -88,6 +126,18 @@ def create(
click.echo(yaml.dump(response.dict()))


@memory_banks.command()
@click.argument("memory-bank-id")
@click.pass_context
@handle_client_errors("delete memory bank")
def unregister(ctx, memory_bank_id: str):
"""Delete a memory bank"""
client = ctx.obj["client"]
client.memory_banks.unregister(memory_bank_id=memory_bank_id)
click.echo(f"Memory bank '{memory_bank_id}' deleted successfully")


# Register subcommands
memory_banks.add_command(list)
memory_banks.add_command(create)
memory_banks.add_command(register)
memory_banks.add_command(unregister)
65 changes: 22 additions & 43 deletions src/llama_stack_client/lib/cli/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.group()
def models():
Expand All @@ -19,6 +21,7 @@ def models():

@click.command(name="list", help="Show available llama models at distribution endpoint")
@click.pass_context
@handle_client_errors("list models")
def list_models(ctx):
client = ctx.obj["client"]
console = Console()
Expand All @@ -43,6 +46,7 @@ def list_models(ctx):
@click.command(name="get")
@click.argument("model_id")
@click.pass_context
@handle_client_errors("get model details")
def get_model(ctx, model_id: str):
"""Show available llama models at distribution endpoint"""
client = ctx.obj["client"]
Expand All @@ -51,9 +55,10 @@ def get_model(ctx, model_id: str):
models_get_response = client.models.retrieve(identifier=model_id)

if not models_get_response:
click.echo(
console.print(
f"Model {model_id} is not found at distribution endpoint. "
"Please ensure endpoint is serving specified model."
"Please ensure endpoint is serving specified model.",
style="bold red",
)
return

Expand All @@ -72,62 +77,36 @@ def get_model(ctx, model_id: str):
@click.option("--provider-model-id", help="Provider's model ID", default=None)
@click.option("--metadata", help="JSON metadata for the model", default=None)
@click.pass_context
@handle_client_errors("register model")
def register_model(
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
):
"""Register a new model at distribution endpoint"""
client = ctx.obj["client"]
console = Console()

try:
response = client.models.register(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
)
if response:
click.echo(f"Successfully registered model {model_id}")
except Exception as e:
click.echo(f"Failed to register model: {str(e)}")


@click.command(name="update", help="Update an existing model at distribution endpoint")
@click.argument("model_id")
@click.option("--provider-id", help="Provider ID for the model", default=None)
@click.option("--provider-model-id", help="Provider's model ID", default=None)
@click.option("--metadata", help="JSON metadata for the model", default=None)
@click.pass_context
def update_model(
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
):
"""Update an existing model at distribution endpoint"""
client = ctx.obj["client"]

try:
response = client.models.update(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
)
if response:
click.echo(f"Successfully updated model {model_id}")
except Exception as e:
click.echo(f"Failed to update model: {str(e)}")
response = client.models.register(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
)
if response:
console.print(f"[green]Successfully registered model {model_id}[/green]")


@click.command(name="delete", help="Delete a model from distribution endpoint")
@click.command(name="unregister", help="Unregister a model from distribution endpoint")
@click.argument("model_id")
@click.pass_context
def delete_model(ctx, model_id: str):
"""Delete a model from distribution endpoint"""
@handle_client_errors("unregister model")
def unregister_model(ctx, model_id: str):
client = ctx.obj["client"]
console = Console()

try:
response = client.models.delete(model_id=model_id)
if response:
click.echo(f"Successfully deleted model {model_id}")
except Exception as e:
click.echo(f"Failed to delete model: {str(e)}")
response = client.models.unregister(model_id=model_id)
if response:
console.print(f"[green]Successfully deleted model {model_id}[/green]")


# Register subcommands
models.add_command(list_models)
models.add_command(get_model)
models.add_command(register_model)
models.add_command(update_model)
models.add_command(delete_model)
models.add_command(unregister_model)
3 changes: 2 additions & 1 deletion src/llama_stack_client/lib/cli/providers/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from rich.console import Console
from rich.table import Table


from ..common.utils import handle_client_errors
@click.command("list")
@click.pass_context
@handle_client_errors("list providers")
def list_providers(ctx):
"""Show available providers on distribution endpoint"""
client = ctx.obj["client"]
Expand Down
3 changes: 3 additions & 0 deletions src/llama_stack_client/lib/cli/scoring_functions/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.command("list")
@click.pass_context
@handle_client_errors("list scoring functions")
def list_scoring_functions(ctx):
"""Show available scoring functions on distribution endpoint"""

Expand Down
4 changes: 4 additions & 0 deletions src/llama_stack_client/lib/cli/shields/shields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.group()
def shields():
Expand All @@ -20,6 +22,7 @@ def shields():

@click.command("list")
@click.pass_context
@handle_client_errors("list shields")
def list(ctx):
"""Show available safety shields on distribution endpoint"""
client = ctx.obj["client"]
Expand All @@ -46,6 +49,7 @@ def list(ctx):
@click.option("--provider-shield-id", help="Provider's shield ID", default=None)
@click.option("--params", type=str, help="JSON configuration parameters for the shield", default=None)
@click.pass_context
@handle_client_errors("register shield")
def register(
ctx,
shield_id: str,
Expand Down

0 comments on commit 5eb688f

Please sign in to comment.