diff --git a/LiteCore/Query/QueryParser+VectorSearch.cc b/LiteCore/Query/QueryParser+VectorSearch.cc index 62257ce8f..b64b3800b 100644 --- a/LiteCore/Query/QueryParser+VectorSearch.cc +++ b/LiteCore/Query/QueryParser+VectorSearch.cc @@ -29,54 +29,40 @@ using namespace litecore::qp; namespace litecore { -# ifndef REQUIRE_LIMIT static constexpr unsigned kDefaultMaxResults = 3; -# endif - static constexpr unsigned kMaxMaxResults = 10000; + static constexpr unsigned kMaxMaxResults = 10000; // Scans the entire query for vector_match() calls, and adds join tables for ones that are // indexed. void QueryParser::addVectorSearchJoins(const Dict* select) { - auto whereClause = getCaseInsensitive(select, "WHERE"); findNodes(select, kVectorMatchFnNameWithParens, 1, [&](const Array* matchExpr) { - // Arguments to vector_match are index name, target vector, and optional max-results. - string tableName = FTSTableName(matchExpr->get(1), true).first; - const Value* limitVal = matchExpr->get(3); - indexJoinInfo* info = indexJoinTable(tableName, "vector"); - if ( matchExpr == whereClause || limitVal ) { + // Arguments to vector_match are index name and target vector. + string tableName = FTSTableName(matchExpr->get(1), true).first; + auto targetVectorParam = matchExpr->get(2); + indexJoinInfo* info = indexJoinTable(tableName, "vector"); + if ( matchExpr == getCaseInsensitive(select, "WHERE") ) { // If vector_match is the entire WHERE clause, this is a simple non-hybrid query. // This is implemented by a nested SELECT that finds the nearest vectors in // the entire collection. Isolating this in a nested SELECT ensures SQLite doesn't // see the outer JOIN against the collection; if it did, the vectorsearch extension's // planner would see a constraint against `rowid` and interpret it as a hybrid search. // https://github.com/couchbaselabs/mobile-vector-search/blob/main/docs/Extension.md - auto targetVectorParam = matchExpr->get(2); // Figure out the limit to use in the vector query: - const char* limitName = "3rd max_results argument"; - if ( !limitVal ) { - // If no limit param is given, check the LIMIT on the SELECT itself: - limitVal = getCaseInsensitive(select, "LIMIT"); - limitName = "LIMIT"; - } int64_t maxResults; - if ( limitVal ) { + if ( auto limitVal = getCaseInsensitive(select, "LIMIT") ) { maxResults = limitVal->asInt(); - require(limitVal->isInteger() && maxResults > 0, "vector_match()'s %s must be a positive integer", - limitName); - require(maxResults <= kMaxMaxResults, "vector_match()'s %s must not exceed %u", limitName, + require(limitVal->isInteger() && maxResults > 0, + "LIMIT must be a positive integer when using vector_match()"); + require(maxResults <= kMaxMaxResults, "LIMIT must not exceed %u when using vector_match()", kMaxMaxResults); } else { -# ifdef REQUIRE_LIMIT - fail("vector_match() requires a 3rd max_results argument or a LIMIT"); -# else maxResults = kDefaultMaxResults; -# endif } // Register a callback to write the nested SELECT in place of a table name: info->writeTableSQL = [=] { - _sql << "(SELECT rowid, distance FROM \"" << tableName << "\" WHERE vector LIKE encode_vector("; + _sql << "(SELECT rowid, distance FROM \"" << tableName << "\" WHERE vector MATCH encode_vector("; parseNode(targetVectorParam); _sql << ") LIMIT " << maxResults << ")"; }; @@ -87,20 +73,20 @@ namespace litecore { // Writes a `vector_match()` expression. void QueryParser::writeVectorMatchFn(ArrayIterator& params) { requireTopLevelConjunction("VECTOR_MATCH"); - auto parentCtx = _context.rbegin() + 1; - auto parentOp = (*parentCtx)->op; - const Value* limitVal = params[2]; - if ( parentOp == "SELECT"_sl || parentOp == nullslice || limitVal ) { + auto parentCtx = _context.rbegin() + 1; + auto parentOp = (*parentCtx)->op; + if ( parentOp == "SELECT"_sl || parentOp == nullslice ) { // In a simple query the work of `vector_match` is done by the JOIN, which limits the results to the // rowids produced by the nested query of the vector table. // Since there's nothing to do here, replace the call with a `true`. _sql << "true"; } else { // In a hybrid query we do write the LIKE test at the point of the match call: - string tableName = FTSTableName(params[0], true).first; - const string& alias = indexJoinTableAlias(tableName); - auto targetVectorParam = params[1]; - _sql << sqlIdentifier(alias) << ".vector LIKE encode_vector("; + string tableName = FTSTableName(params[0], true).first; + const string& alias = indexJoinTableAlias(tableName, "vector"); + Assert(!alias.empty()); + auto targetVectorParam = params[1]; + _sql << sqlIdentifier(alias) << ".vector MATCH encode_vector("; parseNode(targetVectorParam); _sql << ")"; } @@ -109,7 +95,7 @@ namespace litecore { // Writes the SQL translation of the `vector_distance(...)` call. void QueryParser::writeVectorDistanceFn(ArrayIterator& params) { string tableName = FTSTableName(params[0], true).first; - _sql << indexJoinTableAlias(tableName) << ".distance"; + _sql << indexJoinTableAlias(tableName, "vector") << ".distance"; } // Given the expression to index from a vector index spec, returns the SQL of a diff --git a/LiteCore/Query/QueryParserTables.hh b/LiteCore/Query/QueryParserTables.hh index b1773376d..291ae1099 100644 --- a/LiteCore/Query/QueryParserTables.hh +++ b/LiteCore/Query/QueryParserTables.hh @@ -241,7 +241,7 @@ namespace litecore { {"cosine_distance", 2, 2}, // Vector search: - {"vector_match", 2, 3}, + {"vector_match", 2, 2}, {"vector_distance", 1, 1}, #endif diff --git a/LiteCore/tests/LazyVectorQueryTest.cc b/LiteCore/tests/LazyVectorQueryTest.cc index 07d285a34..7da9989b9 100644 --- a/LiteCore/tests/LazyVectorQueryTest.cc +++ b/LiteCore/tests/LazyVectorQueryTest.cc @@ -50,9 +50,10 @@ class LazyVectorQueryTest : public VectorQueryTest { string queryStr = R"( ['SELECT', { - WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target'], 5], + WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target']], WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'factorsindex'], 'distance'] ], ORDER_BY: [ ['.distance'] ], + LIMIT: 5 }] )"; _query = store->compileQuery(json5(queryStr), QueryLanguage::kJSON); REQUIRE(_query != nullptr); diff --git a/LiteCore/tests/N1QLParserTest.cc b/LiteCore/tests/N1QLParserTest.cc index b4f1e6935..abe51a3a7 100644 --- a/LiteCore/tests/N1QLParserTest.cc +++ b/LiteCore/tests/N1QLParserTest.cc @@ -578,10 +578,10 @@ TEST_CASE_METHOD(N1QLParserTest, "N1QL Vector Search", "[Query][N1QL][VectorSear tableNames.emplace("kv_.scope.coll:vector:vecIndex"); CHECK(translate("SELECT META().id, VECTOR_DISTANCE(vecIndex) AS distance " - "WHERE VECTOR_MATCH(vecIndex, $target, 5) ORDER BY distance") - == "{'ORDER_BY':[['.distance']],'WHAT':[['_.',['meta()'],'.id']," + "WHERE VECTOR_MATCH(vecIndex, $target) ORDER BY distance LIMIT 5") + == "{'LIMIT':5,'ORDER_BY':[['.distance']],'WHAT':[['_.',['meta()'],'.id']," "['AS',['VECTOR_DISTANCE()','vecIndex'],'distance']]," - "'WHERE':['VECTOR_MATCH()','vecIndex',['$target'],5]}"); + "'WHERE':['VECTOR_MATCH()','vecIndex',['$target']]}"); CHECK(translate("SELECT META().id, VECTOR_DISTANCE(coll.vecIndex) AS distance " "FROM coll " diff --git a/LiteCore/tests/PredictiveVectorQueryTest.cc b/LiteCore/tests/PredictiveVectorQueryTest.cc index 75c30c318..ac773df00 100644 --- a/LiteCore/tests/PredictiveVectorQueryTest.cc +++ b/LiteCore/tests/PredictiveVectorQueryTest.cc @@ -132,9 +132,10 @@ N_WAY_TEST_CASE_METHOD(PredictiveVectorQueryTest, "Vector Index Of Prediction", } string queryStr = R"( ['SELECT', { - WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target'], 5], + WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target']], WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'factorsindex'], 'distance'] ], ORDER_BY: [ ['.distance'] ], + LIMIT: 5 }] )"; Retained query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)}; REQUIRE(query != nullptr); diff --git a/LiteCore/tests/QueryParserTest.cc b/LiteCore/tests/QueryParserTest.cc index faf93e20b..c6e224895 100644 --- a/LiteCore/tests/QueryParserTest.cc +++ b/LiteCore/tests/QueryParserTest.cc @@ -679,23 +679,15 @@ TEST_CASE_METHOD(QueryParserTest, "QueryParser Vector Search", "[Query][QueryPar "ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ]," "LIMIT: 3}]") == "SELECT key, sequence FROM kv_default AS _doc JOIN (SELECT rowid, distance FROM " - "\"kv_default:vector:vecIndex\" WHERE vector LIKE encode_vector(array_of(12, 34)) LIMIT 3) AS vector1 ON " + "\"kv_default:vector:vecIndex\" WHERE vector MATCH encode_vector(array_of(12, 34)) LIMIT 3) AS vector1 ON " "vector1.rowid = _doc.rowid WHERE (true) AND (_doc.flags & 1 = 0) ORDER BY vector1.distance LIMIT MAX(0, " "3)"); - // Pure vector search (explicit limit given): - CHECK(parse("['SELECT', {WHERE: ['AND', ['VECTOR_MATCH()', 'vecIndex', ['[]', 12, 34], 10]," - "['>', ['._id'], 'x'] ]," - "ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ]}]") - == "SELECT key, sequence FROM kv_default AS _doc JOIN (SELECT rowid, distance FROM " - "\"kv_default:vector:vecIndex\" WHERE vector LIKE encode_vector(array_of(12, 34)) LIMIT 10) AS vector1 ON " - "vector1.rowid = _doc.rowid WHERE (true AND _doc.key > 'x') AND (_doc.flags & 1 = 0) ORDER BY " - "vector1.distance"); // Hybrid search: CHECK(parse("['SELECT', {WHERE: ['AND', ['VECTOR_MATCH()', 'vecIndex', ['[]', 12, 34]]," "['>', ['._id'], 'x'] ]," "ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ]}]") == "SELECT key, sequence FROM kv_default AS _doc JOIN \"kv_default:vector:vecIndex\" AS vector1 ON " - "vector1.rowid = _doc.rowid WHERE (vector1.vector LIKE encode_vector((array_of(12, 34))) AND _doc.key > " + "vector1.rowid = _doc.rowid WHERE (vector1.vector MATCH encode_vector((array_of(12, 34))) AND _doc.key > " "'x') AND (_doc.flags & 1 = 0) ORDER BY vector1.distance"); } diff --git a/LiteCore/tests/VectorQueryTest.cc b/LiteCore/tests/VectorQueryTest.cc index 91066ffe3..b1b28d641 100644 --- a/LiteCore/tests/VectorQueryTest.cc +++ b/LiteCore/tests/VectorQueryTest.cc @@ -120,9 +120,10 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index", "[Query][.Vect // Number of results = 10 string queryStr = R"( ['SELECT', { - WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target'], 10], + WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target']], WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'vecIndex'], 'distance'] ], ORDER_BY: [ ['.distance'] ], + LIMIT: 10 }] )"; Retained query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)}; @@ -137,9 +138,10 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index", "[Query][.Vect // Number of Results = 5 string queryStr = R"( ['SELECT', { - WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target'], 5], + WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target']], WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'vecIndex'], 'distance'] ], ORDER_BY: [ ['.distance'] ], + LIMIT: 5 }] )"; Retained query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)}; @@ -242,7 +244,7 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index with Join", "[Qu string queryStr = R"(SELECT META(a).id, other.publisher FROM )"s + collectionName; queryStr += R"( AS a JOIN other ON META(a).id = other.refID )" - R"(WHERE VECTOR_MATCH(a.vecIndex, $target, 5) )"; + R"(WHERE VECTOR_MATCH(a.vecIndex, $target) LIMIT 5 )"; Retained query{store->compileQuery(queryStr, QueryLanguage::kN1QL)}; REQUIRE(query != nullptr); @@ -444,8 +446,8 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index and AND with FTS string queryStr = R"(SELECT META(a).id, VECTOR_DISTANCE(a.vecIndex) AS distance, a.sentence FROM )"s + collectionName; - queryStr += R"( AS a WHERE VECTOR_MATCH(a.vecIndex, $target, 5))"; - queryStr += R"( AND MATCH(a.sentence, "search"))"; + queryStr += R"( AS a WHERE VECTOR_MATCH(a.vecIndex, $target))"; + queryStr += R"( AND MATCH(a.sentence, "search") ORDER BY distance LIMIT 4)"; Retained query{store->compileQuery(queryStr, QueryLanguage::kN1QL)}; REQUIRE(query != nullptr); @@ -573,7 +575,7 @@ TEST_CASE_METHOD(SIFTVectorQueryTest, "Index isTrained API", "[Query][.VectorSea // Need to run an arbitrary query to actually train the index string queryStr = - R"(SELECT META().id, publisher FROM )"s + collectionName + R"( WHERE VECTOR_MATCH(vecIndex, $target, 5) )"; + R"(SELECT META().id, publisher FROM )"s + collectionName + R"( WHERE VECTOR_MATCH(vecIndex, $target) )"; Retained query{store->compileQuery(queryStr, QueryLanguage::kN1QL)};