From c8302cd4e3aee0d5d580932d44615772b0d87903 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 22 Nov 2024 11:51:19 -0800 Subject: [PATCH] precommit --- .pre-commit-config.yaml | 11 ++++------- src/llama_stack_client/lib/agents/agent.py | 5 ++++- src/llama_stack_client/lib/agents/custom_tool.py | 8 ++++---- src/llama_stack_client/lib/agents/event_logger.py | 2 +- src/llama_stack_client/lib/cli/common/utils.py | 5 +++-- src/llama_stack_client/lib/cli/configure.py | 3 ++- src/llama_stack_client/lib/cli/eval_tasks/list.py | 1 - .../lib/cli/llama_stack_client.py | 3 +-- src/llama_stack_client/lib/cli/providers/list.py | 2 ++ src/llama_stack_client/lib/direct/direct.py | 14 ++++++-------- src/llama_stack_client/lib/direct/test.py | 1 + 11 files changed, 28 insertions(+), 27 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93a60ca..0f60888 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,11 +26,8 @@ repos: - torchfix args: ['--config=.flake8'] -- repo: https://github.com/omnilib/ufmt - rev: v2.7.0 +- repo: https://github.com/pycqa/isort + rev: 5.13.2 hooks: - - id: ufmt - files: ^src/llama_stack_client/lib/.* # Only run on files in specific-folder - additional_dependencies: - - black == 24.4.2 - - usort == 1.0.8 + - id: isort + files: ^src/llama_stack_client/lib/.* diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 1d89c8b..93825a0 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -6,10 +6,13 @@ from typing import List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient -from llama_stack_client.types import Attachment, ToolResponseMessage, UserMessage +from llama_stack_client.types import (Attachment, ToolResponseMessage, + UserMessage) from llama_stack_client.types.agent_create_params import AgentConfig + from .custom_tool import CustomTool + class Agent: def __init__(self, client: LlamaStackClient, agent_config: AgentConfig, custom_tools: Tuple[CustomTool] = ()): self.client = client diff --git a/src/llama_stack_client/lib/agents/custom_tool.py b/src/llama_stack_client/lib/agents/custom_tool.py index e1205ca..0da5bb3 100644 --- a/src/llama_stack_client/lib/agents/custom_tool.py +++ b/src/llama_stack_client/lib/agents/custom_tool.py @@ -5,13 +5,13 @@ # the root directory of this source tree. import json - from abc import abstractmethod from typing import Dict, List, Union -from llama_stack_client.types import FunctionCallToolDefinition, ToolResponseMessage, UserMessage - -from llama_stack_client.types.tool_param_definition_param import ToolParamDefinitionParam +from llama_stack_client.types import (FunctionCallToolDefinition, + ToolResponseMessage, UserMessage) +from llama_stack_client.types.tool_param_definition_param import \ + ToolParamDefinitionParam class CustomTool: diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index e356463..7926e3c 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -153,4 +153,4 @@ def log(self, event_generator): for chunk in event_generator: for log_event in self._get_log_event(chunk, previous_event_type, previous_step_type): yield log_event - previous_event_type, previous_step_type = self._get_event_type_step_type(chunk) \ No newline at end of file + previous_event_type, previous_step_type = self._get_event_type_step_type(chunk) diff --git a/src/llama_stack_client/lib/cli/common/utils.py b/src/llama_stack_client/lib/cli/common/utils.py index 4f7893a..faf9ac2 100644 --- a/src/llama_stack_client/lib/cli/common/utils.py +++ b/src/llama_stack_client/lib/cli/common/utils.py @@ -3,10 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from functools import wraps + from rich.console import Console -from rich.table import Table from rich.panel import Panel -from functools import wraps +from rich.table import Table def create_bar_chart(data, labels, title=""): diff --git a/src/llama_stack_client/lib/cli/configure.py b/src/llama_stack_client/lib/cli/configure.py index 48afbad..88f5b0b 100644 --- a/src/llama_stack_client/lib/cli/configure.py +++ b/src/llama_stack_client/lib/cli/configure.py @@ -11,7 +11,8 @@ from prompt_toolkit import prompt from prompt_toolkit.validation import Validator -from llama_stack_client.lib.cli.constants import get_config_file_path, LLAMA_STACK_CLIENT_CONFIG_DIR +from llama_stack_client.lib.cli.constants import ( + LLAMA_STACK_CLIENT_CONFIG_DIR, get_config_file_path) def get_config(): diff --git a/src/llama_stack_client/lib/cli/eval_tasks/list.py b/src/llama_stack_client/lib/cli/eval_tasks/list.py index 6c054e8..68e20f5 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/list.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/list.py @@ -8,7 +8,6 @@ from rich.console import Console from rich.table import Table - from ..common.utils import handle_client_errors diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index 6d4f50f..f6f3d60 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -10,8 +10,8 @@ import yaml from llama_stack_client import LlamaStackClient -from .configure import configure +from .configure import configure from .constants import get_config_file_path from .datasets import datasets from .eval import eval @@ -19,7 +19,6 @@ from .inference import inference from .memory_banks import memory_banks from .models import models - from .providers import providers from .scoring_functions import scoring_functions from .shields import shields diff --git a/src/llama_stack_client/lib/cli/providers/list.py b/src/llama_stack_client/lib/cli/providers/list.py index de5ad6e..8fb71a6 100644 --- a/src/llama_stack_client/lib/cli/providers/list.py +++ b/src/llama_stack_client/lib/cli/providers/list.py @@ -3,6 +3,8 @@ from rich.table import Table from ..common.utils import handle_client_errors + + @click.command("list") @click.pass_context @handle_client_errors("list providers") diff --git a/src/llama_stack_client/lib/direct/direct.py b/src/llama_stack_client/lib/direct/direct.py index ac436e8..d3aeb5a 100644 --- a/src/llama_stack_client/lib/direct/direct.py +++ b/src/llama_stack_client/lib/direct/direct.py @@ -1,23 +1,21 @@ import inspect -import yaml -from typing import Any, cast, get_args, get_origin, Type +from typing import Any, Type, cast, get_args, get_origin -from rich.console import Console +import yaml from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.server import is_streaming_request -from llama_stack.distribution.stack import ( - get_stack_run_config_from_template, -) -from llama_stack.distribution.stack import construct_stack +from llama_stack.distribution.stack import (construct_stack, + get_stack_run_config_from_template) from pydantic import BaseModel +from rich.console import Console from ..._base_client import ResponseT from ..._client import LlamaStackClient from ..._streaming import Stream -from ..._types import Body, NOT_GIVEN, RequestFiles, RequestOptions +from ..._types import NOT_GIVEN, Body, RequestFiles, RequestOptions class LlamaStackDirectClient(LlamaStackClient): diff --git a/src/llama_stack_client/lib/direct/test.py b/src/llama_stack_client/lib/direct/test.py index 4f21ce4..001f35a 100644 --- a/src/llama_stack_client/lib/direct/test.py +++ b/src/llama_stack_client/lib/direct/test.py @@ -2,6 +2,7 @@ import yaml from llama_stack.distribution.configure import parse_and_maybe_upgrade_config + from llama_stack_client.lib.direct.direct import LlamaStackDirectClient from llama_stack_client.types import UserMessage