From 9a804d3f77c032d5fffc88360535c0158957a486 Mon Sep 17 00:00:00 2001 From: Tristen Harr Date: Sat, 8 Jun 2024 00:03:47 -0500 Subject: [PATCH] add support for multi-vector points --- .gitignore | 3 +- CHANGELOG.md | 3 + connector-definition/connector-metadata.yaml | 4 +- generate-config.ts | 75 ++++++--- package-lock.json | 4 +- package.json | 2 +- src/constants.ts | 166 +++++++++++++++++++ src/handlers/explain.ts | 4 +- src/handlers/mutation.ts | 73 ++++++-- src/handlers/query.ts | 16 +- src/handlers/schema.ts | 94 +++++++---- src/index.ts | 6 +- 12 files changed, 376 insertions(+), 74 deletions(-) diff --git a/.gitignore b/.gitignore index 75c45df..3eab2cd 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ dist/ .DS_Store configuration.json PRIVATE_REMOTE_CONFIGURATION.http -config.json \ No newline at end of file +config.json +TMP.md \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index d16172e..2e5518f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Qdrant Connector Changelog This changelog documents changes between release tags. +## [0.2.0] - 2024-05-6 +* Added support for multi-vector points + ## [0.1.9] - 2024-05-6 * Add multi-arch build diff --git a/connector-definition/connector-metadata.yaml b/connector-definition/connector-metadata.yaml index 08ee9fc..f834368 100644 --- a/connector-definition/connector-metadata.yaml +++ b/connector-definition/connector-metadata.yaml @@ -1,6 +1,6 @@ packagingDefinition: type: PrebuiltDockerImage - dockerImage: ghcr.io/hasura/ndc-qdrant:v0.1.9 + dockerImage: ghcr.io/hasura/ndc-qdrant:v0.2.0 supportedEnvironmentVariables: - name: QDRANT_URL description: The url for the Qdrant database @@ -9,7 +9,7 @@ supportedEnvironmentVariables: commands: update: type: Dockerized - dockerImage: ghcr.io/hasura/ndc-qdrant:v0.1.9 + dockerImage: ghcr.io/hasura/ndc-qdrant:v0.2.0 commandArgs: - update dockerComposeWatch: diff --git a/generate-config.ts b/generate-config.ts index ac50bfa..5c07da4 100644 --- a/generate-config.ts +++ b/generate-config.ts @@ -2,7 +2,7 @@ import { getQdrantClient } from "./src/qdrant"; import fs from "fs"; import { promisify } from "util"; import { insertion } from "./src/utilities"; -import { RESTRICTED_OBJECTS, BASE_FIELDS, BASE_TYPES, INSERT_FIELDS } from "./src/constants"; +import { RESTRICTED_OBJECTS, BASE_FIELDS, BASE_TYPES, INSERT_FIELDS, INSERT_FIELDS_VECTOR } from "./src/constants"; const readFile = promisify(fs.readFile); const writeFile = promisify(fs.writeFile); let HASURA_CONFIGURATION_DIRECTORY = process.env["HASURA_CONFIGURATION_DIRECTORY"] as string | undefined; @@ -21,7 +21,7 @@ let client = getQdrantClient(QDRANT_URL, QDRANT_API_KEY); async function main() { const collections = await client.getCollections(); const collectionNames = collections.collections.map((c) => c.name); - + let collectionVectors: Record = {}; let objectTypes: Record = { ...BASE_TYPES, }; @@ -32,6 +32,7 @@ async function main() { const { points: records } = await client.scroll(cn, { limit: 1, with_payload: true, + with_vector: true }); let fieldDict = {}; let baseFields = {}; @@ -50,16 +51,29 @@ async function main() { }, ...BASE_FIELDS }; - insertFields = { - id: { - description: null, - type: { - type: "named", - name: "Int", + if (Array.isArray(records[0].vector)){ + insertFields = { + id: { + description: null, + type: { + type: "named", + name: "Int", + }, }, - }, - ...INSERT_FIELDS - }; + ...INSERT_FIELDS + }; + } else { + insertFields = { + id: { + description: null, + type: { + type: "named", + name: "Int", + }, + }, + ...INSERT_FIELDS_VECTOR + }; + } } else { baseFields = { id: { @@ -71,16 +85,36 @@ async function main() { }, ...BASE_FIELDS }; - insertFields = { - id: { - description: null, - type: { - type: "named", - name: "String", + + if (Array.isArray(records[0].vector)){ + insertFields = { + id: { + description: null, + type: { + type: "named", + name: "String", + }, }, - }, - ...INSERT_FIELDS - }; + ...INSERT_FIELDS + }; + } else { + insertFields = { + id: { + description: null, + type: { + type: "named", + name: "String", + }, + }, + ...INSERT_FIELDS_VECTOR + }; + } + } + + if (!Array.isArray(records[0].vector)){ + collectionVectors[cn] = true; + } else { + collectionVectors[cn] = false; } } @@ -112,6 +146,7 @@ async function main() { collection_names: collectionNames, object_fields: objectFields, object_types: objectTypes, + collection_vectors: collectionVectors, functions: [], procedures: [], } diff --git a/package-lock.json b/package-lock.json index 239386a..cb1b93f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "ndc-qdrant", - "version": "0.1.9", + "version": "0.2.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "ndc-qdrant", - "version": "0.1.9", + "version": "0.2.0", "dependencies": { "@hasura/ndc-sdk-typescript": "^4.5.0", "@qdrant/js-client-rest": "^1.5.0" diff --git a/package.json b/package.json index 15b387c..9935bf6 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ndc-qdrant", - "version": "0.1.9", + "version": "0.2.0", "main": "index.js", "scripts": { "start": "ts-node ./src/index.ts serve --configuration=.", diff --git a/src/constants.ts b/src/constants.ts index 527847b..aad28e3 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -136,6 +136,32 @@ export const INSERT_FIELDS: Record = { }, }; +export const INSERT_FIELDS_VECTOR: Record = { + vectors: { + description: null, + type: { + type: "array", + element_type: { + type: "array", + element_type: { + type: "named", + name: "Float", + } + } + }, + }, + vector_names: { + description: null, + type: { + type: "array", + element_type: { + type: "named", + name: "String" + } + }, + }, +}; + export const BASE_FIELDS: Record = { score: { description: null, @@ -268,6 +294,44 @@ export const BASE_TYPES: { [k: string]: ObjectType } = { }, }, }, + _searchVector: { + description: "Search the vector database for similar vectors", + fields: { + vector: { + type: { + type: "array", + element_type: { + type: "named", + name: "Float", + }, + }, + }, + name: { + type: { + type: "named", + name: "String" + } + }, + params: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "_params", + }, + }, + }, + score_threshold: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "Float", + }, + }, + }, + }, + }, _recommendInt: { description: "Provide an array of positive and negative example points and get a recommendation", @@ -357,6 +421,108 @@ export const BASE_TYPES: { [k: string]: ObjectType } = { }, }, }, + }, + _recommendIntVector: { + description: + "Provide an array of positive and negative example points and get a recommendation", + fields: { + positive: { + type: { + type: "array", + element_type: { + type: "named", + name: "Int", + }, + }, + }, + negative: { + type: { + type: "nullable", + underlying_type: { + type: "array", + element_type: { + type: "named", + name: "Int", + }, + }, + }, + }, + using: { + type: { + type: "named", + name: "String" + } + }, + params: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "_params", + }, + }, + }, + score_threshold: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "Float", + }, + }, + }, + }, + }, + _recommendStringVector: { + description: + "Provide an array of positive and negative example points and get a recommendation", + fields: { + positive: { + type: { + type: "array", + element_type: { + type: "named", + name: "String", + }, + }, + }, + negative: { + type: { + type: "nullable", + underlying_type: { + type: "array", + element_type: { + type: "named", + name: "String", + }, + }, + }, + }, + using: { + type: { + type: "named", + name: "String" + } + }, + params: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "_params", + }, + }, + }, + score_threshold: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "Float", + }, + }, + }, + }, } }; diff --git a/src/handlers/explain.ts b/src/handlers/explain.ts index b79e4e1..38c51d5 100644 --- a/src/handlers/explain.ts +++ b/src/handlers/explain.ts @@ -1,10 +1,10 @@ import { ExplainResponse, QueryRequest } from "@hasura/ndc-sdk-typescript"; import { QueryPlan, planQueries } from "./query"; -export async function doExplain(query: QueryRequest, collectionNames: string[], collectionFields: {[key: string]: string[]}): Promise{ +export async function doExplain(query: QueryRequest, collectionNames: string[], collectionFields: {[key: string]: string[]}, collectionVectors: {[key: string]: boolean}): Promise{ let explainResponse: ExplainResponse; try { - let queryPlan: QueryPlan = await planQueries(query, collectionNames, collectionFields); + let queryPlan: QueryPlan = await planQueries(query, collectionNames, collectionFields, collectionVectors); let isScroll: boolean = queryPlan.scrollQueries.length > 0; explainResponse = {details:{ queryRequest: JSON.stringify(query), diff --git a/src/handlers/mutation.ts b/src/handlers/mutation.ts index 16aada3..6307a28 100644 --- a/src/handlers/mutation.ts +++ b/src/handlers/mutation.ts @@ -23,7 +23,7 @@ export async function do_mutation(configuration: Configuration, state: State, mu if (procedure.arguments.object){ let args = procedure.arguments.object as any; let id: number = 0; - let vector: number[] = []; + let vector: (number)[] | ({[key: string]: (number)[] | undefined}) = []; const payload: any = {}; for (let field of Object.keys(args)){ if (field === "id"){ @@ -34,6 +34,15 @@ export async function do_mutation(configuration: Configuration, state: State, mu payload[field] = args[field]; } } + if (vector.length === 0){ + vector = {}; + if (args.vectors && args.vector_names){ + for (let i = 0; i < args.vector_names.length; i++){ + console.log(args.vector_names[i]); + vector[args.vector_names[i]] = args.vectors[i]; + } + } + } let existing_point = await state.client.retrieve(collection, {ids: [id]}); if (existing_point.length > 0){ operation_results.push({ @@ -71,7 +80,7 @@ export async function do_mutation(configuration: Configuration, state: State, mu if (procedure.arguments.object){ let args = procedure.arguments.object as any; let id: number = 0; - let vector: number[] = []; + let vector: (number)[] | ({[key: string]: (number)[] | undefined}) = []; const payload: any = {}; for (let field of Object.keys(args)){ if (field === "id"){ @@ -82,6 +91,15 @@ export async function do_mutation(configuration: Configuration, state: State, mu payload[field] = args[field]; } } + if (vector.length === 0){ + vector = {}; + if (args.vectors && args.vector_names){ + for (let i = 0; i < args.vector_names.length; i++){ + console.log(args.vector_names[i]); + vector[args.vector_names[i]] = args.vectors[i]; + } + } + } let point = await state.client.upsert(collection, { wait: true, points: [ @@ -125,7 +143,7 @@ export async function do_mutation(configuration: Configuration, state: State, mu if (procedure.arguments.object){ let args = procedure.arguments.object as any; let id: number = 0; - let vector: number[] = []; + let vector: (number)[] | ({[key: string]: (number)[] | undefined}) = []; const payload: any = {}; for (let field of Object.keys(args)){ if (field === "id"){ @@ -136,6 +154,15 @@ export async function do_mutation(configuration: Configuration, state: State, mu payload[field] = args[field]; } } + if (vector.length === 0){ + vector = {}; + if (args.vectors && args.vector_names){ + for (let i = 0; i < args.vector_names.length; i++){ + console.log(args.vector_names[i]); + vector[args.vector_names[i]] = args.vectors[i]; + } + } + } let existing_point = await state.client.retrieve(collection, {ids: [id]}); if (existing_point.length == 0){ operation_results.push({ @@ -174,13 +201,22 @@ export async function do_mutation(configuration: Configuration, state: State, mu let args = procedure.arguments.objects as any[]; let pointsToInsert = args.map(arg => { let id: number = arg.id; - let vector: number[] = arg.vector || []; + let vector: (number)[] | ({[key: string]: (number)[] | undefined}) = arg.vector || []; const payload: any = {}; for (let field of Object.keys(arg)) { - if (field !== "id" && field !== "vector") { + if (field !== "id" && field !== "vector" && field !== "vectors" && field !== "vector_names") { payload[field] = arg[field]; } } + if (vector.length === 0){ + vector = {}; + if (arg.vectors && arg.vector_names){ + for (let i = 0; i < arg.vector_names.length; i++){ + console.log(arg.vector_names[i]); + vector[arg.vector_names[i]] = arg.vectors[i]; + } + } + } return { id, vector, payload }; }); let existingIds = await state.client.retrieve(collection, { ids: pointsToInsert.map(p => p.id) }); @@ -214,16 +250,24 @@ export async function do_mutation(configuration: Configuration, state: State, mu // Map each argument object to a format suitable for upsert operation let pointsToUpsert = args.map(arg => { let id: number = arg.id; // Extract the id - let vector: number[] = arg.vector || []; // Extract the vector if available, or default to an empty array + let vector: (number)[] | ({[key: string]: (number)[] | undefined}) = arg.vector || []; const payload: any = {}; // Initialize an empty payload object // Populate the payload with fields from arg, excluding 'id' and 'vector' for (let field of Object.keys(arg)) { - if (field !== "id" && field !== "vector") { + if (field !== "id" && field !== "vector" && field !== "vectors" && field !== "vector_names") { payload[field] = arg[field]; } } - + if (vector.length === 0){ + vector = {}; + if (arg.vectors && arg.vector_names){ + for (let i = 0; i < arg.vector_names.length; i++){ + console.log(arg.vector_names[i]); + vector[arg.vector_names[i]] = arg.vectors[i]; + } + } + } // Return a structure suitable for the upsert operation return { id, vector, payload }; }); @@ -250,13 +294,22 @@ export async function do_mutation(configuration: Configuration, state: State, mu let args = procedure.arguments.objects as any[]; let pointsToInsert = args.map(arg => { let id: number = arg.id; - let vector: number[] = arg.vector || []; + let vector: (number)[] | ({[key: string]: (number)[] | undefined}) = arg.vector || []; const payload: any = {}; for (let field of Object.keys(arg)) { - if (field !== "id" && field !== "vector") { + if (field !== "id" && field !== "vector" && field !== "vectors" && field !== "vector_names") { payload[field] = arg[field]; } } + if (vector.length === 0){ + vector = {}; + if (arg.vectors && arg.vector_names){ + for (let i = 0; i < arg.vector_names.length; i++){ + console.log(arg.vector_names[i]); + vector[arg.vector_names[i]] = arg.vectors[i]; + } + } + } return { id, vector, payload }; }); let existingIds = await state.client.retrieve(collection, { ids: pointsToInsert.map(p => p.id) }); diff --git a/src/handlers/query.ts b/src/handlers/query.ts index 30b963d..41b8ac4 100644 --- a/src/handlers/query.ts +++ b/src/handlers/query.ts @@ -5,6 +5,7 @@ import { State } from ".."; type QueryFilter = components["schemas"]["Filter"]; type SearchRequest = components["schemas"]["SearchRequest"]; +type NamedVectorStruct = components["schemas"]["NamedVectorStruct"]; type ScrollRequest = components["schemas"]["ScrollRequest"]; type RecommendRequest = components["schemas"]["RecommendRequest"]; @@ -46,7 +47,7 @@ type SearchParams = { type SearchArguments = { vector: number[]; - vector_name?: string; + name?: string; params?: SearchParams; score_threshold?: number; }; @@ -54,9 +55,9 @@ type SearchArguments = { type RecommendArguments = { positive: number[]; negative?: number[]; - vector_name?: string; params?: SearchParams; score_threshold?: number; + using?: string; }; // Helper function to determine if a value is a float @@ -336,7 +337,7 @@ async function collectQuery(query: QueryRequest, throw new Forbidden("Unknown search argument type", {}); } searchRequest = { - vector: searchArgs.vector, + vector: searchArgs.name ? {vector: searchArgs.vector, name: searchArgs.name} : searchArgs.vector, with_vector: includeVector, with_payload: { include: includedPayloadFields @@ -377,6 +378,9 @@ async function collectQuery(query: QueryRequest, if (recommendArgs.score_threshold){ recommendRequest.score_threshold = recommendArgs.score_threshold; } + if (recommendArgs.using){ + recommendRequest.using = recommendArgs.using; + } } else { scrollRequest = { with_vector: includeVector, @@ -504,7 +508,7 @@ function rowAggregate(aggResults: { [key: string]: any }, aggVars: { [key: strin * @example * const myQueryPlan = await planQueries(myQuery, availableCollections, availableFields); */ -export async function planQueries(query: QueryRequest, collectionNames: string[], collectionFields: { [key: string]: string[] }): Promise { +export async function planQueries(query: QueryRequest, collectionNames: string[], collectionFields: { [key: string]: string[] }, collectionVectors: {[k: string]: boolean}): Promise { // Assert that the collection is registered in the schema if (!collectionNames.includes(query.collection)) { throw new Conflict("Collection not found in schema!", {}); @@ -732,8 +736,8 @@ export async function performQueries( * @param {string | null} qdrantApiKey - The API key for the qdrant service (can be null). * @returns {Promise} - A promise resolving to the query response. */ -export async function doQuery(state: State, query: QueryRequest, collectionNames: string[], collectionFields: { [key: string]: string[] }): Promise { - let queryPlan = await planQueries(query, collectionNames, collectionFields); +export async function doQuery(state: State, query: QueryRequest, collectionNames: string[], collectionFields: { [key: string]: string[] }, collection_vectors: {[k: string]: boolean}): Promise { + let queryPlan = await planQueries(query, collectionNames, collectionFields, collection_vectors); return await performQueries( state, query, diff --git a/src/handlers/schema.ts b/src/handlers/schema.ts index cee30eb..843f18b 100644 --- a/src/handlers/schema.ts +++ b/src/handlers/schema.ts @@ -1,49 +1,87 @@ import { ObjectType, SchemaResponse, CollectionInfo, FunctionInfo, ProcedureInfo, ArgumentInfo, Forbidden } from "@hasura/ndc-sdk-typescript"; import { SCALAR_TYPES } from "../constants"; -export function doGetSchema(objectTypes: { [k: string]: ObjectType }, collectionNames: string[], functions: FunctionInfo[], procedures: ProcedureInfo[]): SchemaResponse { +export function doGetSchema(objectTypes: { [k: string]: ObjectType }, collection_names: string[], functions: FunctionInfo[], procedures: ProcedureInfo[], collection_vectors: {[k: string]: boolean}): SchemaResponse { let collectionInfos: CollectionInfo[] = []; let functionsInfo: FunctionInfo[] = []; let proceduresInfo: ProcedureInfo[] = []; + console.log("GETTING SCHEMA"); + console.log(collection_vectors); for (const cn of Object.keys(objectTypes)){ - if (collectionNames.includes(cn)){ + if (collection_names.includes(cn)){ let ID_FIELD_TYPE = "Int"; if (objectTypes[cn].fields["id"]["type"]["type"] === "named"){ ID_FIELD_TYPE = (objectTypes[cn].fields["id"]["type"] as any)["name"]; } else { throw new Forbidden("Invalid ID type", {}); } - collectionInfos.push({ - name: `${cn}`, - description: null, - arguments: { - search: { - type: { - type: "nullable", - underlying_type: { - type: "named", - name: "_search" + if (!collection_vectors[cn]){ + collectionInfos.push({ + name: `${cn}`, + description: null, + arguments: { + search: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "_search" + } + } + }, + recommend: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: `_recommend${ID_FIELD_TYPE}` + } } } }, - recommend: { - type: { - type: "nullable", - underlying_type: { - type: "named", - name: `_recommend${ID_FIELD_TYPE}` + type: cn, + uniqueness_constraints: { + [`${cn.charAt(0).toUpperCase() + cn.slice(1)}ByID`]: { + unique_columns: ["id"] + } + }, + foreign_keys: {} + }); + } else { + collectionInfos.push({ + name: `${cn}`, + description: null, + arguments: { + search: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: "_searchVector" + } + } + }, + recommend: { + type: { + type: "nullable", + underlying_type: { + type: "named", + name: `_recommend${ID_FIELD_TYPE}Vector` + } } } - } - }, - type: cn, - uniqueness_constraints: { - [`${cn.charAt(0).toUpperCase() + cn.slice(1)}ByID`]: { - unique_columns: ["id"] - } - }, - foreign_keys: {} - }); + }, + type: cn, + uniqueness_constraints: { + [`${cn.charAt(0).toUpperCase() + cn.slice(1)}ByID`]: { + unique_columns: ["id"] + } + }, + foreign_keys: {} + }); + } + + const proc_insert_one: ProcedureInfo = { name: `insert_${cn}_one`, description: `Insert a single record into the ${cn} collection`, diff --git a/src/index.ts b/src/index.ts index 6b9f490..d76b260 100644 --- a/src/index.ts +++ b/src/index.ts @@ -32,6 +32,7 @@ export type ConfigurationSchema = { collection_names: string[]; object_fields: {[k: string]: string[]}; object_types: { [k: string]: ObjectType}; + collection_vectors: {[k: string]: boolean}; functions: FunctionInfo[]; procedures: ProcedureInfo[]; } @@ -113,7 +114,7 @@ const connector: Connector = { if (!configuration.config){ throw new Forbidden("Internal Server Error, server configuration is invalid", {}); } - return Promise.resolve(doGetSchema(configuration.config.object_types, configuration.config.collection_names, configuration.config.functions, configuration.config.procedures)); + return Promise.resolve(doGetSchema(configuration.config.object_types, configuration.config.collection_names, configuration.config.functions, configuration.config.procedures, configuration.config.collection_vectors)); }, /** @@ -133,7 +134,7 @@ const connector: Connector = { if (!configuration.config){ throw new Forbidden("Internal Server Error, server configuration is invalid", {}); } - return doExplain(request, configuration.config.collection_names, configuration.config.object_fields); + return doExplain(request, configuration.config.collection_names, configuration.config.object_fields, configuration.config.collection_vectors); }, /** @@ -178,6 +179,7 @@ const connector: Connector = { request, configuration.config.collection_names, configuration.config.object_fields, + configuration.config.collection_vectors ); },