Skip to content

Commit

Permalink
reset chat or reload history after data source change (#194) (#204)
Browse files Browse the repository at this point in the history
* reset chat or reload history when data source change

Signed-off-by: Lin Wang <[email protected]>

* Add change log

Signed-off-by: Lin Wang <[email protected]>

* Address PR comments

Signed-off-by: Lin Wang <[email protected]>

* Set search and page after data source change

Signed-off-by: Lin Wang <[email protected]>

* Remove skip first value when subscribe dataSourceId$

Signed-off-by: Lin Wang <[email protected]>

* Add dataSourceIdUdpates$ and finalDataSourceId

Signed-off-by: Lin Wang <[email protected]>

* Remove data source service mock in chat header button

Signed-off-by: Lin Wang <[email protected]>

* Remove no need useRef

Signed-off-by: Lin Wang <[email protected]>

* Refactor load history after data source change

Signed-off-by: Lin Wang <[email protected]>

---------

Signed-off-by: Lin Wang <[email protected]>
(cherry picked from commit a2a98f6)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

# Conflicts:
#	CHANGELOG.md

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent bd1a617 commit 7efc32c
Show file tree
Hide file tree
Showing 18 changed files with 375 additions and 93 deletions.
1 change: 1 addition & 0 deletions public/chat_header_button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { EuiBadge, EuiFieldText, EuiIcon } from '@elastic/eui';
import classNames from 'classnames';
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useEffectOnce } from 'react-use';

import { ApplicationStart, SIDECAR_DOCKED_MODE } from '../../../src/core/public';
// TODO: Replace with getChrome().logos.Chat.url
import chatIcon from './assets/chat.svg';
Expand Down
33 changes: 32 additions & 1 deletion public/components/agent_framework_traces_flyout_body.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import React from 'react';
import '@testing-library/jest-dom/extend-expect';
import { act, waitFor, render, screen, fireEvent } from '@testing-library/react';
import { waitFor, render, screen, fireEvent } from '@testing-library/react';
import * as chatContextExports from '../contexts/chat_context';
import * as coreContextExports from '../contexts/core_context';
import { AgentFrameworkTracesFlyoutBody } from './agent_framework_traces_flyout_body';
import { TAB_ID } from '../utils/constants';
import { BehaviorSubject, Subject } from 'rxjs';

