Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed limit param from vector_match() #2080

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 20 additions & 34 deletions LiteCore/Query/QueryParser+VectorSearch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,54 +29,40 @@ using namespace litecore::qp;

namespace litecore {

# ifndef REQUIRE_LIMIT
static constexpr unsigned kDefaultMaxResults = 3;
# endif
static constexpr unsigned kMaxMaxResults = 10000;
static constexpr unsigned kMaxMaxResults = 10000;

// Scans the entire query for vector_match() calls, and adds join tables for ones that are
// indexed.
void QueryParser::addVectorSearchJoins(const Dict* select) {
auto whereClause = getCaseInsensitive(select, "WHERE");
findNodes(select, kVectorMatchFnNameWithParens, 1, [&](const Array* matchExpr) {
// Arguments to vector_match are index name, target vector, and optional max-results.
string tableName = FTSTableName(matchExpr->get(1), true).first;
const Value* limitVal = matchExpr->get(3);
indexJoinInfo* info = indexJoinTable(tableName, "vector");
if ( matchExpr == whereClause || limitVal ) {
// Arguments to vector_match are index name and target vector.
string tableName = FTSTableName(matchExpr->get(1), true).first;
auto targetVectorParam = matchExpr->get(2);
indexJoinInfo* info = indexJoinTable(tableName, "vector");
if ( matchExpr == getCaseInsensitive(select, "WHERE") ) {
// If vector_match is the entire WHERE clause, this is a simple non-hybrid query.
// This is implemented by a nested SELECT that finds the nearest vectors in
// the entire collection. Isolating this in a nested SELECT ensures SQLite doesn't
// see the outer JOIN against the collection; if it did, the vectorsearch extension's
// planner would see a constraint against `rowid` and interpret it as a hybrid search.
// https://github.com/couchbaselabs/mobile-vector-search/blob/main/docs/Extension.md
auto targetVectorParam = matchExpr->get(2);

// Figure out the limit to use in the vector query:
const char* limitName = "3rd max_results argument";
if ( !limitVal ) {
// If no limit param is given, check the LIMIT on the SELECT itself:
limitVal = getCaseInsensitive(select, "LIMIT");
limitName = "LIMIT";
}
int64_t maxResults;
if ( limitVal ) {
if ( auto limitVal = getCaseInsensitive(select, "LIMIT") ) {
maxResults = limitVal->asInt();
require(limitVal->isInteger() && maxResults > 0, "vector_match()'s %s must be a positive integer",
limitName);
require(maxResults <= kMaxMaxResults, "vector_match()'s %s must not exceed %u", limitName,
require(limitVal->isInteger() && maxResults > 0,
"LIMIT must be a positive integer when using vector_match()");
require(maxResults <= kMaxMaxResults, "LIMIT must not exceed %u when using vector_match()",
kMaxMaxResults);
} else {
# ifdef REQUIRE_LIMIT
fail("vector_match() requires a 3rd max_results argument or a LIMIT");
# else
maxResults = kDefaultMaxResults;
# endif
}

// Register a callback to write the nested SELECT in place of a table name:
info->writeTableSQL = [=] {
_sql << "(SELECT rowid, distance FROM \"" << tableName << "\" WHERE vector LIKE encode_vector(";
_sql << "(SELECT rowid, distance FROM \"" << tableName << "\" WHERE vector MATCH encode_vector(";
parseNode(targetVectorParam);
_sql << ") LIMIT " << maxResults << ")";
};
Expand All @@ -87,20 +73,20 @@ namespace litecore {
// Writes a `vector_match()` expression.
void QueryParser::writeVectorMatchFn(ArrayIterator& params) {
requireTopLevelConjunction("VECTOR_MATCH");
auto parentCtx = _context.rbegin() + 1;
auto parentOp = (*parentCtx)->op;
const Value* limitVal = params[2];
if ( parentOp == "SELECT"_sl || parentOp == nullslice || limitVal ) {
auto parentCtx = _context.rbegin() + 1;
auto parentOp = (*parentCtx)->op;
if ( parentOp == "SELECT"_sl || parentOp == nullslice ) {
// In a simple query the work of `vector_match` is done by the JOIN, which limits the results to the
// rowids produced by the nested query of the vector table.
// Since there's nothing to do here, replace the call with a `true`.
_sql << "true";
} else {
// In a hybrid query we do write the LIKE test at the point of the match call:
string tableName = FTSTableName(params[0], true).first;
const string& alias = indexJoinTableAlias(tableName);
auto targetVectorParam = params[1];
_sql << sqlIdentifier(alias) << ".vector LIKE encode_vector(";
string tableName = FTSTableName(params[0], true).first;
const string& alias = indexJoinTableAlias(tableName, "vector");
Assert(!alias.empty());
auto targetVectorParam = params[1];
_sql << sqlIdentifier(alias) << ".vector MATCH encode_vector(";
parseNode(targetVectorParam);
_sql << ")";
}
Expand All @@ -109,7 +95,7 @@ namespace litecore {
// Writes the SQL translation of the `vector_distance(...)` call.
void QueryParser::writeVectorDistanceFn(ArrayIterator& params) {
string tableName = FTSTableName(params[0], true).first;
_sql << indexJoinTableAlias(tableName) << ".distance";
_sql << indexJoinTableAlias(tableName, "vector") << ".distance";
}

// Given the expression to index from a vector index spec, returns the SQL of a
Expand Down
2 changes: 1 addition & 1 deletion LiteCore/Query/QueryParserTables.hh
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ namespace litecore {
{"cosine_distance", 2, 2},

// Vector search:
{"vector_match", 2, 3},
{"vector_match", 2, 2},
{"vector_distance", 1, 1},
#endif

Expand Down
5 changes: 4 additions & 1 deletion LiteCore/tests/LazyVectorQueryTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ class LazyVectorQueryTest : public VectorQueryTest {

string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target'], 5],
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'factorsindex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 5
}] )";
_query = store->compileQuery(json5(queryStr), QueryLanguage::kJSON);
REQUIRE(_query != nullptr);
Expand Down Expand Up @@ -141,6 +142,7 @@ TEST_CASE_METHOD(LazyVectorQueryTest, "Lazy Vector Index", "[Query][.VectorSearc
Retained<QueryEnumerator> e;
e = (_query->createEnumerator(&_options));
REQUIRE(e->getRowCount() == 0); // index is empty so far
++expectedWarningsLogged; // "Untrained index; queries may be slow."

REQUIRE(updateVectorIndex(200, alwaysUpdate) == 200);
REQUIRE(updateVectorIndex(999, alwaysUpdate) == 200);
Expand Down Expand Up @@ -169,6 +171,7 @@ TEST_CASE_METHOD(LazyVectorQueryTest, "Lazy Vector Index Skipping", "[Query][.Ve

// rec-291, rec-171 and rec-081 are missing because unindexed
checkQueryReturns({"rec-039", "rec-249", "rec-345", "rec-159", "rec-369"});
++expectedWarningsLogged; // "Untrained index; queries may be slow."

// Update the index again; only the skipped docs will appear this time.
size_t nIndexed = 0;
Expand Down
6 changes: 3 additions & 3 deletions LiteCore/tests/N1QLParserTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,10 @@ TEST_CASE_METHOD(N1QLParserTest, "N1QL Vector Search", "[Query][N1QL][VectorSear
tableNames.emplace("kv_.scope.coll:vector:vecIndex");

CHECK(translate("SELECT META().id, VECTOR_DISTANCE(vecIndex) AS distance "
"WHERE VECTOR_MATCH(vecIndex, $target, 5) ORDER BY distance")
== "{'ORDER_BY':[['.distance']],'WHAT':[['_.',['meta()'],'.id'],"
"WHERE VECTOR_MATCH(vecIndex, $target) ORDER BY distance LIMIT 5")
== "{'LIMIT':5,'ORDER_BY':[['.distance']],'WHAT':[['_.',['meta()'],'.id'],"
"['AS',['VECTOR_DISTANCE()','vecIndex'],'distance']],"
"'WHERE':['VECTOR_MATCH()','vecIndex',['$target'],5]}");
"'WHERE':['VECTOR_MATCH()','vecIndex',['$target']]}");

CHECK(translate("SELECT META().id, VECTOR_DISTANCE(coll.vecIndex) AS distance "
"FROM coll "
Expand Down
3 changes: 2 additions & 1 deletion LiteCore/tests/PredictiveVectorQueryTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ N_WAY_TEST_CASE_METHOD(PredictiveVectorQueryTest, "Vector Index Of Prediction",
}
string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target'], 5],
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'factorsindex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 5
}] )";
Retained<Query> query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)};
REQUIRE(query != nullptr);
Expand Down
12 changes: 2 additions & 10 deletions LiteCore/tests/QueryParserTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,23 +679,15 @@ TEST_CASE_METHOD(QueryParserTest, "QueryParser Vector Search", "[Query][QueryPar
"ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ],"
"LIMIT: 3}]")
== "SELECT key, sequence FROM kv_default AS _doc JOIN (SELECT rowid, distance FROM "
"\"kv_default:vector:vecIndex\" WHERE vector LIKE encode_vector(array_of(12, 34)) LIMIT 3) AS vector1 ON "
"\"kv_default:vector:vecIndex\" WHERE vector MATCH encode_vector(array_of(12, 34)) LIMIT 3) AS vector1 ON "
"vector1.rowid = _doc.rowid WHERE (true) AND (_doc.flags & 1 = 0) ORDER BY vector1.distance LIMIT MAX(0, "
"3)");
// Pure vector search (explicit limit given):
CHECK(parse("['SELECT', {WHERE: ['AND', ['VECTOR_MATCH()', 'vecIndex', ['[]', 12, 34], 10],"
"['>', ['._id'], 'x'] ],"
"ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ]}]")
== "SELECT key, sequence FROM kv_default AS _doc JOIN (SELECT rowid, distance FROM "
"\"kv_default:vector:vecIndex\" WHERE vector LIKE encode_vector(array_of(12, 34)) LIMIT 10) AS vector1 ON "
"vector1.rowid = _doc.rowid WHERE (true AND _doc.key > 'x') AND (_doc.flags & 1 = 0) ORDER BY "
"vector1.distance");
// Hybrid search:
CHECK(parse("['SELECT', {WHERE: ['AND', ['VECTOR_MATCH()', 'vecIndex', ['[]', 12, 34]],"
"['>', ['._id'], 'x'] ],"
"ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ]}]")
== "SELECT key, sequence FROM kv_default AS _doc JOIN \"kv_default:vector:vecIndex\" AS vector1 ON "
"vector1.rowid = _doc.rowid WHERE (vector1.vector LIKE encode_vector((array_of(12, 34))) AND _doc.key > "
"vector1.rowid = _doc.rowid WHERE (vector1.vector MATCH encode_vector((array_of(12, 34))) AND _doc.key > "
"'x') AND (_doc.flags & 1 = 0) ORDER BY vector1.distance");
}

Expand Down
15 changes: 9 additions & 6 deletions LiteCore/tests/VectorQueryTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index", "[Query][.Vect
// Number of results = 10
string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target'], 10],
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'vecIndex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 10
}] )";
Retained<Query> query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)};

Expand All @@ -137,9 +138,10 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index", "[Query][.Vect
// Number of Results = 5
string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target'], 5],
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'vecIndex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 5
}] )";
Retained<Query> query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)};

