diff --git a/common/constants/llm.ts b/common/constants/llm.ts index 6e21fd94..63ac5a59 100644 --- a/common/constants/llm.ts +++ b/common/constants/llm.ts @@ -19,6 +19,11 @@ export const ASSISTANT_API = { ACCOUNT: `${API_BASE}/account`, } as const; +export const TEXT2VIZ_API = { + TEXT2PPL: `${API_BASE}/text2ppl`, + TEXT2VEGA: `${API_BASE}/text2vega`, +}; + export const NOTEBOOK_API = { CREATE_NOTEBOOK: `${NOTEBOOK_PREFIX}/note`, SET_PARAGRAPH: `${NOTEBOOK_PREFIX}/set_paragraphs/`, diff --git a/opensearch_dashboards.json b/opensearch_dashboards.json index 84e7bffb..17370acd 100644 --- a/opensearch_dashboards.json +++ b/opensearch_dashboards.json @@ -5,16 +5,19 @@ "server": true, "ui": true, "requiredPlugins": [ + "data", "dashboard", "embeddable", "opensearchDashboardsReact", - "opensearchDashboardsUtils" + "opensearchDashboardsUtils", + "visualizations" ], "optionalPlugins": [ "dataSource", "dataSourceManagement" ], + "requiredBundles": [], "configPath": [ "assistant" ] -} \ No newline at end of file +} diff --git a/public/components/visualization/source_selector.tsx b/public/components/visualization/source_selector.tsx new file mode 100644 index 00000000..09dc5aca --- /dev/null +++ b/public/components/visualization/source_selector.tsx @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useMemo, useState, useEffect } from 'react'; +import { i18n } from '@osd/i18n'; + +import { useOpenSearchDashboards } from '../../../../../src/plugins/opensearch_dashboards_react/public'; +import { + DataSource, + DataSourceGroup, + DataSourceSelectable, + DataSourceOption, +} from '../../../../../src/plugins/data/public'; +import { StartServices } from '../../types'; + +export const SourceSelector = ({ + selectedSourceId, + onChange, +}: { + selectedSourceId: string; + onChange: (ds: DataSourceOption) => void; +}) => { + const { + services: { + data: { dataSources }, + notifications: { toasts }, + }, + } = useOpenSearchDashboards(); + const [currentDataSources, setCurrentDataSources] = useState([]); + const [dataSourceOptions, setDataSourceOptions] = useState([]); + + const selectedSources = useMemo(() => { + if (selectedSourceId) { + for (const group of dataSourceOptions) { + for (const item of group.options) { + if (item.value === selectedSourceId) { + return [item]; + } + } + } + } + return []; + }, [selectedSourceId, dataSourceOptions]); + + useEffect(() => { + if ( + !selectedSourceId && + dataSourceOptions.length > 0 && + dataSourceOptions[0].options.length > 0 + ) { + onChange(dataSourceOptions[0].options[0]); + } + }, [selectedSourceId, dataSourceOptions]); + + useEffect(() => { + const subscription = dataSources.dataSourceService.getDataSources$().subscribe((ds) => { + setCurrentDataSources(Object.values(ds)); + }); + + return () => { + subscription.unsubscribe(); + }; + }, [dataSources]); + + const onDataSourceSelect = useCallback( + (selectedDataSources: DataSourceOption[]) => { + onChange(selectedDataSources[0]); + }, + [onChange] + ); + + const handleGetDataSetError = useCallback( + () => (error: Error) => { + toasts.addError(error, { + title: + i18n.translate('visualize.vega.failedToGetDataSetErrorDescription', { + defaultMessage: 'Failed to get data set: ', + }) + (error.message || error.name), + }); + }, + [toasts] + ); + + const memorizedReload = useCallback(() => { + dataSources.dataSourceService.reload(); + }, [dataSources.dataSourceService]); + + return ( + + ); +}; diff --git a/public/components/visualization/text2vega.ts b/public/components/visualization/text2vega.ts new file mode 100644 index 00000000..9621bd2b --- /dev/null +++ b/public/components/visualization/text2vega.ts @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { BehaviorSubject, Observable, of } from 'rxjs'; +import { debounceTime, switchMap, tap, filter, catchError } from 'rxjs/operators'; +import { TEXT2VIZ_API } from '.../../../common/constants/llm'; +import { HttpSetup } from '../../../../../src/core/public'; +import { DataPublicPluginStart } from '../../../../../src/plugins/data/public'; + +const DATA_SOURCE_DELIMITER = '::'; + +const topN = (ppl: string, n: number) => `${ppl} | head ${n}`; + +const getDataSourceAndIndexFromLabel = (label: string) => { + if (label.includes(DATA_SOURCE_DELIMITER)) { + return [ + label.slice(0, label.indexOf(DATA_SOURCE_DELIMITER)), + label.slice(label.indexOf(DATA_SOURCE_DELIMITER) + DATA_SOURCE_DELIMITER.length), + ] as const; + } + return [, label] as const; +}; + +interface Input { + prompt: string; + index: string; + dataSourceId?: string; +} + +export class Text2Vega { + input$ = new BehaviorSubject({ prompt: '', index: '' }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + result$: Observable | { error: any }>; + status$ = new BehaviorSubject<'RUNNING' | 'STOPPED'>('STOPPED'); + http: HttpSetup; + searchClient: DataPublicPluginStart['search']; + + constructor(http: HttpSetup, searchClient: DataPublicPluginStart['search']) { + this.http = http; + this.searchClient = searchClient; + this.result$ = this.input$ + .pipe( + filter((v) => v.prompt.length > 0), + debounceTime(200), + tap(() => this.status$.next('RUNNING')) + ) + .pipe( + switchMap((v) => + of(v).pipe( + // text to ppl + switchMap(async (value) => { + const [, indexName] = getDataSourceAndIndexFromLabel(value.index); + const pplQuestion = value.prompt.split('//')[0]; + const ppl = await this.text2ppl(pplQuestion, indexName, value.dataSourceId); + return { + ...value, + ppl, + }; + }), + // query sample data with ppl + switchMap(async (value) => { + const ppl = topN(value.ppl, 2); + const res = await this.searchClient + .search( + { params: { body: { query: ppl } }, dataSourceId: value.dataSourceId }, + { strategy: 'pplraw' } + ) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .toPromise(); + return { ...value, sample: res.rawResponse }; + }), + // call llm to generate vega + switchMap(async (value) => { + const result = await this.text2vega({ + input: value.prompt, + ppl: value.ppl, + sampleData: JSON.stringify(value.sample.jsonData), + dataSchema: JSON.stringify(value.sample.schema), + dataSourceId: value.dataSourceId, + }); + const [dataSourceName] = getDataSourceAndIndexFromLabel(value.index); + result.data = { + url: { + '%type%': 'ppl', + body: { query: value.ppl }, + data_source_name: dataSourceName, + }, + }; + return result; + }), + catchError((e) => of({ error: e })) + ) + ) + ) + .pipe(tap(() => this.status$.next('STOPPED'))); + } + + async text2vega({ + input, + ppl, + sampleData, + dataSchema, + dataSourceId, + }: { + input: string; + ppl: string; + sampleData: string; + dataSchema: string; + dataSourceId?: string; + }) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const escapeField = (json: any, field: string) => { + if (json[field]) { + if (typeof json[field] === 'string') { + json[field] = json[field].replace(/\./g, '\\.'); + } + if (typeof json[field] === 'object') { + Object.keys(json[field]).forEach((p) => { + escapeField(json[field], p); + }); + } + } + }; + const res = await this.http.post(TEXT2VIZ_API.TEXT2VEGA, { + body: JSON.stringify({ + input, + ppl, + sampleData: JSON.stringify(sampleData), + dataSchema: JSON.stringify(dataSchema), + }), + query: { dataSourceId }, + }); + + // need to escape field: geo.city -> field: geo\\.city + escapeField(res, 'encoding'); + return res; + } + + async text2ppl(query: string, index: string, dataSourceId?: string) { + const pplResponse = await this.http.post(TEXT2VIZ_API.TEXT2PPL, { + body: JSON.stringify({ + question: query, + index, + }), + query: { dataSourceId }, + }); + return pplResponse.ppl; + } + + invoke(value: Input) { + this.input$.next(value); + } + + getStatus$() { + return this.status$; + } + + getResult$() { + return this.result$; + } +} diff --git a/public/components/visualization/text2viz.scss b/public/components/visualization/text2viz.scss new file mode 100644 index 00000000..ead70f30 --- /dev/null +++ b/public/components/visualization/text2viz.scss @@ -0,0 +1,10 @@ +.text2viz__page { + .visualize { + height: 400px; + } + + .text2viz__right { + padding-top: 15px; + padding-left: 30px; + } +} diff --git a/public/components/visualization/text2viz.tsx b/public/components/visualization/text2viz.tsx new file mode 100644 index 00000000..8a4ead6c --- /dev/null +++ b/public/components/visualization/text2viz.tsx @@ -0,0 +1,282 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + EuiPageBody, + EuiPage, + EuiPageContent, + EuiPageContentBody, + EuiFlexGroup, + EuiFlexItem, + EuiFieldText, + EuiIcon, + EuiButtonIcon, + EuiButton, + EuiBreadcrumb, + EuiHeaderLinks, +} from '@elastic/eui'; +import React, { useEffect, useRef, useState } from 'react'; +import { i18n } from '@osd/i18n'; + +import { useCallback } from 'react'; +import { useObservable } from 'react-use'; +import { useMemo } from 'react'; +import { SourceSelector } from './source_selector'; +import type { DataSourceOption } from '../../../../../src/plugins/data/public'; +import chatIcon from '../../assets/chat.svg'; +import { EmbeddableRenderer } from '../../../../../src/plugins/embeddable/public'; +import { + useOpenSearchDashboards, + MountPointPortal, + toMountPoint, +} from '../../../../../src/plugins/opensearch_dashboards_react/public'; +import { StartServices } from '../../types'; +import { + VISUALIZE_EMBEDDABLE_TYPE, + VisSavedObject, + VisualizeInput, +} from '../../../../../src/plugins/visualizations/public'; +import './text2viz.scss'; +import { Text2VizEmpty } from './text2viz_empty'; +import { Text2VizLoading } from './text2viz_loading'; +import { Text2Vega } from './text2vega'; +import { + OnSaveProps, + SavedObjectSaveModalOrigin, +} from '../../../../../src/plugins/saved_objects/public'; + +export const Text2Viz = () => { + const [selectedSource, setSelectedSource] = useState(); + const { + services: { + application, + chrome, + embeddable, + visualizations, + http, + notifications, + setHeaderActionMenu, + overlays, + data, + }, + } = useOpenSearchDashboards(); + const [input, setInput] = useState(''); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const [vegaSpec, setVegaSpec] = useState>(); + const text2vegaRef = useRef(new Text2Vega(http, data.search)); + const status = useObservable(text2vegaRef.current.status$); + + useEffect(() => { + const text2vega = text2vegaRef.current; + const subscription = text2vega.getResult$().subscribe((result) => { + if (result) { + if (result.error) { + notifications.toasts.addError(result.error, { + title: i18n.translate('dashboardAssistant.feature.text2viz.error', { + defaultMessage: 'Error while executing text to vega', + }), + }); + } else { + setVegaSpec(result); + } + } + }); + + return () => { + subscription.unsubscribe(); + }; + }, [http, notifications]); + + const onInputChange = useCallback((e: React.ChangeEvent) => { + setInput(e.target.value); + }, []); + + const onSubmit = useCallback(async () => { + setVegaSpec(undefined); + const text2vega = text2vegaRef.current; + if (selectedSource?.label) { + const dataSource = (await selectedSource.ds.getDataSet()).dataSets.find( + (ds) => ds.title === selectedSource.label + ); + text2vega.invoke({ + index: selectedSource.label, + prompt: input, + dataSourceId: dataSource?.dataSourceId, + }); + } + }, [selectedSource, input]); + + const factory = embeddable.getEmbeddableFactory(VISUALIZE_EMBEDDABLE_TYPE); + const vis = useMemo(() => { + return vegaSpec + ? visualizations.convertToSerializedVis({ + title: vegaSpec?.title ?? 'vega', + description: vegaSpec?.description ?? '', + visState: { + title: vegaSpec?.title ?? 'vega', + type: 'vega', + aggs: [], + params: { + spec: JSON.stringify(vegaSpec, null, 4), + }, + }, + }) + : null; + }, [vegaSpec]); + + const onSaveClick = useCallback(async () => { + if (!vis) return; + + const doSave = async (onSaveProps: OnSaveProps) => { + const savedVis: VisSavedObject = await visualizations.savedVisualizationsLoader.get(); // .createVis('vega', vis) + savedVis.visState = { + title: onSaveProps.newTitle, + type: vis.type, + params: vis.params, + aggs: [], + }; + savedVis.title = onSaveProps.newTitle; + savedVis.description = onSaveProps.newDescription; + savedVis.copyOnSave = onSaveProps.newCopyOnSave; + try { + const id = await savedVis.save({ + isTitleDuplicateConfirmed: onSaveProps.isTitleDuplicateConfirmed, + onTitleDuplicate: onSaveProps.onTitleDuplicate, + }); + if (id) { + notifications.toasts.addSuccess({ + title: i18n.translate('dashboardAssistant.feature.text2viz.saveSuccess', { + defaultMessage: `Saved '{title}'`, + values: { + title: savedVis.title, + }, + }), + }); + dialog.close(); + } + } catch (e) { + notifications.toasts.addDanger({ + title: i18n.translate('dashboardAssistant.feature.text2viz.saveFail', { + defaultMessage: `Error on saving '{title}'`, + values: { + title: savedVis.title, + }, + }), + }); + } + }; + + const dialog = overlays.openModal( + toMountPoint( + dialog.close()} + onSave={doSave} + /> + ) + ); + }, [vis, visualizations, notifications]); + + useEffect(() => { + const breadcrumbs: EuiBreadcrumb[] = [ + { + text: 'Visualize', + onClick: () => { + application.navigateToApp('visualize'); + }, + }, + { + text: 'Create', + }, + ]; + chrome.setBreadcrumbs(breadcrumbs); + }, [chrome, application]); + + return ( + + + + + {i18n.translate('dashboardAssistant.feature.text2viz.save', { + defaultMessage: 'Save', + })} + + + + + + + + + setSelectedSource(ds)} + /> + + + } + placeholder="Generate visualization with a natural language question." + /> + + + + + + {status === 'STOPPED' && !vegaSpec && ( + + + + + + )} + {status === 'RUNNING' && ( + + + + + + )} + {status === 'STOPPED' && vis && ( + + + {factory && ( + + )} + + + )} + + + + + ); +}; diff --git a/public/components/visualization/text2viz_empty.tsx b/public/components/visualization/text2viz_empty.tsx new file mode 100644 index 00000000..f158f6a1 --- /dev/null +++ b/public/components/visualization/text2viz_empty.tsx @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiEmptyPrompt } from '@elastic/eui'; +import { i18n } from '@osd/i18n'; + +export const Text2VizEmpty = () => { + return ( + + {i18n.translate('dashboardAssistant.feature.text2viz.getStarted', { + defaultMessage: 'Get started', + })} + + } + body={ + <> +

+ {i18n.translate('dashboardAssistant.feature.text2viz.body', { + defaultMessage: + 'Use the Natural Language Query form field to automatically generate visualizations using simple conversational prompts.', + })} +

+ + } + /> + ); +}; diff --git a/public/components/visualization/text2viz_loading.tsx b/public/components/visualization/text2viz_loading.tsx new file mode 100644 index 00000000..8d21fd80 --- /dev/null +++ b/public/components/visualization/text2viz_loading.tsx @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { EuiEmptyPrompt, EuiLoadingLogo } from '@elastic/eui'; +import { i18n } from '@osd/i18n'; + +export const Text2VizLoading = () => { + return ( + } + title={ +

+ {i18n.translate('dashboardAssistant.feature.text2viz.loading', { + defaultMessage: 'Generating Visualization', + })} +

+ } + /> + ); +}; diff --git a/public/plugin.tsx b/public/plugin.tsx index 7585e0e5..e1309fe0 100644 --- a/public/plugin.tsx +++ b/public/plugin.tsx @@ -3,10 +3,18 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { i18n } from '@osd/i18n'; import { EuiLoadingSpinner } from '@elastic/eui'; import React, { lazy, Suspense } from 'react'; import { Subscription } from 'rxjs'; -import { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '../../../src/core/public'; +import { + AppMountParameters, + AppNavLinkStatus, + CoreSetup, + CoreStart, + Plugin, + PluginInitializerContext, +} from '../../../src/core/public'; import { createOpenSearchDashboardsReactContext, toMountPoint, @@ -94,6 +102,48 @@ export class AssistantPlugin dataSourceManagement: setupDeps.dataSourceManagement, }); + if (this.config.next.enabled) { + setupDeps.visualizations.registerAlias({ + name: 'text2viz', + aliasPath: '#/', + aliasApp: 'text2viz', + title: i18n.translate('dashboardAssistant.feature.text2viz.title', { + defaultMessage: 'Natural language', + }), + description: i18n.translate('dashboardAssistant.feature.text2viz.description', { + defaultMessage: 'Generate visualization with a natural language question.', + }), + icon: 'chatRight', + stage: 'experimental', + promotion: { + buttonText: i18n.translate('dashboardAssistant.feature.text2viz.promotion.buttonText', { + defaultMessage: 'Natural language previewer', + }), + description: i18n.translate('dashboardAssistant.feature.text2viz.promotion.description', { + defaultMessage: + 'Not sure which visualization to choose? Generate visualization previews with a natural language question.', + }), + }, + }); + + core.application.register({ + id: 'text2viz', + title: i18n.translate('dashboardAssistant.feature.text2viz', { + defaultMessage: 'Natural language previewer', + }), + navLinkStatus: AppNavLinkStatus.hidden, + mount: async (params: AppMountParameters) => { + const [coreStart, pluginsStart] = await core.getStartServices(); + const { renderText2VizApp } = await import('./text2viz'); + return renderText2VizApp(params, { + ...coreStart, + ...pluginsStart, + setHeaderActionMenu: params.setHeaderActionMenu, + }); + }, + }); + } + if (this.config.chat.enabled) { const setupChat = async () => { const [coreStart, startDeps] = await core.getStartServices(); diff --git a/public/text2viz.tsx b/public/text2viz.tsx new file mode 100644 index 00000000..634cc191 --- /dev/null +++ b/public/text2viz.tsx @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import ReactDOM from 'react-dom'; +import { AppMountParameters } from '../../../src/core/public'; +import { Text2Viz } from './components/visualization/text2viz'; +import { OpenSearchDashboardsContextProvider } from '../../../src/plugins/opensearch_dashboards_react/public'; +import { StartServices } from './types'; + +export const renderText2VizApp = (params: AppMountParameters, services: StartServices) => { + ReactDOM.render( + + + , + + params.element + ); + return () => { + ReactDOM.unmountComponentAtNode(params.element); + }; +}; diff --git a/public/types.ts b/public/types.ts index c9417680..aa95eef6 100644 --- a/public/types.ts +++ b/public/types.ts @@ -10,6 +10,12 @@ import { IChatContext } from './contexts/chat_context'; import { MessageContentProps } from './tabs/chat/messages/message_content'; import { DataSourceServiceContract, IncontextInsightRegistry } from './services'; import { DataSourceManagementPluginSetup } from '../../../src/plugins/data_source_management/public'; +import { + VisualizationsSetup, + VisualizationsStart, +} from '../../../src/plugins/visualizations/public'; +import { DataPublicPluginSetup, DataPublicPluginStart } from '../../../src/plugins/data/public'; +import { AppMountParameters, CoreStart } from '../../../src/core/public'; export interface RenderProps { props: MessageContentProps; @@ -29,11 +35,15 @@ export interface AssistantActions { } export interface AssistantPluginStartDependencies { + data: DataPublicPluginStart; + visualizations: VisualizationsStart; embeddable: EmbeddableStart; dashboard: DashboardStart; } export interface AssistantPluginSetupDependencies { + data: DataPublicPluginSetup; + visualizations: VisualizationsSetup; embeddable: EmbeddableSetup; dataSourceManagement?: DataSourceManagementPluginSetup; } @@ -59,6 +69,11 @@ export interface AssistantStart { dataSource: DataSourceServiceContract; } +export type StartServices = CoreStart & + AssistantPluginStartDependencies & { + setHeaderActionMenu: AppMountParameters['setHeaderActionMenu']; + }; + export interface UserAccount { username: string; } diff --git a/server/plugin.ts b/server/plugin.ts index 5918864a..48777819 100644 --- a/server/plugin.ts +++ b/server/plugin.ts @@ -16,6 +16,7 @@ import { AssistantPluginSetup, AssistantPluginStart, MessageParser } from './typ import { BasicInputOutputParser } from './parsers/basic_input_output_parser'; import { VisualizationCardParser } from './parsers/visualization_card_parser'; import { registerChatRoutes } from './routes/chat_routes'; +import { registerText2VizRoutes } from './routes/text2viz_routes'; export class AssistantPlugin implements Plugin { private readonly logger: Logger; @@ -47,6 +48,11 @@ export class AssistantPlugin implements Plugin ({ observability: { show: true, diff --git a/server/routes/get_agent.ts b/server/routes/get_agent.ts new file mode 100644 index 00000000..a8d285a6 --- /dev/null +++ b/server/routes/get_agent.ts @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { OpenSearchClient } from '../../../../src/core/server'; +import { ML_COMMONS_BASE_API } from '../utils/constants'; + +export const getAgent = async (id: string, client: OpenSearchClient['transport']) => { + try { + const path = `${ML_COMMONS_BASE_API}/config/${id}`; + const response = await client.request({ + method: 'GET', + path, + }); + + if (!response || !response.body.configuration?.agent_id) { + throw new Error(`cannot get agent ${id} by calling the api: ${path}`); + } + return response.body.configuration.agent_id; + } catch (error) { + const errorMessage = JSON.stringify(error.meta?.body) || error; + throw new Error(`get agent ${id} failed, reason: ${errorMessage}`); + } +}; diff --git a/server/routes/text2viz_routes.ts b/server/routes/text2viz_routes.ts new file mode 100644 index 00000000..bf7d58ef --- /dev/null +++ b/server/routes/text2viz_routes.ts @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { schema } from '@osd/config-schema'; +import { IRouter } from '../../../../src/core/server'; +import { TEXT2VIZ_API } from '../../common/constants/llm'; +import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport'; +import { ML_COMMONS_BASE_API } from '../utils/constants'; +import { getAgent } from './get_agent'; + +const TEXT2VEGA_AGENT_CONFIG_ID = 'text2vega'; +const TEXT2PPL_AGENT_CONFIG_ID = 'text2ppl'; + +export function registerText2VizRoutes(router: IRouter) { + router.post( + { + path: TEXT2VIZ_API.TEXT2VEGA, + validate: { + body: schema.object({ + input: schema.string(), + ppl: schema.string(), + dataSchema: schema.string(), + sampleData: schema.string(), + }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), + }, + }, + router.handleLegacyErrors(async (context, req, res) => { + const client = await getOpenSearchClientTransport({ + context, + dataSourceId: req.query.dataSourceId, + }); + const agentId = await getAgent(TEXT2VEGA_AGENT_CONFIG_ID, client); + const response = await client.request({ + method: 'POST', + path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`, + body: { + parameters: { + input: req.body.input, + ppl: req.body.ppl, + dataSchema: req.body.dataSchema, + sampleData: req.body.sampleData, + }, + }, + }); + + try { + // let result = response.body.inference_results[0].output[0].dataAsMap; + let result = JSON.parse(response.body.inference_results[0].output[0].result); + // sometimes llm returns {response: } instead of + if (result.response) { + result = JSON.parse(result.response); + } + // Sometimes the response contains width and height which is not needed, here delete the these fields + delete result.width; + delete result.height; + + return res.ok({ body: result }); + } catch (e) { + return res.internalError(); + } + }) + ); + + router.post( + { + path: TEXT2VIZ_API.TEXT2PPL, + validate: { + body: schema.object({ + index: schema.string(), + question: schema.string(), + }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), + }, + }, + router.handleLegacyErrors(async (context, req, res) => { + const client = await getOpenSearchClientTransport({ + context, + dataSourceId: req.query.dataSourceId, + }); + const agentId = await getAgent(TEXT2PPL_AGENT_CONFIG_ID, client); + const response = await client.request({ + method: 'POST', + path: `${ML_COMMONS_BASE_API}/agents/${agentId}/_execute`, + body: { + parameters: { + question: req.body.question, + index: req.body.index, + }, + }, + }); + try { + const result = JSON.parse(response.body.inference_results[0].output[0].result); + return res.ok({ body: result }); + } catch (e) { + return res.internalError(); + } + }) + ); +} diff --git a/server/services/chat/olly_chat_service.test.ts b/server/services/chat/olly_chat_service.test.ts index 49362dcb..c42fca7f 100644 --- a/server/services/chat/olly_chat_service.test.ts +++ b/server/services/chat/olly_chat_service.test.ts @@ -233,7 +233,7 @@ describe('OllyChatService', () => { interactionId: 'interactionId', }) ).rejects.toMatchInlineSnapshot( - `[Error: get root agent failed, reason: Error: cannot get root agent by calling the api: /_plugins/_ml/config/os_chat]` + `[Error: get agent os_chat failed, reason: Error: cannot get agent os_chat by calling the api: /_plugins/_ml/config/os_chat]` ); }); }); diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts index bdebec5a..830f6c4f 100644 --- a/server/services/chat/olly_chat_service.ts +++ b/server/services/chat/olly_chat_service.ts @@ -8,6 +8,7 @@ import { OpenSearchClient } from '../../../../../src/core/server'; import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes'; import { ChatService } from './chat_service'; import { ML_COMMONS_BASE_API, ROOT_AGENT_CONFIG_ID } from '../../utils/constants'; +import { getAgent } from '../../routes/get_agent'; interface AgentRunPayload { question?: string; @@ -25,21 +26,7 @@ export class OllyChatService implements ChatService { constructor(private readonly opensearchClientTransport: OpenSearchClient['transport']) {} private async getRootAgent(): Promise { - try { - const path = `${ML_COMMONS_BASE_API}/config/${ROOT_AGENT_CONFIG_ID}`; - const response = await this.opensearchClientTransport.request({ - method: 'GET', - path, - }); - - if (!response || !response.body.configuration?.agent_id) { - throw new Error(`cannot get root agent by calling the api: ${path}`); - } - return response.body.configuration.agent_id; - } catch (error) { - const errorMessage = JSON.stringify(error.meta?.body) || error; - throw new Error('get root agent failed, reason: ' + errorMessage); - } + return await getAgent(ROOT_AGENT_CONFIG_ID, this.opensearchClientTransport); } private async requestAgentRun(payload: AgentRunPayload) {