forked from Chainlit/chainlit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add get_element() and test infra for sql_alchemy.py (Chainlit#1346)
* Add get_element() to sql_alchemy.py * Add test for create_element and get_element in SQLAlchemyDataLayer. * Add aiosqlite test dep. * Add missing attribute to mocked WebsocketSession object * Add mocked user to ChainlitContext in test --------- Co-authored-by: Mathijs de Bruin (aider) <[email protected]>
- Loading branch information
Showing
7 changed files
with
207 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
|
||
from unittest.mock import AsyncMock | ||
|
||
from chainlit.data.base import BaseStorageClient | ||
|
||
|
||
@pytest.fixture | ||
def mock_storage_client(): | ||
mock_client = AsyncMock(spec=BaseStorageClient) | ||
mock_client.upload_file.return_value = { | ||
"url": "https://example.com/test.txt", | ||
"object_key": "test_user/test_element/test.txt", | ||
} | ||
return mock_client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import uuid | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from sqlalchemy.ext.asyncio import create_async_engine | ||
from sqlalchemy import text | ||
|
||
from chainlit.data.base import BaseStorageClient | ||
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer | ||
from chainlit.element import Text | ||
|
||
|
||
@pytest.fixture | ||
async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path): | ||
db_file = tmp_path / "test_db.sqlite" | ||
conninfo = f"sqlite+aiosqlite:///{db_file}" | ||
|
||
# Create async engine | ||
engine = create_async_engine(conninfo) | ||
|
||
# Execute initialization statements | ||
# Ref: https://docs.chainlit.io/data-persistence/custom#sql-alchemy-data-layer | ||
async with engine.begin() as conn: | ||
await conn.execute( | ||
text(""" | ||
CREATE TABLE users ( | ||
"id" UUID PRIMARY KEY, | ||
"identifier" TEXT NOT NULL UNIQUE, | ||
"metadata" JSONB NOT NULL, | ||
"createdAt" TEXT | ||
); | ||
""") | ||
) | ||
|
||
await conn.execute( | ||
text(""" | ||
CREATE TABLE IF NOT EXISTS threads ( | ||
"id" UUID PRIMARY KEY, | ||
"createdAt" TEXT, | ||
"name" TEXT, | ||
"userId" UUID, | ||
"userIdentifier" TEXT, | ||
"tags" TEXT[], | ||
"metadata" JSONB, | ||
FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE | ||
); | ||
""") | ||
) | ||
|
||
await conn.execute( | ||
text(""" | ||
CREATE TABLE IF NOT EXISTS steps ( | ||
"id" UUID PRIMARY KEY, | ||
"name" TEXT NOT NULL, | ||
"type" TEXT NOT NULL, | ||
"threadId" UUID NOT NULL, | ||
"parentId" UUID, | ||
"disableFeedback" BOOLEAN NOT NULL, | ||
"streaming" BOOLEAN NOT NULL, | ||
"waitForAnswer" BOOLEAN, | ||
"isError" BOOLEAN, | ||
"metadata" JSONB, | ||
"tags" TEXT[], | ||
"input" TEXT, | ||
"output" TEXT, | ||
"createdAt" TEXT, | ||
"start" TEXT, | ||
"end" TEXT, | ||
"generation" JSONB, | ||
"showInput" TEXT, | ||
"language" TEXT, | ||
"indent" INT | ||
); | ||
""") | ||
) | ||
|
||
await conn.execute( | ||
text(""" | ||
CREATE TABLE IF NOT EXISTS elements ( | ||
"id" UUID PRIMARY KEY, | ||
"threadId" UUID, | ||
"type" TEXT, | ||
"url" TEXT, | ||
"chainlitKey" TEXT, | ||
"name" TEXT NOT NULL, | ||
"display" TEXT, | ||
"objectKey" TEXT, | ||
"size" TEXT, | ||
"page" INT, | ||
"language" TEXT, | ||
"forId" UUID, | ||
"mime" TEXT | ||
); | ||
""") | ||
) | ||
|
||
await conn.execute( | ||
text(""" | ||
CREATE TABLE IF NOT EXISTS feedbacks ( | ||
"id" UUID PRIMARY KEY, | ||
"forId" UUID NOT NULL, | ||
"threadId" UUID NOT NULL, | ||
"value" INT NOT NULL, | ||
"comment" TEXT | ||
); | ||
""") | ||
) | ||
|
||
# Create SQLAlchemyDataLayer instance | ||
data_layer = SQLAlchemyDataLayer(conninfo, storage_provider=mock_storage_client) | ||
|
||
yield data_layer | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_create_and_get_element( | ||
mock_chainlit_context, data_layer: SQLAlchemyDataLayer | ||
): | ||
async with mock_chainlit_context: | ||
text_element = Text( | ||
id=str(uuid.uuid4()), | ||
name="test.txt", | ||
mime="text/plain", | ||
content="test content", | ||
for_id="test_step_id", | ||
) | ||
|
||
await data_layer.create_element(text_element) | ||
|
||
retrieved_element = await data_layer.get_element( | ||
text_element.thread_id, text_element.id | ||
) | ||
assert retrieved_element is not None | ||
assert retrieved_element["id"] == text_element.id | ||
assert retrieved_element["name"] == text_element.name | ||
assert retrieved_element["mime"] == text_element.mime | ||
# The 'content' field is not part of the ElementDict, so we remove this assertion |