Expand Down Expand Up @@ -242,7 +244,7 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index with Join", "[Qu

string queryStr = R"(SELECT META(a).id, other.publisher FROM )"s + collectionName;
queryStr += R"( AS a JOIN other ON META(a).id = other.refID )"
R"(WHERE VECTOR_MATCH(a.vecIndex, $target, 5) )";
R"(WHERE VECTOR_MATCH(a.vecIndex, $target) LIMIT 5 )";

Retained<Query> query{store->compileQuery(queryStr, QueryLanguage::kN1QL)};
REQUIRE(query != nullptr);
Expand Down Expand Up @@ -444,8 +446,8 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index and AND with FTS

string queryStr =
R"(SELECT META(a).id, VECTOR_DISTANCE(a.vecIndex) AS distance, a.sentence FROM )"s + collectionName;
queryStr += R"( AS a WHERE VECTOR_MATCH(a.vecIndex, $target, 5))";
queryStr += R"( AND MATCH(a.sentence, "search"))";
queryStr += R"( AS a WHERE VECTOR_MATCH(a.vecIndex, $target))";
queryStr += R"( AND MATCH(a.sentence, "search") ORDER BY distance LIMIT 4)";

Retained<Query> query{store->compileQuery(queryStr, QueryLanguage::kN1QL)};
REQUIRE(query != nullptr);
Expand Down Expand Up @@ -573,7 +575,7 @@ TEST_CASE_METHOD(SIFTVectorQueryTest, "Index isTrained API", "[Query][.VectorSea

// Need to run an arbitrary query to actually train the index
string queryStr =
R"(SELECT META().id, publisher FROM )"s + collectionName + R"( WHERE VECTOR_MATCH(vecIndex, $target, 5) )";
R"(SELECT META().id, publisher FROM )"s + collectionName + R"( WHERE VECTOR_MATCH(vecIndex, $target) )";

Retained<Query> query{store->compileQuery(queryStr, QueryLanguage::kN1QL)};

Expand All @@ -587,6 +589,7 @@ TEST_CASE_METHOD(SIFTVectorQueryTest, "Index isTrained API", "[Query][.VectorSea

bool isTrained = collection->isIndexTrained("vecIndex"_sl);
CHECK(isTrained == expectedTrained);
if ( !isTrained ) ++expectedWarningsLogged; // "Untrained index; queries may be slow."
}

N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Inspect Vector Index", "[Query][.VectorSearch]") {
Expand Down
Loading