jest.mock('./agent_framework_traces', () => {
return {
Expand All @@ -17,6 +19,20 @@ jest.mock('./agent_framework_traces', () => {
});

describe('<AgentFrameworkTracesFlyout/> spec', () => {
let dataSourceIdUpdates$: Subject<string | null>;
beforeEach(() => {
dataSourceIdUpdates$ = new Subject<string | null>();
jest.spyOn(coreContextExports, 'useCore').mockImplementation(() => {
return {
services: {
dataSource: {
dataSourceIdUpdates$,
},
},
};
});
});

it('show back button if interactionId exists', async () => {
const onCloseMock = jest.fn();
jest.spyOn(chatContextExports, 'useChatContext').mockReturnValue({
Expand Down Expand Up @@ -70,4 +86,19 @@ describe('<AgentFrameworkTracesFlyout/> spec', () => {
expect(onCloseMock).toHaveBeenCalledWith(TAB_ID.HISTORY);
});
});

it('should set tab to chat after data source changed', () => {
const setSelectedTabIdMock = jest.fn();
jest.spyOn(chatContextExports, 'useChatContext').mockReturnValue({
interactionId: 'test-interaction-id',
flyoutFullScreen: true,
setSelectedTabId: setSelectedTabIdMock,
preSelectedTabId: TAB_ID.HISTORY,
});
render(<AgentFrameworkTracesFlyoutBody />);

expect(setSelectedTabIdMock).not.toHaveBeenCalled();
dataSourceIdUpdates$.next('foo');
expect(setSelectedTabIdMock).toHaveBeenCalled();
});
});
14 changes: 13 additions & 1 deletion public/components/agent_framework_traces_flyout_body.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,26 @@ import {
EuiButtonIcon,
EuiPageHeaderSection,
} from '@elastic/eui';
import React from 'react';
import React, { useEffect } from 'react';
import { useChatContext } from '../contexts/chat_context';
import { useCore } from '../../public/contexts';
import { AgentFrameworkTraces } from './agent_framework_traces';
import { TAB_ID } from '../utils/constants';

export const AgentFrameworkTracesFlyoutBody: React.FC = () => {
const core = useCore();
const chatContext = useChatContext();
const interactionId = chatContext.interactionId;

useEffect(() => {
const subscription = core.services.dataSource.dataSourceIdUpdates$.subscribe(() => {
chatContext.setSelectedTabId(TAB_ID.CHAT);
});
return () => {
subscription.unsubscribe();
};
}, [core.services.dataSource, chatContext.setSelectedTabId]);

if (!interactionId) {
return null;
}
Expand Down
60 changes: 54 additions & 6 deletions public/hooks/use_chat_actions.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ jest.mock('../services/conversations_service', () => {
jest.mock('../services/conversation_load_service', () => {
return {
ConversationLoadService: jest.fn().mockImplementation(() => {
return { load: jest.fn().mockReturnValue({ messages: [], interactions: [] }) };
const conversationLoadMock = {
abortController: new AbortController(),
load: jest.fn().mockImplementation(async () => {
conversationLoadMock.abortController = new AbortController();
return { messages: [], interactions: [] };
}),
};
return conversationLoadMock;
}),
};
});
Expand Down Expand Up @@ -126,7 +133,7 @@ describe('useChatActions hook', () => {
messages: [SEND_MESSAGE_RESPONSE.messages[0]],
input: INPUT_MESSAGE,
}),
query: await dataSourceServiceMock.getDataSourceQuery(),
query: dataSourceServiceMock.getDataSourceQuery(),
});

// it should send dispatch `receive` action to remove the message without messageId
Expand Down Expand Up @@ -201,7 +208,7 @@ describe('useChatActions hook', () => {
messages: [],
input: { type: 'input', content: 'message that send as input', contentType: 'text' },
}),
query: await dataSourceServiceMock.getDataSourceQuery(),
query: dataSourceServiceMock.getDataSourceQuery(),
});
});

Expand Down Expand Up @@ -264,7 +271,7 @@ describe('useChatActions hook', () => {
expect(chatStateDispatchMock).toHaveBeenCalledWith({ type: 'abort' });
expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.ABORT_AGENT_EXECUTION, {
body: JSON.stringify({ conversationId: 'conversation_id_to_abort' }),
query: await dataSourceServiceMock.getDataSourceQuery(),
query: dataSourceServiceMock.getDataSourceQuery(),
});
});

Expand Down Expand Up @@ -292,7 +299,7 @@ describe('useChatActions hook', () => {
conversationId: 'conversation_id_mock',
interactionId: 'interaction_id_mock',
}),
query: await dataSourceServiceMock.getDataSourceQuery(),
query: dataSourceServiceMock.getDataSourceQuery(),
});
expect(chatStateDispatchMock).toHaveBeenCalledWith(
expect.objectContaining({ type: 'receive', payload: { messages: [], interactions: [] } })
Expand All @@ -312,6 +319,7 @@ describe('useChatActions hook', () => {
it('should not handle regenerate response if the regenerate operation has already aborted', async () => {
const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
signal: { aborted: true },
abort: jest.fn(),
}));

httpMock.put.mockResolvedValue(SEND_MESSAGE_RESPONSE);
Expand All @@ -328,7 +336,7 @@ describe('useChatActions hook', () => {
conversationId: 'conversation_id_mock',
interactionId: 'interaction_id_mock',
}),
query: await dataSourceServiceMock.getDataSourceQuery(),
query: dataSourceServiceMock.getDataSourceQuery(),
});
expect(chatStateDispatchMock).not.toHaveBeenCalledWith(
expect.objectContaining({ type: 'receive' })
Expand All @@ -353,6 +361,7 @@ describe('useChatActions hook', () => {
it('should not handle regenerate error if the regenerate operation has already aborted', async () => {
const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
signal: { aborted: true },
abort: jest.fn(),
}));
httpMock.put.mockImplementationOnce(() => {
throw new Error();
Expand All @@ -369,4 +378,43 @@ describe('useChatActions hook', () => {
);
AbortControllerMock.mockRestore();
});

it('should clear chat title, conversation id, flyoutComponent and call reset action', async () => {
const { result } = renderHook(() => useChatActions());
result.current.resetChat();

expect(chatContextMock.setConversationId).toHaveBeenLastCalledWith(undefined);
expect(chatContextMock.setTitle).toHaveBeenLastCalledWith(undefined);
expect(chatContextMock.setFlyoutComponent).toHaveBeenLastCalledWith(null);

expect(chatStateDispatchMock).toHaveBeenLastCalledWith({ type: 'reset' });
});

it('should abort send action after reset chat', async () => {
const abortFn = jest.fn();
const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
signal: { aborted: true },
abort: abortFn,
}));
const { result } = renderHook(() => useChatActions());
await result.current.send(INPUT_MESSAGE);
result.current.resetChat();

expect(abortFn).toHaveBeenCalled();
AbortControllerMock.mockRestore();
});

