Skip to content

Commit

Permalink
Resolve Chainlit#828: frontend connection resume after connection loss
Browse files Browse the repository at this point in the history
Update websocket's thread id header with currentThreadId to ensure session continuation after backend restart.
  • Loading branch information
qtangs authored Aug 28, 2024
1 parent 8da9ad2 commit 86798bc
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 9 deletions.
32 changes: 27 additions & 5 deletions cypress/e2e/data_layer/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os.path
import pickle
from typing import Dict, List, Optional

import chainlit.data as cl_data
from chainlit.socket import persist_user_session
from chainlit.step import StepDict
from literalai.helper import utc_now

import chainlit as cl

now = utc_now()

create_step_counter = 0


thread_history = [
{
"id": "test1",
Expand Down Expand Up @@ -61,6 +61,22 @@
] # type: List[cl_data.ThreadDict]
deleted_thread_ids = [] # type: List[str]

THREAD_HISTORY_PICKLE_PATH = os.getenv("THREAD_HISTORY_PICKLE_PATH")
if THREAD_HISTORY_PICKLE_PATH and os.path.exists(THREAD_HISTORY_PICKLE_PATH):
with open(THREAD_HISTORY_PICKLE_PATH, "rb") as f:
thread_history = pickle.load(f)


async def save_thread_history():
if THREAD_HISTORY_PICKLE_PATH:
# Force saving of thread history for reload when server restarts
await persist_user_session(
cl.context.session.thread_id, cl.context.session.to_persistable()
)

with open(THREAD_HISTORY_PICKLE_PATH, "wb") as out_file:
pickle.dump(thread_history, out_file)


class TestDataLayer(cl_data.BaseDataLayer):
async def get_user(self, identifier: str):
Expand Down Expand Up @@ -101,8 +117,9 @@ async def update_thread(

@cl_data.queue_until_user_message()
async def create_step(self, step_dict: StepDict):
global create_step_counter
create_step_counter += 1
cl.user_session.set(
"create_step_counter", cl.user_session.get("create_step_counter") + 1
)

thread = next(
(t for t in thread_history if t["id"] == step_dict.get("threadId")), None
Expand Down Expand Up @@ -138,11 +155,14 @@ async def delete_thread(self, thread_id: str):


async def send_count():
create_step_counter = cl.user_session.get("create_step_counter")
await cl.Message(f"Create step counter: {create_step_counter}").send()


@cl.on_chat_start
async def main():
# Add step counter to session so that it is saved in thread metadata
cl.user_session.set("create_step_counter", 0)
await cl.Message("Hello, send me a message!").send()
await send_count()

Expand All @@ -157,6 +177,8 @@ async def handle_message():
await cl.Message("Ok!").send()
await send_count()

await save_thread_history()


@cl.password_auth_callback
def auth_callback(username: str, password: str) -> Optional[cl.User]:
Expand Down
67 changes: 65 additions & 2 deletions cypress/e2e/data_layer/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { sep } from 'path';

import { runTestServer, submitMessage } from '../../support/testUtils';
import { ExecutionMode } from '../../support/utils';

function login() {
cy.get("[id='email']").type('admin');
Expand Down Expand Up @@ -71,9 +74,50 @@ function resumeThread() {
cy.get('.step').eq(8).should('contain', 'chat_profile');
}

function restartServer(
mode: ExecutionMode = undefined,
env?: Record<string, string>
) {
const pathItems = Cypress.spec.absolute.split(sep);
const testName = pathItems[pathItems.length - 2];
cy.exec(`pnpm exec ts-node ./cypress/support/run.ts ${testName} ${mode}`, {
env
});
}

function continueThread() {
cy.get('.step').eq(7).should('contain', 'Welcome back to Hello');

submitMessage('Hello after restart');

// Verify that new step counter messages have been added
cy.get('.step').eq(11).should('contain', 'Create step counter: 14');
cy.get('.step').eq(14).should('contain', 'Create step counter: 17');
}

function newThread() {
cy.get('#new-chat-button').click();
cy.get('#confirm').click();
}

describe('Data Layer', () => {
before(() => {
runTestServer();
beforeEach(() => {
// Set up the thread history file
const pathItems = Cypress.spec.absolute.split(sep);
pathItems[pathItems.length - 1] = 'thread_history.pickle';
const threadHistoryFile = pathItems.join(sep);
cy.wrap(threadHistoryFile).as('threadHistoryFile');

runTestServer(undefined, {
THREAD_HISTORY_PICKLE_PATH: threadHistoryFile
});
});

afterEach(() => {
cy.get('@threadHistoryFile').then((threadHistoryFile) => {
// Clean up the thread history file
cy.exec(`rm ${threadHistoryFile}`);
});
});

describe('Data Features with persistence', () => {
Expand All @@ -84,5 +128,24 @@ describe('Data Layer', () => {
threadList();
resumeThread();
});

it('should continue the thread after backend restarts and work with new thread as usual', () => {
login();
feedback();
threadQueue();

cy.get('@threadHistoryFile').then((threadHistoryFile) => {
restartServer(undefined, {
THREAD_HISTORY_PICKLE_PATH: `${threadHistoryFile}`
});
});
// Continue the thread and verify that the step counter is not reset
continueThread();

// Create a new thread and verify that the step counter is reset
newThread();
feedback();
threadQueue();
});
});
});
14 changes: 12 additions & 2 deletions libs/react-client/src/useChatSession.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { debounce } from 'lodash';
import { useCallback, useContext } from 'react';
import { useCallback, useContext, useEffect } from 'react';
import {
useRecoilState,
useRecoilValue,
Expand Down Expand Up @@ -63,7 +63,17 @@ const useChatSession = () => {
const setTokenCount = useSetRecoilState(tokenCountState);
const [chatProfile, setChatProfile] = useRecoilState(chatProfileState);
const idToResume = useRecoilValue(threadIdToResumeState);
const setCurrentThreadId = useSetRecoilState(currentThreadIdState);
const [currentThreadId, setCurrentThreadId] =
useRecoilState(currentThreadIdState);

// Use currentThreadId as thread id in websocket header
useEffect(() => {
if (session?.socket) {
session.socket.io.opts.extraHeaders!['X-Chainlit-Thread-Id'] =
currentThreadId || '';
}
}, [currentThreadId]);

const _connect = useCallback(
({
userEnv,
Expand Down

0 comments on commit 86798bc

Please sign in to comment.