Skip to content

Commit

Permalink
Removed limit param from vector_match()
Browse files Browse the repository at this point in the history
vector_match() no longer takes an optional 3rd parameter.
The limit should now be passed as the overall LIMIT of the SELECT.

Also switched to using MATCH instead of LIKE in the generated SQL,
because I found MATCH is slightly more efficient.
  • Loading branch information
snej committed Jun 24, 2024
1 parent 2b53828 commit cf41b43
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 47 deletions.
42 changes: 14 additions & 28 deletions LiteCore/Query/QueryParser+VectorSearch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,54 +29,40 @@ using namespace litecore::qp;

namespace litecore {

# ifndef REQUIRE_LIMIT
static constexpr unsigned kDefaultMaxResults = 3;
# endif
static constexpr unsigned kMaxMaxResults = 10000;
static constexpr unsigned kMaxMaxResults = 10000;

// Scans the entire query for vector_match() calls, and adds join tables for ones that are
// indexed.
void QueryParser::addVectorSearchJoins(const Dict* select) {
auto whereClause = getCaseInsensitive(select, "WHERE");
findNodes(select, kVectorMatchFnNameWithParens, 1, [&](const Array* matchExpr) {
// Arguments to vector_match are index name, target vector, and optional max-results.
// Arguments to vector_match are index name and target vector.
string tableName = FTSTableName(matchExpr->get(1), true).first;
const Value* limitVal = matchExpr->get(3);
auto targetVectorParam = matchExpr->get(2);
indexJoinInfo* info = indexJoinTable(tableName, "vector");
if ( matchExpr == whereClause || limitVal ) {
if ( matchExpr == getCaseInsensitive(select, "WHERE") ) {
// If vector_match is the entire WHERE clause, this is a simple non-hybrid query.
// This is implemented by a nested SELECT that finds the nearest vectors in
// the entire collection. Isolating this in a nested SELECT ensures SQLite doesn't
// see the outer JOIN against the collection; if it did, the vectorsearch extension's
// planner would see a constraint against `rowid` and interpret it as a hybrid search.
// https://github.com/couchbaselabs/mobile-vector-search/blob/main/docs/Extension.md
auto targetVectorParam = matchExpr->get(2);

// Figure out the limit to use in the vector query:
const char* limitName = "3rd max_results argument";
if ( !limitVal ) {
// If no limit param is given, check the LIMIT on the SELECT itself:
limitVal = getCaseInsensitive(select, "LIMIT");
limitName = "LIMIT";
}
int64_t maxResults;
if ( limitVal ) {
if ( auto limitVal = getCaseInsensitive(select, "LIMIT") ) {
maxResults = limitVal->asInt();
require(limitVal->isInteger() && maxResults > 0, "vector_match()'s %s must be a positive integer",
limitName);
require(maxResults <= kMaxMaxResults, "vector_match()'s %s must not exceed %u", limitName,
require(limitVal->isInteger() && maxResults > 0,
"LIMIT must be a positive integer when using vector_match()");
require(maxResults <= kMaxMaxResults, "LIMIT must not exceed %u when using vector_match()",
kMaxMaxResults);
} else {
# ifdef REQUIRE_LIMIT
fail("vector_match() requires a 3rd max_results argument or a LIMIT");
# else
maxResults = kDefaultMaxResults;
# endif
}

// Register a callback to write the nested SELECT in place of a table name:
info->writeTableSQL = [=] {
_sql << "(SELECT rowid, distance FROM \"" << tableName << "\" WHERE vector LIKE encode_vector(";
_sql << "(SELECT rowid, distance FROM \"" << tableName << "\" WHERE vector MATCH encode_vector(";
parseNode(targetVectorParam);
_sql << ") LIMIT " << maxResults << ")";
};
Expand All @@ -89,18 +75,18 @@ namespace litecore {
requireTopLevelConjunction("VECTOR_MATCH");
auto parentCtx = _context.rbegin() + 1;
auto parentOp = (*parentCtx)->op;
const Value* limitVal = params[2];
if ( parentOp == "SELECT"_sl || parentOp == nullslice || limitVal ) {
if ( parentOp == "SELECT"_sl || parentOp == nullslice ) {
// In a simple query the work of `vector_match` is done by the JOIN, which limits the results to the
// rowids produced by the nested query of the vector table.
// Since there's nothing to do here, replace the call with a `true`.
_sql << "true";
} else {
// In a hybrid query we do write the LIKE test at the point of the match call:
string tableName = FTSTableName(params[0], true).first;
const string& alias = indexJoinTableAlias(tableName);
const string& alias = indexJoinTableAlias(tableName, "vector");
Assert(!alias.empty());
auto targetVectorParam = params[1];
_sql << sqlIdentifier(alias) << ".vector LIKE encode_vector(";
_sql << sqlIdentifier(alias) << ".vector MATCH encode_vector(";
parseNode(targetVectorParam);
_sql << ")";
}
Expand All @@ -109,7 +95,7 @@ namespace litecore {
// Writes the SQL translation of the `vector_distance(...)` call.
void QueryParser::writeVectorDistanceFn(ArrayIterator& params) {
string tableName = FTSTableName(params[0], true).first;
_sql << indexJoinTableAlias(tableName) << ".distance";
_sql << indexJoinTableAlias(tableName, "vector") << ".distance";
}

// Given the expression to index from a vector index spec, returns the SQL of a
Expand Down
2 changes: 1 addition & 1 deletion LiteCore/Query/QueryParserTables.hh
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ namespace litecore {
{"cosine_distance", 2, 2},

// Vector search:
{"vector_match", 2, 3},
{"vector_match", 2, 2},
{"vector_distance", 1, 1},
#endif

Expand Down
3 changes: 2 additions & 1 deletion LiteCore/tests/LazyVectorQueryTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ class LazyVectorQueryTest : public VectorQueryTest {

string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target'], 5],
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'factorsindex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 5
}] )";
_query = store->compileQuery(json5(queryStr), QueryLanguage::kJSON);
REQUIRE(_query != nullptr);
Expand Down
3 changes: 2 additions & 1 deletion LiteCore/tests/PredictiveVectorQueryTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ N_WAY_TEST_CASE_METHOD(PredictiveVectorQueryTest, "Vector Index Of Prediction",
}
string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target'], 5],
WHERE: ['VECTOR_MATCH()', 'factorsindex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'factorsindex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 5
}] )";
Retained<Query> query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)};
REQUIRE(query != nullptr);
Expand Down
12 changes: 2 additions & 10 deletions LiteCore/tests/QueryParserTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,23 +679,15 @@ TEST_CASE_METHOD(QueryParserTest, "QueryParser Vector Search", "[Query][QueryPar
"ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ],"
"LIMIT: 3}]")
== "SELECT key, sequence FROM kv_default AS _doc JOIN (SELECT rowid, distance FROM "
"\"kv_default:vector:vecIndex\" WHERE vector LIKE encode_vector(array_of(12, 34)) LIMIT 3) AS vector1 ON "
"\"kv_default:vector:vecIndex\" WHERE vector MATCH encode_vector(array_of(12, 34)) LIMIT 3) AS vector1 ON "
"vector1.rowid = _doc.rowid WHERE (true) AND (_doc.flags & 1 = 0) ORDER BY vector1.distance LIMIT MAX(0, "
"3)");
// Pure vector search (explicit limit given):
CHECK(parse("['SELECT', {WHERE: ['AND', ['VECTOR_MATCH()', 'vecIndex', ['[]', 12, 34], 10],"
"['>', ['._id'], 'x'] ],"
"ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ]}]")
== "SELECT key, sequence FROM kv_default AS _doc JOIN (SELECT rowid, distance FROM "
"\"kv_default:vector:vecIndex\" WHERE vector LIKE encode_vector(array_of(12, 34)) LIMIT 10) AS vector1 ON "
"vector1.rowid = _doc.rowid WHERE (true AND _doc.key > 'x') AND (_doc.flags & 1 = 0) ORDER BY "
"vector1.distance");
// Hybrid search:
CHECK(parse("['SELECT', {WHERE: ['AND', ['VECTOR_MATCH()', 'vecIndex', ['[]', 12, 34]],"
"['>', ['._id'], 'x'] ],"
"ORDER_BY: [ ['VECTOR_DISTANCE()', 'vecIndex'] ]}]")
== "SELECT key, sequence FROM kv_default AS _doc JOIN \"kv_default:vector:vecIndex\" AS vector1 ON "
"vector1.rowid = _doc.rowid WHERE (vector1.vector LIKE encode_vector((array_of(12, 34))) AND _doc.key > "
"vector1.rowid = _doc.rowid WHERE (vector1.vector MATCH encode_vector((array_of(12, 34))) AND _doc.key > "
"'x') AND (_doc.flags & 1 = 0) ORDER BY vector1.distance");
}

Expand Down
14 changes: 8 additions & 6 deletions LiteCore/tests/VectorQueryTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index", "[Query][.Vect
// Number of results = 10
string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target'], 10],
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'vecIndex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 10
}] )";
Retained<Query> query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)};

