Skip to content

Commit

Permalink
Format All the Things (Chainlit#1531)
Browse files Browse the repository at this point in the history
* Specify Python version for ruff to target.
* Bump ruff version.
* Sort all imports with ruff.
* Format all code using ruff.
* Enable fastcache for mypy.
  • Loading branch information
dokterbob authored Nov 20, 2024
1 parent a5612aa commit 2ea6aac
Show file tree
Hide file tree
Showing 50 changed files with 182 additions and 176 deletions.
9 changes: 3 additions & 6 deletions backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import os
import pathlib
import shutil
import signal
import subprocess
import sys
from contextlib import contextmanager
from typing import Optional


class BuildError(Exception):
Expand Down Expand Up @@ -50,7 +47,7 @@ def copy_directory(src, dst, description):
shutil.rmtree(dst)
raise
except Exception as e:
raise BuildError(f"Failed to copy {src} to {dst}: {str(e)}")
raise BuildError(f"Failed to copy {src} to {dst}: {e!s}")


def copy_frontend(project_root):
Expand Down Expand Up @@ -94,10 +91,10 @@ def build():
print("\nBuild interrupted by user")
sys.exit(1)
except BuildError as e:
print(f"\nBuild failed: {str(e)}")
print(f"\nBuild failed: {e!s}")
sys.exit(1)
except Exception as e:
print(f"\nUnexpected error: {str(e)}")
print(f"\nUnexpected error: {e!s}")
sys.exit(1)


Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def acall(self):
"Image",
"Text",
"Component",
"Dataframe",
"Pyplot",
"File",
"Task",
Expand Down
5 changes: 3 additions & 2 deletions backend/chainlit/action.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import uuid
from typing import Optional

from chainlit.context import context
from chainlit.telemetry import trace_event
from dataclasses_json import DataClassJsonMixin
from pydantic.dataclasses import Field, dataclass

from chainlit.context import context
from chainlit.telemetry import trace_event


@dataclass
class Action(DataClassJsonMixin):
Expand Down
10 changes: 5 additions & 5 deletions backend/chainlit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Any, Dict

import jwt
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer

from chainlit.config import config
from chainlit.data import get_data_layer
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.user import User
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer

reuseable_oauth = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)

Expand Down Expand Up @@ -52,9 +53,8 @@ def create_jwt(data: User) -> str:
to_encode: Dict[str, Any] = data.to_dict()
to_encode.update(
{
"exp": datetime.utcnow() + timedelta(
seconds=config.project.user_session_timeout
),
"exp": datetime.utcnow()
+ timedelta(seconds=config.project.user_session_timeout),
}
)
encoded_jwt = jwt.encode(to_encode, get_jwt_secret(), algorithm="HS256")
Expand Down
4 changes: 2 additions & 2 deletions backend/chainlit/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def add(self, message: "Message"):

if context.session.id not in chat_contexts:
chat_contexts[context.session.id] = []

if message not in chat_contexts[context.session.id]:
chat_contexts[context.session.id].append(message)

return message

def remove(self, message: "Message") -> bool:
Expand Down
3 changes: 2 additions & 1 deletion backend/chainlit/chat_settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List

from pydantic.dataclasses import Field, dataclass

from chainlit.context import context
from chainlit.input_widget import InputWidget
from pydantic.dataclasses import Field, dataclass


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def init_config(log=False):
dst = os.path.join(config_translation_dir, file)
if not os.path.exists(dst):
src = os.path.join(TRANSLATIONS_DIR, file)
with open(src, "r", encoding="utf-8") as f:
with open(src, encoding="utf-8") as f:
translation = json.load(f)
with open(dst, "w", encoding="utf-8") as f:
json.dump(translation, f, indent=4)
Expand Down Expand Up @@ -515,15 +515,15 @@ def load_config():
def lint_translations():
# Load the ground truth (en-US.json file from chainlit source code)
src = os.path.join(TRANSLATIONS_DIR, "en-US.json")
with open(src, "r", encoding="utf-8") as f:
with open(src, encoding="utf-8") as f:
truth = json.load(f)

# Find the local app translations
for file in os.listdir(config_translation_dir):
if file.endswith(".json"):
# Load the translation file
to_lint = os.path.join(config_translation_dir, file)
with open(to_lint, "r", encoding="utf-8") as f:
with open(to_lint, encoding="utf-8") as f:
translation = json.load(f)

# Lint the translation file
Expand Down
5 changes: 3 additions & 2 deletions backend/chainlit/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from contextvars import ContextVar
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from chainlit.session import ClientType, HTTPSession, WebsocketSession
from lazify import LazyProxy

from chainlit.session import ClientType, HTTPSession, WebsocketSession

if TYPE_CHECKING:
from chainlit.emitter import BaseChainlitEmitter
from chainlit.step import Step
Expand Down Expand Up @@ -104,7 +105,7 @@ def get_context() -> ChainlitContext:
try:
return context_var.get()
except LookupError as e:
raise ChainlitContextException() from e
raise ChainlitContextException from e


context: ChainlitContext = LazyProxy(get_context, enable_cache=False)
5 changes: 3 additions & 2 deletions backend/chainlit/data/acl.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from chainlit.data import get_data_layer
from fastapi import HTTPException

from chainlit.data import get_data_layer


async def is_thread_author(username: str, thread_id: str):
data_layer = get_data_layer()
if not data_layer:
raise HTTPException(status_code=400, detail="Data layer not initialized")

thread_author = await data_layer.get_thread_author(thread_id)

if not thread_author:
raise HTTPException(status_code=404, detail="Thread not found")

Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/data/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional

