diff --git a/connectors/migrations/20240802_table_parents.ts b/connectors/migrations/20240802_table_parents.ts new file mode 100644 index 000000000000..549a1d776845 --- /dev/null +++ b/connectors/migrations/20240802_table_parents.ts @@ -0,0 +1,184 @@ +import { getGoogleSheetTableId } from "@dust-tt/types"; +import { makeScript } from "scripts/helpers"; +import { Op } from "sequelize"; +import { v4 as uuidv4 } from "uuid"; + +import { getLocalParents as getGoogleParents } from "@connectors/connectors/google_drive/lib"; +import { getParents as getMicrosoftParents } from "@connectors/connectors/microsoft/temporal/file"; +import { getParents as getNotionParents } from "@connectors/connectors/notion/lib/parents"; +import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; +import { + getTable, + updateTableParentsField, +} from "@connectors/lib/data_sources"; +import { GoogleDriveSheet } from "@connectors/lib/models/google_drive"; +import { MicrosoftNodeModel } from "@connectors/lib/models/microsoft"; +import { NotionDatabase } from "@connectors/lib/models/notion"; +import type { Logger } from "@connectors/logger/logger"; +import { ConnectorResource } from "@connectors/resources/connector_resource"; + +export async function googleTables( + connector: ConnectorResource, + execute: boolean, + logger: Logger +): Promise { + logger.info(`Processing Google Drive connector ${connector.id}`); + const memo = uuidv4(); + const csvGoogleSheets = await GoogleDriveSheet.findAll({ + where: { connectorId: connector.id }, + }); + for (const sheet of csvGoogleSheets) { + const { driveFileId, driveSheetId, connectorId } = sheet; + + const dataSourceConfig = dataSourceConfigFromConnector(connector); + + const tableId = getGoogleSheetTableId(driveFileId, driveSheetId); + + const parents = await getGoogleParents(connectorId, tableId, memo); + + const table = await getTable({ + dataSourceConfig, + tableId, + }); + if (table && JSON.stringify(table.parents) !== JSON.stringify(parents)) { + logger.info(`Parents for ${tableId}: ${parents}`); + if (execute) { + await updateTableParentsField({ tableId, parents, dataSourceConfig }); + } + } + } +} + +export async function microsoftTables( + connector: ConnectorResource, + execute: boolean, + logger: Logger +): Promise { + logger.info(`Processing Microsoft connector ${connector.id}`); + const microsoftSheets = await MicrosoftNodeModel.findAll({ + where: { + nodeType: "worksheet", + connectorId: connector.id, + }, + }); + for (const sheet of microsoftSheets) { + const { internalId, connectorId } = sheet; + + const dataSourceConfig = dataSourceConfigFromConnector(connector); + + const parents = await getMicrosoftParents({ + connectorId, + internalId, + startSyncTs: 0, + }); + + const table = await getTable({ + dataSourceConfig, + tableId: internalId, + }); + + if (table && JSON.stringify(table.parents) !== JSON.stringify(parents)) { + logger.info(`Parents for ${internalId}: ${parents}`); + if (execute) { + await updateTableParentsField({ + tableId: internalId, + parents, + dataSourceConfig, + }); + } + } + } +} + +export async function notionTables( + connector: ConnectorResource, + execute: boolean, + logger: Logger +): Promise { + logger.info(`Processing Notion connector ${connector.id}`); + const notionDatabases = await NotionDatabase.findAll({ + where: { + connectorId: connector.id, + structuredDataUpsertedTs: { + [Op.not]: null, + }, + }, + }); + + const memo = uuidv4(); + + for (const database of notionDatabases) { + const { notionDatabaseId, connectorId } = database; + if (!connectorId) { + continue; + } + + const dataSourceConfig = dataSourceConfigFromConnector(connector); + const parents = await getNotionParents( + connectorId as number, + notionDatabaseId as string, + new Set(), + memo + ); + const table = await getTable({ + dataSourceConfig, + tableId: "notion-" + notionDatabaseId, + }); + if (table && JSON.stringify(table.parents) !== JSON.stringify(parents)) { + logger.info(`Parents for notion-${notionDatabaseId}: ${parents}`); + if (execute) { + await updateTableParentsField({ + tableId: "notion-" + notionDatabaseId, + parents, + dataSourceConfig, + }); + } + } + } +} + +export async function handleConnector( + connector: ConnectorResource, + execute: boolean, + logger: Logger +): Promise { + switch (connector.type) { + case "google_drive": + return googleTables(connector, execute, logger); + case "microsoft": + return microsoftTables(connector, execute, logger); + case "notion": + return notionTables(connector, execute, logger); + } +} + +makeScript( + { + connectorId: { type: "number", demandOption: false }, + }, + async ({ connectorId, execute }, logger) => { + if (connectorId) { + const connector = await ConnectorResource.fetchById(connectorId); + if (!connector) { + throw new Error( + `Could not find connector for connectorId ${connectorId}` + ); + } + await handleConnector(connector, execute, logger); + } else { + for (const connectorType of [ + "google_drive", + "microsoft", + "notion", + ] as const) { + const connectors = await ConnectorResource.listByType( + connectorType, + {} + ); + for (const connector of connectors) { + await handleConnector(connector, execute, logger); + } + } + } + } +); diff --git a/connectors/src/connectors/google_drive/lib/cli.ts b/connectors/src/connectors/google_drive/lib/cli.ts index 25135d2c2472..e64046fd36e9 100644 --- a/connectors/src/connectors/google_drive/lib/cli.ts +++ b/connectors/src/connectors/google_drive/lib/cli.ts @@ -1,6 +1,6 @@ import type { AdminSuccessResponseType, - GoogleDriveCheckFileResponseType, + CheckFileGenericResponseType, GoogleDriveCommandType, } from "@dust-tt/types"; import { googleDriveIncrementalSyncWorkflowId } from "@dust-tt/types"; @@ -24,7 +24,7 @@ export const google_drive = async ({ command, args, }: GoogleDriveCommandType): Promise< - AdminSuccessResponseType | GoogleDriveCheckFileResponseType + AdminSuccessResponseType | CheckFileGenericResponseType > => { const logger = topLogger.child({ majorCommand: "google_drive", diff --git a/connectors/src/connectors/google_drive/temporal/file.ts b/connectors/src/connectors/google_drive/temporal/file.ts index 35c158189746..71ce0f4064e2 100644 --- a/connectors/src/connectors/google_drive/temporal/file.ts +++ b/connectors/src/connectors/google_drive/temporal/file.ts @@ -111,7 +111,8 @@ async function handleFileExport( maxDocumentLen: number, localLogger: Logger, dataSourceConfig: DataSourceConfig, - connectorId: ModelId + connectorId: ModelId, + startSyncTs: number ): Promise { const drive = await getDriveClient(oauth2client); let res; @@ -157,6 +158,10 @@ async function handleFileExport( if (file.mimeType === "text/plain") { result = handleTextFile(res.data, maxDocumentLen); } else if (file.mimeType === "text/csv") { + const parents = ( + await getFileParentsMemoized(connectorId, oauth2client, file, startSyncTs) + ).map((f) => f.id); + result = await handleCsvFile({ data: res.data, file, @@ -164,6 +169,7 @@ async function handleFileExport( localLogger, dataSourceConfig, connectorId, + parents, }); } else { result = await handleTextExtraction(res.data, localLogger, file.mimeType); @@ -259,7 +265,8 @@ export async function syncOneFile( file, localLogger, dataSourceConfig, - maxDocumentLen + maxDocumentLen, + startSyncTs ); } else { return syncOneFileTextDocument( @@ -284,7 +291,8 @@ async function syncOneFileTable( file: GoogleDriveObjectType, localLogger: Logger, dataSourceConfig: DataSourceConfig, - maxDocumentLen: number + maxDocumentLen: number, + startSyncTs: number ) { let skipReason: string | undefined; const upsertTimestampMs = undefined; @@ -292,7 +300,12 @@ async function syncOneFileTable( const documentId = getDocumentId(file.id); if (isGoogleDriveSpreadSheetFile(file)) { - const res = await syncSpreadSheet(oauth2client, connectorId, file); + const res = await syncSpreadSheet( + oauth2client, + connectorId, + file, + startSyncTs + ); if (!res.isSupported) { return false; } @@ -310,7 +323,8 @@ async function syncOneFileTable( maxDocumentLen, localLogger, dataSourceConfig, - connectorId + connectorId, + startSyncTs ); } await updateGoogleDriveFiles( @@ -357,7 +371,8 @@ async function syncOneFileTextDocument( maxDocumentLen, localLogger, dataSourceConfig, - connectorId + connectorId, + startSyncTs ); } if (documentContent) { diff --git a/connectors/src/connectors/google_drive/temporal/spreadsheets.ts b/connectors/src/connectors/google_drive/temporal/spreadsheets.ts index de2192318313..3d7850d05cfd 100644 --- a/connectors/src/connectors/google_drive/temporal/spreadsheets.ts +++ b/connectors/src/connectors/google_drive/temporal/spreadsheets.ts @@ -12,6 +12,7 @@ import type { sheets_v4 } from "googleapis"; import { google } from "googleapis"; import type { OAuth2Client } from "googleapis-common"; +import { getFileParentsMemoized } from "@connectors/connectors/google_drive/lib/hierarchy"; import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; import { concurrentExecutor } from "@connectors/lib/async_utils"; import { MAX_FILE_SIZE_TO_DOWNLOAD } from "@connectors/lib/data_sources"; @@ -47,6 +48,7 @@ async function upsertSheetInDb(connector: ConnectorResource, sheet: Sheet) { async function upsertTable( connector: ConnectorResource, sheet: Sheet, + parents: string[], rows: string[][], loggerArgs: object ) { @@ -77,6 +79,7 @@ async function upsertTable( spreadsheetId: spreadsheet.id, }, truncate: true, + parents: [tableId, ...parents], }); logger.info(loggerArgs, "[Spreadsheet] Table upserted."); @@ -171,7 +174,8 @@ function getValidRows(allRows: string[][], loggerArgs: object): string[][] { async function processSheet( connector: ConnectorResource, - sheet: Sheet + sheet: Sheet, + parents: string[] ): Promise { if (!sheet.values) { return false; @@ -196,7 +200,7 @@ async function processSheet( const rows = await getValidRows(sheet.values, loggerArgs); // Assuming the first line as headers, at least one additional data line is required. if (rows.length > 1) { - await upsertTable(connector, sheet, rows, loggerArgs); + await upsertTable(connector, sheet, parents, rows, loggerArgs); await upsertSheetInDb(connector, sheet); @@ -365,7 +369,8 @@ async function getAllSheetsFromSpreadSheet( export async function syncSpreadSheet( oauth2client: OAuth2Client, connectorId: ModelId, - file: GoogleDriveObjectType + file: GoogleDriveObjectType, + startSyncTs: number ): Promise< | { isSupported: false; @@ -474,9 +479,21 @@ export async function syncSpreadSheet( }, }); + const parents = [ + file.id, + ...( + await getFileParentsMemoized( + connectorId, + oauth2client, + file, + startSyncTs + ) + ).map((f) => f.id), + ]; + const successfulSheetIdImports: number[] = []; for (const sheet of sheets) { - const isImported = await processSheet(connector, sheet); + const isImported = await processSheet(connector, sheet, parents); if (isImported) { successfulSheetIdImports.push(sheet.id); } diff --git a/connectors/src/connectors/microsoft/index.ts b/connectors/src/connectors/microsoft/index.ts index dcb5d4ed6e0d..0b9c218e4c73 100644 --- a/connectors/src/connectors/microsoft/index.ts +++ b/connectors/src/connectors/microsoft/index.ts @@ -315,7 +315,6 @@ export class MicrosoftConnectorManager extends BaseConnectorManager { nodesWithPermissions.filter((n) => n.permission === filterPermission) ); } - return new Ok(nodesWithPermissions); } diff --git a/connectors/src/connectors/microsoft/lib/cli.ts b/connectors/src/connectors/microsoft/lib/cli.ts index cbba08b4d154..c027e5b7e3ae 100644 --- a/connectors/src/connectors/microsoft/lib/cli.ts +++ b/connectors/src/connectors/microsoft/lib/cli.ts @@ -1,6 +1,6 @@ import type { AdminSuccessResponseType, - MicrosoftCheckFileResponseType, + CheckFileGenericResponseType, MicrosoftCommandType, } from "@dust-tt/types"; import { googleDriveIncrementalSyncWorkflowId } from "@dust-tt/types"; @@ -52,7 +52,7 @@ export const microsoft = async ({ command, args, }: MicrosoftCommandType): Promise< - AdminSuccessResponseType | MicrosoftCheckFileResponseType + AdminSuccessResponseType | CheckFileGenericResponseType > => { switch (command) { case "garbage-collect-all": { diff --git a/connectors/src/connectors/microsoft/temporal/activities.ts b/connectors/src/connectors/microsoft/temporal/activities.ts index ed8e66b2a0f4..239729de3198 100644 --- a/connectors/src/connectors/microsoft/temporal/activities.ts +++ b/connectors/src/connectors/microsoft/temporal/activities.ts @@ -790,7 +790,6 @@ async function updateParentsField({ const parents = await getParents({ connectorId: file.connectorId, internalId: file.internalId, - parentInternalId: file.parentInternalId, startSyncTs, }); diff --git a/connectors/src/connectors/microsoft/temporal/file.ts b/connectors/src/connectors/microsoft/temporal/file.ts index 06e95aa65d80..2d7fbbb08ce5 100644 --- a/connectors/src/connectors/microsoft/temporal/file.ts +++ b/connectors/src/connectors/microsoft/temporal/file.ts @@ -157,6 +157,13 @@ export async function syncOneFile({ if (mimeType === "application/vnd.ms-excel" || mimeType === "text/csv") { const data = Buffer.from(downloadRes.data); + + const parents = await getParents({ + connectorId, + internalId: documentId, + startSyncTs, + }); + result = await handleCsvFile({ dataSourceConfig, data, @@ -164,6 +171,7 @@ export async function syncOneFile({ localLogger, maxDocumentLen, connectorId, + parents, }); } else if ( mimeType === @@ -174,6 +182,7 @@ export async function syncOneFile({ file, parentInternalId, localLogger, + startSyncTs, }); } else if (mimeType === "text/plain") { result = handleTextFile(downloadRes.data, maxDocumentLen); @@ -247,7 +256,6 @@ export async function syncOneFile({ const parents = await getParents({ connectorId, internalId: documentId, - parentInternalId, startSyncTs, }); parents.reverse(); @@ -295,55 +303,48 @@ export async function syncOneFile({ export async function getParents({ connectorId, internalId, - parentInternalId, startSyncTs, }: { connectorId: ModelId; internalId: string; - parentInternalId: string | null; startSyncTs: number; }): Promise { - if (!parentInternalId) { - return [internalId]; - } - - const parentParentInternalId = await getParentParentId( + const parentInternalId = await getParentId( connectorId, - parentInternalId, + internalId, startSyncTs ); - return parentParentInternalId + return parentInternalId ? [ internalId, ...(await getParents({ connectorId, internalId: parentInternalId, - parentInternalId: parentParentInternalId, startSyncTs, })), ] - : [internalId, parentInternalId]; + : [internalId]; } /* Fetching parent's parent id queries the db for a resource; since those * fetches can be made a lot of times during a sync, cache for a while in a * per-sync basis (given by startSyncTs) */ -const getParentParentId = cacheWithRedis( +const getParentId = cacheWithRedis( // eslint-disable-next-line @typescript-eslint/no-unused-vars - async (connectorId, parentInternalId, startSyncTs) => { - const parent = await MicrosoftNodeResource.fetchByInternalId( + async (connectorId, internalId, startSyncTs) => { + const node = await MicrosoftNodeResource.fetchByInternalId( connectorId, - parentInternalId + internalId ); - if (!parent) { + if (!node) { return ""; } - return parent.parentInternalId; + return node.parentInternalId; }, - (connectorId, parentInternalId, startSyncTs) => - `microsoft-${connectorId}-parent-${parentInternalId}-syncms-${startSyncTs}`, + (connectorId, internalId, startSyncTs) => + `microsoft-${connectorId}-parent-${internalId}-syncms-${startSyncTs}`, PARENT_SYNC_CACHE_TTL_MS ); diff --git a/connectors/src/connectors/microsoft/temporal/spreadsheets.ts b/connectors/src/connectors/microsoft/temporal/spreadsheets.ts index 2674c04d161a..2c3c3477bc01 100644 --- a/connectors/src/connectors/microsoft/temporal/spreadsheets.ts +++ b/connectors/src/connectors/microsoft/temporal/spreadsheets.ts @@ -12,6 +12,7 @@ import { getWorksheets, wrapMicrosoftGraphAPIWithResult, } from "@connectors/connectors/microsoft/lib/graph_api"; +import { getParents } from "@connectors/connectors/microsoft/temporal/file"; import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; import { concurrentExecutor } from "@connectors/lib/async_utils"; import { deleteTable, upsertTableFromCsv } from "@connectors/lib/data_sources"; @@ -65,6 +66,7 @@ async function upsertTable( internalId: string, spreadsheet: microsoftgraph.DriveItem, worksheet: microsoftgraph.WorkbookWorksheet, + parents: string[], rows: string[][], loggerArgs: object ) { @@ -92,6 +94,7 @@ async function upsertTable( spreadsheetId: spreadsheet.id ?? "", }, truncate: true, + parents, }); logger.info(loggerArgs, "[Spreadsheet] Table upserted."); @@ -104,7 +107,8 @@ async function processSheet( internalId: string, worksheet: microsoftgraph.WorkbookWorksheet, spreadsheetId: string, - localLogger: Logger + localLogger: Logger, + startSyncTs: number ): Promise> { if (!worksheet.id) { return new Err(new Error("Worksheet has no id")); @@ -165,11 +169,18 @@ async function processSheet( if (rawHeaders && rows.length > 1) { const headers = getSanitizedHeaders(rawHeaders); + const parents = await getParents({ + connectorId: connector.id, + internalId: internalId, + startSyncTs, + }); + await upsertTable( connector, internalId, spreadsheet, worksheet, + parents, [headers, ...rest], loggerArgs ); @@ -192,11 +203,13 @@ export async function handleSpreadSheet({ file, parentInternalId, localLogger, + startSyncTs, }: { connectorId: number; file: microsoftgraph.DriveItem; parentInternalId: string; localLogger: Logger; + startSyncTs: number; }): Promise> { const connector = await ConnectorResource.fetchById(connectorId); @@ -249,7 +262,8 @@ export async function handleSpreadSheet({ internalWorkSheetId, worksheet, documentId, - localLogger + localLogger, + startSyncTs ); if (importResult.isOk()) { successfulSheetIdImports.push(internalWorkSheetId); diff --git a/connectors/src/connectors/notion/temporal/activities.ts b/connectors/src/connectors/notion/temporal/activities.ts index 74db7a6545c0..ea09edcbf118 100644 --- a/connectors/src/connectors/notion/temporal/activities.ts +++ b/connectors/src/connectors/notion/temporal/activities.ts @@ -1820,6 +1820,14 @@ export async function renderAndUpsertPageFromCache({ cellSeparator: ",", rowBoundary: "", }); + + const parents = await getParents( + connector.id, + parentDb.notionDatabaseId, + new Set(), + runTimestamp.toString() + ); + await upsertTableFromCsv({ dataSourceConfig: dataSourceConfigFromConnector(connector), tableId, @@ -1829,6 +1837,7 @@ export async function renderAndUpsertPageFromCache({ loggerArgs, // We only update the rowId of for the page without truncating the rest of the table (incremental sync). truncate: false, + parents, }); } else { localLogger.info( @@ -2472,6 +2481,13 @@ export async function upsertDatabaseStructuredDataFromCache({ const upsertAt = new Date(); + const parents = await getParents( + connector.id, + databaseId, + new Set(), + runTimestamp.toString() + ); + localLogger.info("Upserting Notion Database as Table."); await upsertTableFromCsv({ dataSourceConfig, @@ -2482,6 +2498,7 @@ export async function upsertDatabaseStructuredDataFromCache({ loggerArgs, // We overwrite the whole table since we just fetched all child pages. truncate: true, + parents, }); // Same as above, but without the `dustId` column const csvForDocument = await renderDatabaseFromPages({ @@ -2502,12 +2519,6 @@ export async function upsertDatabaseStructuredDataFromCache({ "Skipping document upsert as body is too long." ); } else { - const parents = await getParents( - connector.id, - databaseId, - new Set(), - runTimestamp.toString() - ); localLogger.info("Upserting Notion Database as Document."); const prefix = `${databaseName}\n${csvHeader}`; const prefixSection = await renderPrefixSection({ diff --git a/connectors/src/connectors/shared/file.ts b/connectors/src/connectors/shared/file.ts index 4987a669b7ae..89fcc7eec995 100644 --- a/connectors/src/connectors/shared/file.ts +++ b/connectors/src/connectors/shared/file.ts @@ -50,6 +50,7 @@ export async function handleCsvFile({ localLogger, dataSourceConfig, connectorId, + parents, }: { data: ArrayBuffer; file: GoogleDriveObjectType | DriveItem; @@ -57,6 +58,7 @@ export async function handleCsvFile({ localLogger: Logger; dataSourceConfig: DataSourceConfig; connectorId: ModelId; + parents: string[]; }): Promise> { if (data.byteLength > 4 * maxDocumentLen) { localLogger.info({}, "File too big to be chunked. Skipping"); @@ -84,6 +86,7 @@ export async function handleCsvFile({ fileName: tableName, }, truncate: true, + parents, }); } catch (err) { localLogger.warn({ error: err }, "Error while parsing or upserting table"); diff --git a/connectors/src/lib/cli.ts b/connectors/src/lib/cli.ts index 8ef21b1ec932..f6bc4074efba 100644 --- a/connectors/src/lib/cli.ts +++ b/connectors/src/lib/cli.ts @@ -4,6 +4,7 @@ import type { BatchCommandType, BatchRestartAllResponseType, ConnectorsCommandType, + GetParentsResponseType, Result, TemporalCheckQueueResponseType, TemporalCommandType, @@ -94,7 +95,9 @@ export async function throwOnError(p: Promise>) { export const connectors = async ({ command, args, -}: ConnectorsCommandType): Promise => { +}: ConnectorsCommandType): Promise< + AdminSuccessResponseType | GetParentsResponseType +> => { if (!args.wId) { throw new Error("Missing --wId argument"); } @@ -161,6 +164,22 @@ export const connectors = async ({ await throwOnError(manager.resume()); return { success: true }; } + + case "get-parents": { + if (!args.fileId) { + throw new Error("Missing --fileId argument"); + } + const parents = await manager.retrieveContentNodeParents({ + internalId: args.fileId, + }); + + if (parents.isErr()) { + throw new Error(`Cannot fetch parents: ${parents.error}`); + } + + return { parents: parents.value }; + } + default: throw new Error(`Unknown workspace command: ${command}`); } diff --git a/connectors/src/lib/data_sources.ts b/connectors/src/lib/data_sources.ts index ccaeee76f121..a9281eec09d3 100644 --- a/connectors/src/lib/data_sources.ts +++ b/connectors/src/lib/data_sources.ts @@ -1,5 +1,6 @@ import type { CoreAPIDataSourceDocumentSection, + CoreAPITable, PostDataSourceDocumentRequestBody, } from "@dust-tt/types"; import { @@ -9,7 +10,7 @@ import { sectionFullText, } from "@dust-tt/types"; import { MAX_CHUNK_SIZE } from "@dust-tt/types"; -import type { AxiosRequestConfig, AxiosResponse } from "axios"; +import type { AxiosError, AxiosRequestConfig, AxiosResponse } from "axios"; import axios from "axios"; import tracer from "dd-trace"; import http from "http"; @@ -249,19 +250,58 @@ export const updateDocumentParentsField = withRetries( ); async function _updateDocumentParentsField({ - dataSourceConfig, documentId, + ...params +}: { + dataSourceConfig: DataSourceConfig; + documentId: string; + parents: string[]; + loggerArgs?: Record; +}) { + return _updateDocumentOrTableParentsField({ + ...params, + tableOrDocument: "document", + id: documentId, + }); +} + +export const updateTableParentsField = withRetries(_updateTableParentsField); + +async function _updateTableParentsField({ + tableId, + ...params +}: { + dataSourceConfig: DataSourceConfig; + tableId: string; + parents: string[]; + loggerArgs?: Record; +}) { + return _updateDocumentOrTableParentsField({ + ...params, + tableOrDocument: "table", + id: tableId, + }); +} + +async function _updateDocumentOrTableParentsField({ + dataSourceConfig, + id, parents, loggerArgs = {}, + tableOrDocument, }: { dataSourceConfig: DataSourceConfig; - documentId: string; + id: string; parents: string[]; loggerArgs?: Record; + tableOrDocument: "document" | "table"; }) { - const localLogger = logger.child({ ...loggerArgs, documentId }); + const localLogger = + tableOrDocument === "document" + ? logger.child({ ...loggerArgs, documentId: id }) + : logger.child({ ...loggerArgs, tableId: id }); const urlSafeName = encodeURIComponent(dataSourceConfig.dataSourceName); - const endpoint = `${DUST_FRONT_API}/api/v1/w/${dataSourceConfig.workspaceId}/data_sources/${urlSafeName}/documents/${documentId}/parents`; + const endpoint = `${DUST_FRONT_API}/api/v1/w/${dataSourceConfig.workspaceId}/data_sources/${urlSafeName}/${tableOrDocument}s/${id}/parents`; const dustRequestConfig: AxiosRequestConfig = { headers: { Authorization: `Bearer ${dataSourceConfig.workspaceAPIKey}`, @@ -278,7 +318,10 @@ async function _updateDocumentParentsField({ dustRequestConfig ); } catch (e) { - localLogger.error({ error: e }, "Error updating document parents field."); + localLogger.error( + { error: e }, + `Error updating ${tableOrDocument} parents field.` + ); throw e; } @@ -290,10 +333,10 @@ async function _updateDocumentParentsField({ status: dustRequestResult.status, data: dustRequestResult.data, }, - "Error updating document parents field." + `Error updating ${tableOrDocument} parents field.` ); throw new Error( - `Error updating document parents field: ${dustRequestResult}` + `Error updating ${tableOrDocument} parents field: ${dustRequestResult}` ); } } @@ -513,6 +556,7 @@ export async function upsertTableFromCsv({ tableCsv, loggerArgs, truncate, + parents, }: { dataSourceConfig: DataSourceConfig; tableId: string; @@ -521,6 +565,7 @@ export async function upsertTableFromCsv({ tableCsv: string; loggerArgs?: Record; truncate: boolean; + parents: string[]; }) { const localLogger = logger.child({ ...loggerArgs, tableId, tableName }); const statsDTags = [ @@ -528,7 +573,7 @@ export async function upsertTableFromCsv({ `workspace_id:${dataSourceConfig.workspaceId}`, ]; - localLogger.info("Attempting to upload structured data to Dust."); + localLogger.info("Attempting to upload table to Dust."); statsDClient.increment( "data_source_structured_data_upserts_attempt.count", 1, @@ -541,6 +586,7 @@ export async function upsertTableFromCsv({ const endpoint = `${DUST_FRONT_API}/api/v1/w/${dataSourceConfig.workspaceId}/data_sources/${urlSafeName}/tables/csv`; const dustRequestPayload = { name: tableName, + parents, description: tableDescription, csv: tableCsv, tableId, @@ -586,7 +632,7 @@ export async function upsertTableFromCsv({ csv: dustRequestPayload.csv.substring(0, 100), }, }, - "Axios error uploading structured data to Dust." + "Axios error uploading table to Dust." ); } else if (e instanceof Error) { localLogger.error( @@ -597,13 +643,13 @@ export async function upsertTableFromCsv({ csv: dustRequestPayload.csv.substring(0, 100), }, }, - "Error uploading structured data to Dust." + "Error uploading table to Dust." ); } else { - localLogger.error("Unknown error uploading structured data to Dust."); + localLogger.error("Unknown error uploading table to Dust."); } - throw new Error("Error uploading structured data to Dust."); + throw new Error("Error uploading table to Dust."); } const elapsed = new Date().getTime() - now.getTime(); @@ -619,7 +665,7 @@ export async function upsertTableFromCsv({ elapsed, statsDTags ); - localLogger.info("Successfully uploaded structured data to Dust."); + localLogger.info("Successfully uploaded table to Dust."); } else { statsDClient.increment( "data_source_structured_data_upserts_error.count", @@ -636,7 +682,7 @@ export async function upsertTableFromCsv({ status: dustRequestResult.status, elapsed, }, - "Error uploading structured data to Dust." + "Error uploading table to Dust." ); throw new Error( `Error uploading to dust, got ${ @@ -667,7 +713,7 @@ export async function deleteTableRow({ `workspace_id:${dataSourceConfig.workspaceId}`, ]; - localLogger.info("Attempting to delete structured data from Dust."); + localLogger.info("Attempting to delete table from Dust."); statsDClient.increment( "data_source_structured_data_deletes_attempt.count", 1, @@ -703,17 +749,14 @@ export async function deleteTableRow({ elapsed, statsDTags ); - localLogger.error( - { error: e }, - "Error deleting structured data from Dust." - ); + localLogger.error({ error: e }, "Error deleting table from Dust."); throw e; } const elapsed = new Date().getTime() - now.getTime(); if (dustRequestResult.status === 404) { - localLogger.info("Structured data doesn't exist on Dust. Ignoring."); + localLogger.info("Table doesn't exist on Dust. Ignoring."); return; } @@ -724,7 +767,7 @@ export async function deleteTableRow({ statsDTags ); - localLogger.info("Successfully deleted structured data from Dust."); + localLogger.info("Successfully deleted table from Dust."); } else { statsDClient.increment( "data_source_structured_data_deletes_error.count", @@ -741,12 +784,47 @@ export async function deleteTableRow({ status: dustRequestResult.status, elapsed, }, - "Error deleting structured data from Dust." + "Error deleting table from Dust." ); throw new Error(`Error deleting from dust: ${dustRequestResult}`); } } +export async function getTable({ + dataSourceConfig, + tableId, +}: { + dataSourceConfig: DataSourceConfig; + tableId: string; +}): Promise { + const localLogger = logger.child({ + tableId, + }); + + const urlSafeName = encodeURIComponent(dataSourceConfig.dataSourceName); + const endpoint = `${DUST_FRONT_API}/api/v1/w/${dataSourceConfig.workspaceId}/data_sources/${urlSafeName}/tables/${tableId}`; + const dustRequestConfig: AxiosRequestConfig = { + headers: { + Authorization: `Bearer ${dataSourceConfig.workspaceAPIKey}`, + }, + }; + + let dustRequestResult: AxiosResponse; + try { + dustRequestResult = await axiosWithTimeout.get(endpoint, dustRequestConfig); + } catch (e) { + const axiosError = e as AxiosError; + if (axiosError?.response?.status === 404) { + localLogger.info("Table doesn't exist on Dust. Ignoring."); + return; + } + localLogger.error({ error: e }, "Error getting table from Dust."); + throw e; + } + + return dustRequestResult.data.table; +} + export async function deleteTable({ dataSourceConfig, tableId, @@ -765,7 +843,7 @@ export async function deleteTable({ `workspace_id:${dataSourceConfig.workspaceId}`, ]; - localLogger.info("Attempting to delete structured data from Dust."); + localLogger.info("Attempting to delete table from Dust."); statsDClient.increment( "data_source_structured_data_deletes_attempt.count", 1, @@ -801,17 +879,14 @@ export async function deleteTable({ elapsed, statsDTags ); - localLogger.error( - { error: e }, - "Error deleting structured data from Dust." - ); + localLogger.error({ error: e }, "Error deleting table from Dust."); throw e; } const elapsed = new Date().getTime() - now.getTime(); if (dustRequestResult.status === 404) { - localLogger.info("Structured data doesn't exist on Dust. Ignoring."); + localLogger.info("Table doesn't exist on Dust. Ignoring."); return; } @@ -822,7 +897,7 @@ export async function deleteTable({ statsDTags ); - localLogger.info("Successfully deleted structured data from Dust."); + localLogger.info("Successfully deleted table from Dust."); } else { statsDClient.increment( "data_source_structured_data_deletes_error.count", @@ -839,7 +914,7 @@ export async function deleteTable({ status: dustRequestResult.status, elapsed, }, - "Error deleting structured data from Dust." + "Error deleting table from Dust." ); throw new Error(`Error deleting from dust: ${dustRequestResult}`); } diff --git a/core/bin/dust_api.rs b/core/bin/dust_api.rs index 5096c5898257..f04804032e26 100644 --- a/core/bin/dust_api.rs +++ b/core/bin/dust_api.rs @@ -2084,6 +2084,53 @@ async fn tables_delete( } } +async fn tables_update_parents( + Path((project_id, data_source_id, table_id)): Path<(i64, String, String)>, + State(state): State>, + Json(payload): Json, +) -> (StatusCode, Json) { + let project = project::Project::new_from_id(project_id); + + match state + .store + .load_table(&project, &data_source_id, &table_id) + .await + { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + "Failed to load table", + Some(e), + ), + Ok(None) => error_response( + StatusCode::NOT_FOUND, + "table_not_found", + &format!("No table found for id `{}`", table_id), + None, + ), + Ok(Some(table)) => match table + .update_parents(state.store.clone(), payload.parents.clone()) + .await + { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + "Failed to update table parents", + Some(e), + ), + Ok(_) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "success": true, + })), + }), + ), + }, + } +} + #[derive(serde::Deserialize)] struct TablesRowsUpsertPayload { rows: Vec, @@ -2725,6 +2772,10 @@ fn main() { "/projects/:project_id/data_sources/:data_source_id/tables", post(tables_upsert), ) + .route( + "/projects/:project_id/data_sources/:data_source_id/tables/:table_id/parents", + patch(tables_update_parents), + ) .route( "/projects/:project_id/data_sources/:data_source_id/tables/:table_id", get(tables_retrieve), diff --git a/core/src/databases/database.rs b/core/src/databases/database.rs index 52af1e363a12..c4f9fe835a38 100644 --- a/core/src/databases/database.rs +++ b/core/src/databases/database.rs @@ -401,6 +401,22 @@ impl Table { .await } + pub async fn update_parents( + &self, + store: Box, + parents: Vec, + ) -> Result<()> { + store + .update_table_parents( + &self.project, + &self.data_source_id, + &&self.table_id, + &parents, + ) + .await?; + Ok(()) + } + async fn compute_schema( &self, databases_store: Box, diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index 2e959a5a347a..6718b513d95e 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -2273,7 +2273,8 @@ impl Store for PostgresStore { timestamp, tags_array, parents) \ VALUES (DEFAULT, $1, $2, $3, $4, $5, $6, $7, $8) \ ON CONFLICT (table_id, data_source) DO UPDATE \ - SET name = EXCLUDED.name, description = EXCLUDED.description \ + SET name = EXCLUDED.name, description = EXCLUDED.description, \ + timestamp = EXCLUDED.timestamp, tags_array = EXCLUDED.tags_array, parents = EXCLUDED.parents \ RETURNING id", ) .await?; @@ -2355,6 +2356,43 @@ impl Store for PostgresStore { Ok(()) } + async fn update_table_parents( + &self, + project: &Project, + data_source_id: &str, + table_id: &str, + parents: &Vec, + ) -> Result<()> { + let project_id = project.project_id(); + let data_source_id = data_source_id.to_string(); + let table_id = table_id.to_string(); + + let pool = self.pool.clone(); + let c = pool.get().await?; + + // Get the data source row id. + let stmt = c + .prepare( + "SELECT id FROM data_sources WHERE project = $1 AND data_source_id = $2 LIMIT 1", + ) + .await?; + let r = c.query(&stmt, &[&project_id, &data_source_id]).await?; + let data_source_row_id: i64 = match r.len() { + 0 => Err(anyhow!("Unknown DataSource: {}", data_source_id))?, + 1 => r[0].get(0), + _ => unreachable!(), + }; + + // Update parents. + let stmt = c + .prepare("UPDATE tables SET parents = $1 WHERE data_source = $2 AND table_id = $3") + .await?; + c.query(&stmt, &[&parents, &data_source_row_id, &table_id]) + .await?; + + Ok(()) + } + async fn invalidate_table_schema( &self, project: &Project, diff --git a/core/src/stores/store.rs b/core/src/stores/store.rs index 7cf748e5bd9e..63502e55ef5e 100644 --- a/core/src/stores/store.rs +++ b/core/src/stores/store.rs @@ -207,6 +207,13 @@ pub trait Store { table_id: &str, schema: &TableSchema, ) -> Result<()>; + async fn update_table_parents( + &self, + project: &Project, + data_source_id: &str, + table_id: &str, + parents: &Vec, + ) -> Result<()>; async fn invalidate_table_schema( &self, project: &Project, diff --git a/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts b/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts index 5fa1162e0b0a..d72889ed9282 100644 --- a/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts +++ b/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/index.ts @@ -15,6 +15,9 @@ export type GetTableResponseBody = { table_id: string; description: string; schema: CoreAPITableSchema | null; + timestamp: number; + tags: string[]; + parents: string[]; }; }; @@ -153,6 +156,15 @@ async function handler( tableId, }); if (tableRes.isErr()) { + if (tableRes.error.code === "table_not_found") { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "table_not_found", + message: "Failed to get table.", + }, + }); + } logger.error( { dataSourcename: dataSource.name, @@ -178,6 +190,9 @@ async function handler( table_id: table.table_id, description: table.description, schema: table.schema, + timestamp: table.timestamp, + tags: table.tags, + parents: table.parents, }, }); diff --git a/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/parents.ts b/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/parents.ts new file mode 100644 index 000000000000..4067f1a1091a --- /dev/null +++ b/front/pages/api/v1/w/[wId]/data_sources/[name]/tables/[tId]/parents.ts @@ -0,0 +1,115 @@ +import type { WithAPIErrorResponse } from "@dust-tt/types"; +import { CoreAPI } from "@dust-tt/types"; +import { isLeft } from "fp-ts/lib/Either"; +import * as t from "io-ts"; +import * as reporter from "io-ts-reporters"; +import type { NextApiRequest, NextApiResponse } from "next"; + +import config from "@app/lib/api/config"; +import { getDataSource } from "@app/lib/api/data_sources"; +import { Authenticator, getAPIKey } from "@app/lib/auth"; +import logger from "@app/logger/logger"; +import { apiError, withLogging } from "@app/logger/withlogging"; + +const ParentsBodySchema = t.type({ + parents: t.array(t.string), +}); + +export type PostParentsResponseBody = { + updated: true; +}; + +/** + * @ignoreswagger + * System API key only endpoint. Undocumented. + */ + +async function handler( + req: NextApiRequest, + res: NextApiResponse> +): Promise { + const keyRes = await getAPIKey(req); + if (keyRes.isErr()) { + return apiError(req, res, keyRes.error); + } + const { workspaceAuth } = await Authenticator.fromKey( + keyRes.value, + req.query.wId as string + ); + + const owner = workspaceAuth.workspace(); + const isSystemKey = keyRes.value.isSystem; + if (!owner || !workspaceAuth.isBuilder() || !isSystemKey) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "data_source_not_found", + message: "The data source you requested was not found.", + }, + }); + } + + const dataSource = await getDataSource( + workspaceAuth, + req.query.name as string + ); + + if (!dataSource) { + return apiError(req, res, { + status_code: 404, + api_error: { + type: "data_source_not_found", + message: "The data source you requested was not found.", + }, + }); + } + + switch (req.method) { + case "POST": + const bodyValidation = ParentsBodySchema.decode(req.body); + if (isLeft(bodyValidation)) { + const pathError = reporter.formatValidationErrors(bodyValidation.left); + return apiError(req, res, { + api_error: { + type: "invalid_request_error", + message: `Invalid request body: ${pathError}`, + }, + status_code: 400, + }); + } + const { parents } = bodyValidation.right; + + const coreAPI = new CoreAPI(config.getCoreAPIConfig(), logger); + const updateRes = await coreAPI.updateTableParents({ + projectId: dataSource.dustAPIProjectId, + dataSourceName: dataSource.name, + tableId: req.query.tId as string, + parents, + }); + + if (updateRes.isErr()) { + return apiError(req, res, { + status_code: 500, + api_error: { + type: "internal_server_error", + message: "There was an error updating the `parents` field.", + data_source_error: updateRes.error, + }, + }); + } + + res.status(200).json({ updated: true }); + return; + + default: + return apiError(req, res, { + status_code: 405, + api_error: { + type: "method_not_supported_error", + message: "The method passed is not supported, POST is expected.", + }, + }); + } +} + +export default withLogging(handler); diff --git a/types/src/connectors/admin/cli.ts b/types/src/connectors/admin/cli.ts index a9b53486d3ee..3dbd9c619495 100644 --- a/types/src/connectors/admin/cli.ts +++ b/types/src/connectors/admin/cli.ts @@ -11,6 +11,7 @@ export const ConnectorsCommandSchema = t.type({ t.literal("full-resync"), t.literal("set-error"), t.literal("restart"), + t.literal("get-parents"), ]), args: t.record( t.string, @@ -208,6 +209,7 @@ export const MicrosoftCommandSchema = t.type({ t.literal("start-incremental-sync"), t.literal("restart-all-incremental-sync-workflows"), t.literal("skip-file"), + t.literal("get-parents"), ]), args: t.record( t.string, @@ -217,7 +219,30 @@ export const MicrosoftCommandSchema = t.type({ export type MicrosoftCommandType = t.TypeOf; -export const MicrosoftCheckFileResponseSchema = t.type({ +export const AdminCommandSchema = t.union([ + BatchCommandSchema, + ConnectorsCommandSchema, + GithubCommandSchema, + GoogleDriveCommandSchema, + IntercomCommandSchema, + MicrosoftCommandSchema, + NotionCommandSchema, + SlackCommandSchema, + TemporalCommandSchema, + WebcrawlerCommandSchema, +]); + +export type AdminCommandType = t.TypeOf; + +export const AdminSuccessResponseSchema = t.type({ + success: t.literal(true), +}); + +export type AdminSuccessResponseType = t.TypeOf< + typeof AdminSuccessResponseSchema +>; + +export const CheckFileGenericResponseSchema = t.type({ status: t.number, // all literals from js `typeof` type: t.union([ @@ -233,32 +258,15 @@ export const MicrosoftCheckFileResponseSchema = t.type({ content: t.unknown, // google drive type, can't be iots'd }); -export type MicrosoftCheckFileResponseType = t.TypeOf< - typeof MicrosoftCheckFileResponseSchema +export type CheckFileGenericResponseType = t.TypeOf< + typeof CheckFileGenericResponseSchema >; -export const AdminCommandSchema = t.union([ - ConnectorsCommandSchema, - GithubCommandSchema, - NotionCommandSchema, - GoogleDriveCommandSchema, - SlackCommandSchema, - BatchCommandSchema, - WebcrawlerCommandSchema, - TemporalCommandSchema, - IntercomCommandSchema, - MicrosoftCommandSchema, -]); - -export type AdminCommandType = t.TypeOf; - -export const AdminSuccessResponseSchema = t.type({ - success: t.literal(true), +export const GetParentsResponseSchema = t.type({ + parents: t.array(t.string), }); -export type AdminSuccessResponseType = t.TypeOf< - typeof AdminSuccessResponseSchema ->; +export type GetParentsResponseType = t.TypeOf; export const NotionUpsertResponseSchema = t.type({ workflowId: t.string, @@ -310,26 +318,6 @@ export const NotionMeResponseSchema = t.type({ export type NotionMeResponseType = t.TypeOf; -export const GoogleDriveCheckFileResponseSchema = t.type({ - status: t.number, - // all literals from js `typeof` - type: t.union([ - t.literal("undefined"), - t.literal("object"), - t.literal("boolean"), - t.literal("number"), - t.literal("string"), - t.literal("function"), - t.literal("symbol"), - t.literal("bigint"), - ]), - content: t.unknown, // google drive type, can't be iots'd -}); - -export type GoogleDriveCheckFileResponseType = t.TypeOf< - typeof GoogleDriveCheckFileResponseSchema ->; - export const TemporalCheckQueueResponseSchema = t.type({ taskQueue: t.UnknownRecord, // temporal type, can't be iots'd }); @@ -350,17 +338,18 @@ export type TemporalUnprocessedWorkflowsResponseType = t.TypeOf< export const AdminResponseSchema = t.union([ AdminSuccessResponseSchema, BatchRestartAllResponseSchema, - NotionUpsertResponseSchema, - NotionSearchPagesResponseSchema, + CheckFileGenericResponseSchema, + GetParentsResponseSchema, + IntercomCheckConversationResponseSchema, + IntercomCheckMissingConversationsResponseSchema, + IntercomCheckTeamsResponseSchema, + IntercomFetchConversationResponseSchema, NotionCheckUrlResponseSchema, NotionMeResponseSchema, - GoogleDriveCheckFileResponseSchema, + NotionSearchPagesResponseSchema, + NotionUpsertResponseSchema, TemporalCheckQueueResponseSchema, TemporalUnprocessedWorkflowsResponseSchema, - IntercomCheckConversationResponseSchema, - IntercomFetchConversationResponseSchema, - IntercomCheckTeamsResponseSchema, - IntercomCheckMissingConversationsResponseSchema, ]); export type AdminResponseType = t.TypeOf; diff --git a/types/src/front/lib/core_api.ts b/types/src/front/lib/core_api.ts index ab51bb74f9c5..aec0ec3ce518 100644 --- a/types/src/front/lib/core_api.ts +++ b/types/src/front/lib/core_api.ts @@ -123,6 +123,9 @@ export type CoreAPITable = { name: string; description: string; schema: CoreAPITableSchema | null; + timestamp: number; + tags: string[]; + parents: string[]; }; export type CoreAPIRowValue = @@ -1116,6 +1119,37 @@ export class CoreAPI { return this._resultFromResponse(response); } + async updateTableParents({ + projectId, + dataSourceName, + tableId, + parents, + }: { + projectId: string; + dataSourceName: string; + tableId: string; + parents: string[]; + }): Promise> { + const response = await this._fetchWithError( + `${this._url}/projects/${encodeURIComponent( + projectId + )}/data_sources/${encodeURIComponent( + dataSourceName + )}/tables/${encodeURIComponent(tableId)}/parents`, + { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + parents: parents, + }), + } + ); + + return this._resultFromResponse(response); + } + async upsertTableRows({ projectId, dataSourceName,