it('should abort load action after reset chat', async () => {
const abortFn = jest.fn();
const AbortControllerMock = jest.spyOn(window, 'AbortController').mockImplementation(() => ({
signal: { aborted: true },
abort: abortFn,
}));
const { result } = renderHook(() => useChatActions());
await result.current.loadChat('conversation_id_mock');
result.current.resetChat();

expect(abortFn).toHaveBeenCalled();
AbortControllerMock.mockRestore();
});
});
17 changes: 13 additions & 4 deletions public/hooks/use_chat_actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export const useChatActions = (): AssistantActions => {
...(!chatContext.conversationId && { messages: chatState.messages }), // include all previous messages for new chats
input,
}),
query: await core.services.dataSource.getDataSourceQuery(),
query: core.services.dataSource.getDataSourceQuery(),
});
if (abortController.signal.aborted) return;
// Refresh history list after new conversation created if new conversation saved and history list page visible
Expand Down Expand Up @@ -106,6 +106,15 @@ export const useChatActions = (): AssistantActions => {
}
};

const resetChat = () => {
abortControllerRef?.abort();
core.services.conversationLoad.abortController?.abort();
chatContext.setConversationId(undefined);
chatContext.setTitle(undefined);
chatContext.setFlyoutComponent(null);
chatStateDispatch({ type: 'reset' });
};

const openChatUI = () => {
chatContext.setFlyoutVisible(true);
chatContext.setSelectedTabId(TAB_ID.CHAT);
Expand Down Expand Up @@ -163,7 +172,7 @@ export const useChatActions = (): AssistantActions => {
// abort agent execution
await core.services.http.post(`${ASSISTANT_API.ABORT_AGENT_EXECUTION}`, {
body: JSON.stringify({ conversationId }),
query: await core.services.dataSource.getDataSourceQuery(),
query: core.services.dataSource.getDataSourceQuery(),
});
}
};
Expand All @@ -180,7 +189,7 @@ export const useChatActions = (): AssistantActions => {
conversationId: chatContext.conversationId,
interactionId,
}),
query: await core.services.dataSource.getDataSourceQuery(),
query: core.services.dataSource.getDataSourceQuery(),
});

if (abortController.signal.aborted) {
Expand Down Expand Up @@ -225,5 +234,5 @@ export const useChatActions = (): AssistantActions => {
}
};

