Skip to content

Commit

Permalink
Add get_element() and test infra for sql_alchemy.py (Chainlit#1346)
Browse files Browse the repository at this point in the history
* Add get_element() to sql_alchemy.py
* Add test for create_element and get_element in SQLAlchemyDataLayer.
* Add aiosqlite test dep.
* Add missing attribute to mocked WebsocketSession object
* Add mocked user to ChainlitContext in test

---------

Co-authored-by: Mathijs de Bruin (aider) <[email protected]>
  • Loading branch information
hayescode and dokterbob authored Sep 18, 2024
1 parent b86fa05 commit 2bdd541
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 2 deletions.
28 changes: 28 additions & 0 deletions backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,34 @@ async def delete_feedback(self, feedback_id: str) -> bool:
return True

###### Elements ######
async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]:
if self.show_logger:
logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}")
query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
parameters = {"thread_id": thread_id, "element_id": element_id}
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters)
if isinstance(element, list) and element:
element_dict: Dict[str, Any] = element[0]
return ElementDict(
id=element_dict["id"],
threadId=element_dict.get("threadId"),
type=element_dict["type"],
chainlitKey=element_dict.get("chainlitKey"),
url=element_dict.get("url"),
objectKey=element_dict.get("objectKey"),
name=element_dict["name"],
display=element_dict["display"],
size=element_dict.get("size"),
language=element_dict.get("language"),
page=element_dict.get("page"),
autoPlay=element_dict.get("autoPlay"),
playerConfig=element_dict.get("playerConfig"),
forId=element_dict.get("forId"),
mime=element_dict.get("mime")
)
else:
return None

@queue_until_user_message()
async def create_element(self, element: "Element"):
if self.show_logger:
Expand Down
20 changes: 19 additions & 1 deletion backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ plotly = "^5.18.0"
slack_bolt = "^1.18.1"
discord = "^2.3.2"
botbuilder-core = "^4.15.0"
aiosqlite = "^0.20.0"

[tool.poetry.group.dev.dependencies]
black = "^24.8.0"
Expand Down Expand Up @@ -106,6 +107,7 @@ ignore_missing_imports = true




[tool.poetry.group.custom-data]
optional = true

Expand Down
6 changes: 5 additions & 1 deletion backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest_asyncio
from chainlit.context import ChainlitContext, context_var
from chainlit.session import HTTPSession, WebsocketSession
from chainlit.user import PersistedUser
from chainlit.user_session import UserSession


Expand All @@ -14,13 +15,16 @@ async def create_chainlit_context():
mock_session.id = "test_session_id"
mock_session.user_env = {"test_env": "value"}
mock_session.chat_settings = {}
mock_session.user = None
mock_user = Mock(spec=PersistedUser)
mock_user.id = "test_user_id"
mock_session.user = mock_user
mock_session.chat_profile = None
mock_session.http_referer = None
mock_session.client_type = "webapp"
mock_session.languages = ["en"]
mock_session.thread_id = "test_thread_id"
mock_session.emit = AsyncMock()
mock_session.has_first_interaction = True

context = ChainlitContext(mock_session)
token = context_var.set(context)
Expand Down
Empty file added backend/tests/data/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions backend/tests/data/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from unittest.mock import AsyncMock

from chainlit.data.base import BaseStorageClient


@pytest.fixture
def mock_storage_client():
mock_client = AsyncMock(spec=BaseStorageClient)
mock_client.upload_file.return_value = {
"url": "https://example.com/test.txt",
"object_key": "test_user/test_element/test.txt",
}
return mock_client
138 changes: 138 additions & 0 deletions backend/tests/data/test_sql_alchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import uuid
from pathlib import Path

import pytest

from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text

from chainlit.data.base import BaseStorageClient
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.element import Text


@pytest.fixture
async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
db_file = tmp_path / "test_db.sqlite"
conninfo = f"sqlite+aiosqlite:///{db_file}"

# Create async engine
engine = create_async_engine(conninfo)

# Execute initialization statements
# Ref: https://docs.chainlit.io/data-persistence/custom#sql-alchemy-data-layer
async with engine.begin() as conn:
await conn.execute(
text("""
CREATE TABLE users (
"id" UUID PRIMARY KEY,
"identifier" TEXT NOT NULL UNIQUE,
"metadata" JSONB NOT NULL,
"createdAt" TEXT
);
""")
)

await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS threads (
"id" UUID PRIMARY KEY,
"createdAt" TEXT,
"name" TEXT,
"userId" UUID,
"userIdentifier" TEXT,
"tags" TEXT[],
"metadata" JSONB,
FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE
);
""")
)

await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS steps (
"id" UUID PRIMARY KEY,
"name" TEXT NOT NULL,
"type" TEXT NOT NULL,
"threadId" UUID NOT NULL,
"parentId" UUID,
"disableFeedback" BOOLEAN NOT NULL,
"streaming" BOOLEAN NOT NULL,
"waitForAnswer" BOOLEAN,
"isError" BOOLEAN,
"metadata" JSONB,
"tags" TEXT[],
"input" TEXT,
"output" TEXT,
"createdAt" TEXT,
"start" TEXT,
"end" TEXT,
"generation" JSONB,
"showInput" TEXT,
"language" TEXT,
"indent" INT
);
""")
)

await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS elements (
"id" UUID PRIMARY KEY,
"threadId" UUID,
"type" TEXT,
"url" TEXT,
"chainlitKey" TEXT,
"name" TEXT NOT NULL,
"display" TEXT,
"objectKey" TEXT,
"size" TEXT,
"page" INT,
"language" TEXT,
"forId" UUID,
"mime" TEXT
);
""")
)

await conn.execute(
text("""
CREATE TABLE IF NOT EXISTS feedbacks (
"id" UUID PRIMARY KEY,
"forId" UUID NOT NULL,
"threadId" UUID NOT NULL,
"value" INT NOT NULL,
"comment" TEXT
);
""")
)

# Create SQLAlchemyDataLayer instance
data_layer = SQLAlchemyDataLayer(conninfo, storage_provider=mock_storage_client)

yield data_layer


@pytest.mark.asyncio
async def test_create_and_get_element(
mock_chainlit_context, data_layer: SQLAlchemyDataLayer
):
async with mock_chainlit_context:
text_element = Text(
id=str(uuid.uuid4()),
name="test.txt",
mime="text/plain",
content="test content",
for_id="test_step_id",
)

await data_layer.create_element(text_element)

retrieved_element = await data_layer.get_element(
text_element.thread_id, text_element.id
)
assert retrieved_element is not None
assert retrieved_element["id"] == text_element.id
assert retrieved_element["name"] == text_element.name
assert retrieved_element["mime"] == text_element.mime
# The 'content' field is not part of the ElementDict, so we remove this assertion

0 comments on commit 2bdd541

Please sign in to comment.