Skip to content

Commit

Permalink
remove caching agent id (opensearch-project#1529)
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Li <[email protected]>
Signed-off-by: Paul Sebastian <[email protected]>
Co-authored-by: Paul Sebastian <[email protected]>
  • Loading branch information
joshuali925 and paulstn authored Mar 15, 2024
1 parent 9a4e532 commit 01bec05
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 82 deletions.
8 changes: 4 additions & 4 deletions server/routes/query_assist/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
import { isResponseError } from '../../../../../src/core/server/opensearch/client/errors';
import { ERROR_DETAILS, QUERY_ASSIST_API } from '../../../common/constants/query_assist';
import { generateFieldContext } from '../../common/helpers/query_assist/generate_field_context';
import { getAgentIdByConfig, requestWithRetryAgentSearch } from './utils/agents';
import { getAgentIdByConfig, getAgentIdAndRequest } from './utils/agents';
import { AGENT_CONFIGS } from './utils/constants';

export function registerQueryAssistRoutes(router: IRouter) {
Expand Down Expand Up @@ -57,7 +57,7 @@ export function registerQueryAssistRoutes(router: IRouter) {
): Promise<IOpenSearchDashboardsResponse<any | ResponseError>> => {
const client = context.core.opensearch.client.asCurrentUser;
try {
const pplRequest = await requestWithRetryAgentSearch({
const pplRequest = await getAgentIdAndRequest({
client,
configName: AGENT_CONFIGS.PPL_AGENT,
body: {
Expand Down Expand Up @@ -118,7 +118,7 @@ export function registerQueryAssistRoutes(router: IRouter) {

try {
if (!isError) {
summaryRequest = await requestWithRetryAgentSearch({
summaryRequest = await getAgentIdAndRequest({
client,
configName: AGENT_CONFIGS.RESPONSE_SUMMARY_AGENT,
body: {
Expand All @@ -131,7 +131,7 @@ export function registerQueryAssistRoutes(router: IRouter) {
client.search({ index, size: 1 }),
]);
const fields = generateFieldContext(mappings, sampleDoc);
summaryRequest = await requestWithRetryAgentSearch({
summaryRequest = await getAgentIdAndRequest({
client,
configName: AGENT_CONFIGS.ERROR_SUMMARY_AGENT,
body: {
Expand Down
53 changes: 3 additions & 50 deletions server/routes/query_assist/utils/__tests__/agents.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { ApiResponse } from '@opensearch-project/opensearch';
import { ResponseError } from '@opensearch-project/opensearch/lib/errors';
import { CoreRouteHandlerContext } from '../../../../../../../src/core/server/core_route_handler_context';
import { coreMock, httpServerMock } from '../../../../../../../src/core/server/mocks';
import { agentIdMap, getAgentIdByConfig, requestWithRetryAgentSearch } from '../agents';
import { getAgentIdByConfig, getAgentIdAndRequest } from '../agents';

describe('Agents helper functions', () => {
const coreContext = new CoreRouteHandlerContext(
Expand Down Expand Up @@ -75,27 +75,7 @@ describe('Agents helper functions', () => {
);
});

it('requests with valid agent id', async () => {
agentIdMap.test_agent = 'test-id';
mockedTransport.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
client,
configName: 'test_agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
expect.objectContaining({
path: '/_plugins/_ml/agents/test-id/_execute',
}),
expect.anything()
);
expect(response.body.inference_results[0].output[0].result).toEqual('test response');
});

it('searches for agent id if id is undefined', async () => {
it('searches for agent id and sends request', async () => {
mockedTransport
.mockResolvedValueOnce({
body: {
Expand All @@ -106,36 +86,9 @@ describe('Agents helper functions', () => {
.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
const response = await getAgentIdAndRequest({
client,
configName: 'new_agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
expect.objectContaining({ path: '/_plugins/_ml/agents/new-id/_execute' }),
expect.anything()
);
expect(response.body.inference_results[0].output[0].result).toEqual('test response');
});

it('searches for agent id if id is not found', async () => {
agentIdMap.test_agent = 'non-exist-agent';
mockedTransport
.mockRejectedValueOnce({ statusCode: 404, body: {}, headers: {} })
.mockResolvedValueOnce({
body: {
type: 'agent',
configuration: { agent_id: 'new-id' },
},
})
.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
client,
configName: 'test_agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
Expand Down
40 changes: 12 additions & 28 deletions server/routes/query_assist/utils/agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
*/

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { RequestBody } from '@opensearch-project/opensearch/lib/Transport';
import { RequestBody, TransportRequestPromise } from '@opensearch-project/opensearch/lib/Transport';
import { OpenSearchClient } from '../../../../../../src/core/server';
import { isResponseError } from '../../../../../../src/core/server/opensearch/client/errors';
import { ML_COMMONS_API_PREFIX } from '../../../../common/constants/query_assist';

const AGENT_REQUEST_OPTIONS = {
Expand All @@ -27,8 +26,6 @@ type AgentResponse = ApiResponse<{
}>;
}>;

export const agentIdMap: Record<string, string> = {};

export const getAgentIdByConfig = async (
opensearchClient: OpenSearchClient,
configName: string
Expand All @@ -49,32 +46,19 @@ export const getAgentIdByConfig = async (
}
};

export const requestWithRetryAgentSearch = async (options: {
export const getAgentIdAndRequest = async (options: {
client: OpenSearchClient;
configName: string;
shouldRetryAgentSearch?: boolean;
body: RequestBody;
}): Promise<AgentResponse> => {
const { client, configName, shouldRetryAgentSearch = true, body } = options;
let retry = shouldRetryAgentSearch;
if (!agentIdMap[configName]) {
agentIdMap[configName] = await getAgentIdByConfig(client, configName);
retry = false;
}
return client.transport
.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${agentIdMap[configName]}/_execute`,
body,
},
AGENT_REQUEST_OPTIONS
)
.catch(async (error) => {
if (retry && isResponseError(error) && error.statusCode === 404) {
agentIdMap[configName] = await getAgentIdByConfig(client, configName);
return requestWithRetryAgentSearch({ ...options, shouldRetryAgentSearch: false });
}
return Promise.reject(error);
}) as Promise<AgentResponse>;
const { client, configName, body } = options;
const agentId = await getAgentIdByConfig(client, configName);
return client.transport.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${agentId}/_execute`,
body,
},
AGENT_REQUEST_OPTIONS
) as TransportRequestPromise<AgentResponse>;
};

0 comments on commit 01bec05

Please sign in to comment.