return { send, loadChat, executeAction, openChatUI, abortAction, regenerate };
return { send, loadChat, resetChat, executeAction, openChatUI, abortAction, regenerate };
};
8 changes: 4 additions & 4 deletions public/hooks/use_conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ export const useDeleteConversation = () => {
const abortControllerRef = useRef<AbortController>();

const deleteConversation = useCallback(
async (conversationId: string) => {
(conversationId: string) => {
abortControllerRef.current = new AbortController();
dispatch({ type: 'request' });
return core.services.http
.delete(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, {
signal: abortControllerRef.current.signal,
query: await core.services.dataSource.getDataSourceQuery(),
query: core.services.dataSource.getDataSourceQuery(),
})
.then((payload) => {
dispatch({ type: 'success', payload });
Expand Down Expand Up @@ -53,15 +53,15 @@ export const usePatchConversation = () => {
const abortControllerRef = useRef<AbortController>();

const patchConversation = useCallback(
async (conversationId: string, title: string) => {
(conversationId: string, title: string) => {
abortControllerRef.current = new AbortController();
dispatch({ type: 'request' });
return core.services.http
.put(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, {
body: JSON.stringify({
title,
}),
query: await core.services.dataSource.getDataSourceQuery(),
query: core.services.dataSource.getDataSourceQuery(),
signal: abortControllerRef.current.signal,
})
.then((payload) => dispatch({ type: 'success', payload }))
Expand Down
4 changes: 2 additions & 2 deletions public/hooks/use_feed_back.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ describe('useFeedback hook', () => {
body: JSON.stringify({
satisfaction: true,
}),
query: await dataSourceMock.getDataSourceQuery(),
query: dataSourceMock.getDataSourceQuery(),
}
);
expect(result.current.feedbackResult).toBe(true);
Expand Down Expand Up @@ -119,7 +119,7 @@ describe('useFeedback hook', () => {
body: JSON.stringify({
satisfaction: true,
}),
query: await dataSourceMock.getDataSourceQuery(),
query: dataSourceMock.getDataSourceQuery(),
}
);
expect(result.current.feedbackResult).toBe(undefined);
Expand Down
2 changes: 1 addition & 1 deletion public/hooks/use_feed_back.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export const useFeedback = (interaction?: Interaction | null) => {
try {
await core.services.http.put(`${ASSISTANT_API.FEEDBACK}/${message.interactionId}`, {
body: JSON.stringify(body),
query: await core.services.dataSource.getDataSourceQuery(),
query: core.services.dataSource.getDataSourceQuery(),
});
setFeedbackResult(correct);
} catch (error) {
Expand Down
32 changes: 15 additions & 17 deletions public/hooks/use_fetch_agentframework_traces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,24 @@ export const useFetchAgentFrameworkTraces = (interactionId: string) => {
return;
}

core.services.dataSource.getDataSourceQuery().then((query) => {
core.services.http
.get<AgentFrameworkTrace[]>(`${ASSISTANT_API.TRACE}/${interactionId}`, {
signal: abortController.signal,
query,
core.services.http
.get<AgentFrameworkTrace[]>(`${ASSISTANT_API.TRACE}/${interactionId}`, {
signal: abortController.signal,
query: core.services.dataSource.getDataSourceQuery(),
})
.then((payload) =>
dispatch({
type: 'success',
payload,
})
.then((payload) =>
dispatch({
type: 'success',
payload,
})
)
.catch((error) => {
if (error.name === 'AbortError') return;
dispatch({ type: 'failure', error });
});
});
)
.catch((error) => {
if (error.name === 'AbortError') return;
dispatch({ type: 'failure', error });
});

return () => abortController.abort();
}, [core.services.http, interactionId]);
}, [core.services.http, interactionId, core.services.dataSource]);

return { ...state };
};
Loading

0 comments on commit 7efc32c

Please sign in to comment.