Expand All @@ -137,9 +138,10 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index", "[Query][.Vect
// Number of Results = 5
string queryStr = R"(
['SELECT', {
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target'], 5],
WHERE: ['VECTOR_MATCH()', 'vecIndex', ['$target']],
WHAT: [ ['._id'], ['AS', ['VECTOR_DISTANCE()', 'vecIndex'], 'distance'] ],
ORDER_BY: [ ['.distance'] ],
LIMIT: 5
}] )";
Retained<Query> query{store->compileQuery(json5(queryStr), QueryLanguage::kJSON)};

Expand Down Expand Up @@ -242,7 +244,7 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index with Join", "[Qu

string queryStr = R"(SELECT META(a).id, other.publisher FROM )"s + collectionName;
queryStr += R"( AS a JOIN other ON META(a).id = other.refID )"
R"(WHERE VECTOR_MATCH(a.vecIndex, $target, 5) )";
R"(WHERE VECTOR_MATCH(a.vecIndex, $target) LIMIT 5 )";

Retained<Query> query{store->compileQuery(queryStr, QueryLanguage::kN1QL)};
REQUIRE(query != nullptr);
Expand Down Expand Up @@ -444,8 +446,8 @@ N_WAY_TEST_CASE_METHOD(SIFTVectorQueryTest, "Query Vector Index and AND with FTS

string queryStr =
R"(SELECT META(a).id, VECTOR_DISTANCE(a.vecIndex) AS distance, a.sentence FROM )"s + collectionName;
queryStr += R"( AS a WHERE VECTOR_MATCH(a.vecIndex, $target, 5))";
queryStr += R"( AND MATCH(a.sentence, "search"))";
queryStr += R"( AS a WHERE VECTOR_MATCH(a.vecIndex, $target))";
queryStr += R"( AND MATCH(a.sentence, "search") ORDER BY distance LIMIT 4)";

Retained<Query> query{store->compileQuery(queryStr, QueryLanguage::kN1QL)};
REQUIRE(query != nullptr);
Expand Down Expand Up @@ -573,7 +575,7 @@ TEST_CASE_METHOD(SIFTVectorQueryTest, "Index isTrained API", "[Query][.VectorSea

// Need to run an arbitrary query to actually train the index
string queryStr =
R"(SELECT META().id, publisher FROM )"s + collectionName + R"( WHERE VECTOR_MATCH(vecIndex, $target, 5) )";
R"(SELECT META().id, publisher FROM )"s + collectionName + R"( WHERE VECTOR_MATCH(vecIndex, $target) )";

Retained<Query> query{store->compileQuery(queryStr, QueryLanguage::kN1QL)};

Expand Down

0 comments on commit cf41b43

Please sign in to comment.