from chainlit.types import (
Feedback,
Expand Down
8 changes: 5 additions & 3 deletions backend/chainlit/data/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aiohttp
import boto3 # type: ignore
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer

from chainlit.context import context
from chainlit.data.base import BaseDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
Expand All @@ -29,9 +30,10 @@
from chainlit.user import PersistedUser, User

if TYPE_CHECKING:
from chainlit.element import Element
from mypy_boto3_dynamodb import DynamoDBClient

from chainlit.element import Element


_logger = logger.getChild("DynamoDB")
_logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -402,15 +404,15 @@ async def delete_thread(self, thread_id: str):

BATCH_ITEM_SIZE = 25 # pylint: disable=invalid-name
for i in range(0, len(delete_requests), BATCH_ITEM_SIZE):
chunk = delete_requests[i : i + BATCH_ITEM_SIZE] # noqa: E203
chunk = delete_requests[i : i + BATCH_ITEM_SIZE]
response = self.client.batch_write_item(
RequestItems={
self.table_name: chunk, # type: ignore
}
)

backoff_time = 1
while "UnprocessedItems" in response and response["UnprocessedItems"]:
while response.get("UnprocessedItems"):
backoff_time *= 2
# Cap the backoff time at 32 seconds & add jitter
delay = min(backoff_time, 32) + random.uniform(0, 1)
Expand Down
8 changes: 3 additions & 5 deletions backend/chainlit/data/literalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,9 @@ async def delete_element(self, element_id: str, thread_id: Optional[str] = None)
async def create_step(self, step_dict: "StepDict"):
metadata = dict(
step_dict.get("metadata", {}),
**{
"waitForAnswer": step_dict.get("waitForAnswer"),
"language": step_dict.get("language"),
"showInput": step_dict.get("showInput"),
},
waitForAnswer=step_dict.get("waitForAnswer"),
language=step_dict.get("language"),
showInput=step_dict.get("showInput"),
)

step: LiteralStepDict = {
Expand Down
11 changes: 6 additions & 5 deletions backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

import aiofiles
import aiohttp
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

from chainlit.data.base import BaseDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.utils import queue_until_user_message
Expand All @@ -23,10 +28,6 @@
ThreadFilter,
)
from chainlit.user import PersistedUser, User
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

if TYPE_CHECKING:
from chainlit.element import Element, ElementDict
Expand Down Expand Up @@ -167,7 +168,7 @@ async def _get_user_id_by_thread(self, thread_id: str) -> Optional[str]:
async def create_user(self, user: User) -> Optional[PersistedUser]:
if self.show_logger:
logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
existing_user: Optional[PersistedUser] = await self.get_user(user.identifier)
user_dict: Dict[str, Any] = {
"identifier": str(user.identifier),
"metadata": json.dumps(user.metadata) or {},
Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/data/storage_clients/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DataLakeServiceClient,
FileSystemClient,
)

from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.logger import logger

Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/data/storage_clients/s3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Union

import boto3 # type: ignore

from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.logger import logger

Expand Down
3 changes: 2 additions & 1 deletion backend/chainlit/discord/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import discord
import filetype
import httpx
from discord.ui import Button, View

from chainlit.config import config
from chainlit.context import ChainlitContext, HTTPSession, context, context_var
from chainlit.data import get_data_layer
Expand All @@ -23,7 +25,6 @@
from chainlit.types import Feedback
from chainlit.user import PersistedUser, User
from chainlit.user_session import user_session
from discord.ui import Button, View


class FeedbackView(View):
Expand Down
10 changes: 5 additions & 5 deletions backend/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
)

import filetype
from pydantic.dataclasses import Field, dataclass
from syncer import asyncio

from chainlit.context import context
from chainlit.data import get_data_layer
from chainlit.logger import logger
from chainlit.telemetry import trace_event
from chainlit.types import FileDict
from pydantic.dataclasses import Field, dataclass
from syncer import asyncio

mime_types = {
"text": "text/plain",
Expand Down Expand Up @@ -154,7 +155,7 @@ async def _create(self) -> bool:
try:
asyncio.create_task(data_layer.create_element(self))
except Exception as e:
logger.error(f"Failed to create element: {str(e)}")
logger.error(f"Failed to create element: {e!s}")
if not self.url and (not self.chainlit_key or self.updatable):
file_dict = await context.session.persist_file(
name=self.name,
Expand Down Expand Up @@ -352,8 +353,7 @@ class Plotly(Element):
content: str = ""

def __post_init__(self) -> None:
from plotly import graph_objects as go
from plotly import io as pio
from plotly import graph_objects as go, io as pio

if not isinstance(self.figure, go.Figure):
raise TypeError("figure must be a plotly.graph_objects.Figure")
Expand Down
6 changes: 3 additions & 3 deletions backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ async def send_ask_user(
# End the task temporarily so that the User can answer the prompt
await self.task_end()

final_res: Optional[
Union["StepDict", "AskActionResponse", List["FileDict"]]
] = None
final_res: Optional[Union[StepDict, AskActionResponse, List[FileDict]]] = (
None
)

if user_res:
interaction: Union[str, None] = None
Expand Down
9 changes: 5 additions & 4 deletions backend/chainlit/haystack/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import re
from typing import Any, Generic, List, Optional, TypeVar

from chainlit import Message
from chainlit.step import Step
from chainlit.sync import run_sync
from haystack.agents import Agent, Tool
from haystack.agents.agent_step import AgentStep
from literalai.helper import utc_now

from chainlit import Message
from chainlit.step import Step
from chainlit.sync import run_sync

T = TypeVar("T")


Expand Down Expand Up @@ -131,7 +132,7 @@ def on_tool_finish(
tool_result: str,
tool_name: Optional[str] = None,
tool_input: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
# Tool finished, send step with tool_result
tool_step = self.stack.pop()
Expand Down
Loading

0 comments on commit 2ea6aac

Please sign in to comment.