Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QueryTranslator: Implemented predictive index support #2141

Merged
merged 1 commit into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions LiteCore/Query/SQLiteKeyStore+PredictiveIndexes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace litecore {
// Derive the table name from the expression (path) it unnests:
auto kvTableName = tableName();
auto q_kvTableName = quotedTableName();
QueryTranslator qp(db(), "", kvTableName);
QueryTranslator qp(db(), string(kDefaultCollectionName), kvTableName);
auto predTableName = qp.predictiveTableName((FLValue)expression);

// Create the index table, unless an identical one already exists:
Expand All @@ -75,23 +75,30 @@ namespace litecore {
if ( !db().schemaExistsWithSQL(predTableName, "table", predTableName, sql) ) {
LogTo(QueryLog, "Creating predictive table '%s' on %s", predTableName.c_str(),
expression->toJSONString().c_str());
// Capture the SQL of the `predict(...)` call, _before_ creating the table.
// (If we created the table first, the query translator would generate SQL that used it!)
string predictExpr = qp.expressionSQL((FLValue)expression);
qp.setBodyColumnName("new.body");
string triggerPredictExpr = qp.expressionSQL((FLValue)expression);

// Create the index-table:
LogTo(QueryLog, "Creating predictive index table: %s", sql.c_str());
db().exec(sql);

// Populate the index-table with data from existing documents:
string predictExpr = qp.expressionSQL((FLValue)expression);
db().exec(CONCAT("INSERT INTO " << sqlIdentifier(predTableName)
<< " (docid, body) "
"SELECT rowid, "
<< predictExpr << "FROM " << q_kvTableName << " WHERE (flags & 1) = 0"));
sql = CONCAT("INSERT INTO " << sqlIdentifier(predTableName)
<< " (docid, body) "
"SELECT rowid, "
<< predictExpr << "FROM " << q_kvTableName << " as _doc WHERE (flags & 1) = 0");
LogTo(QueryLog, "Populating predictive index table: %s", sql.c_str());
db().exec(sql);

// Set up triggers to keep the index-table up to date
// ...on insertion:
qp.setBodyColumnName("new.body");
predictExpr = qp.expressionSQL((FLValue)expression);
string insertTriggerExpr = CONCAT("INSERT INTO " << sqlIdentifier(predTableName)
<< " (docid, body) "
"VALUES (new.rowid, "
<< predictExpr << ")");
<< triggerPredictExpr << ")");
createTrigger(predTableName, "ins", "AFTER INSERT", "WHEN (new.flags & 1) = 0", insertTriggerExpr);

// ...on delete:
Expand Down
2 changes: 2 additions & 0 deletions LiteCore/Query/Translator/ExprNodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ namespace litecore::qt {
#ifdef COUCHBASE_ENTERPRISE
case OpType::vectorDistance:
return new (ctx) VectorDistanceNode(operands, ctx);
case OpType::prediction:
return PredictionNode::parse(operands, ctx);
#endif
default:
// A normal OpNode
Expand Down
105 changes: 74 additions & 31 deletions LiteCore/Query/Translator/IndexedNodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,25 @@ namespace litecore::qt {
using namespace fleece;

// indexed by IndexType:
constexpr const char* kOwnerFnName[2] = {"MATCH", "APPROX_VECTOR_DISTANCE"};
constexpr const char* kIndexTypeName[3] = {"FTS", "vector", "predictive"};
constexpr const char* kOwnerFnName[3] = {"MATCH", "APPROX_VECTOR_DISTANCE", "PREDICTION"};

void IndexedNode::setIndexedExpression(ExprNode* expression) {
_indexedExpr = expression;
expression->visitTree([&](Node& n, unsigned /*depth*/) {
if ( SourceNode* nodeSource = n.source() ) {
require(_sourceCollection == nullptr || _sourceCollection == nodeSource,
"1st argument to %s may only refer to a single collection", kOwnerFnName[int(_type)]);
_sourceCollection = nodeSource;
}
});
require(_sourceCollection, "unknown source collection for %s()", kOwnerFnName[int(_type)]);
}

void IndexedNode::writeSourceTable(SQLWriter& ctx, string_view tableName) const {
require(!tableName.empty(), "missing %s index", kIndexTypeName[int(_type)]);
ctx << sqlIdentifier(tableName);
}

#pragma mark - FTS:

Expand All @@ -38,13 +56,8 @@ namespace litecore::qt {
require(source, "unknown source collection for %s()", name);
require(source->isCollection(), "invalid source collection for %s()", name);
require(path.count() > 0, "missing property after collection alias in %s()", name);
_sourceCollection = source;
_indexExpressionJSON = string(path.toString());
}

void FTSNode::writeSourceTable(SQLWriter& ctx, string_view tableName) const {
require(!tableName.empty(), "missing FTS index");
ctx << sqlIdentifier(tableName);
_sourceCollection = source;
_indexID = ctx.newString(path.toString());
}

void FTSNode::writeIndex(SQLWriter& sql) const {
Expand All @@ -71,11 +84,10 @@ namespace litecore::qt {
ctx << "))";
}

#pragma mark - VECTOR:


#ifdef COUCHBASE_ENTERPRISE

# pragma mark - VECTOR:

// A SQLite vector MATCH expression; used by VectorDistanceNode to add a join condition.
class VectorMatchNode final : public ExprNode {
public:
Expand All @@ -90,19 +102,9 @@ namespace litecore::qt {
ExprNode* _vector;
};

VectorDistanceNode::VectorDistanceNode(Array::iterator& args, ParseContext& ctx)
: IndexedNode(IndexType::vector), _indexedExpr(parse(args[0], ctx)) {
VectorDistanceNode::VectorDistanceNode(Array::iterator& args, ParseContext& ctx) : IndexedNode(IndexType::vector) {
// Determine which collection the vector is based on:
SourceNode* source = nullptr;
_indexedExpr->visitTree([&](Node& n, unsigned /*depth*/) {
if ( SourceNode* nodeSource = n.source() ) {
require(source == nullptr || source == nodeSource,
"1st argument (vector) to APPROX_VECTOR_DISTANCE may only refer to a single collection");
source = nodeSource;
}
});
require(source, "unknown source collection for APPROX_VECTOR_DISTANCE()");
_sourceCollection = source;
setIndexedExpression(ExprNode::parse(args[0], ctx));

// Create the JSON expression used to locate the index:
string indexExpr(args[0].toJSON(false, true));
Expand All @@ -118,7 +120,7 @@ namespace litecore::qt {
replace(indexExpr, "[\"." + prefix + ".", "[\".");
}
}
_indexExpressionJSON = ctx.newString(indexExpr);
_indexID = ctx.newString(indexExpr);

_vector = ExprNode::parse(args[1], ctx);

Expand Down Expand Up @@ -183,8 +185,7 @@ namespace litecore::qt {
}

void VectorDistanceNode::writeSourceTable(SQLWriter& sql, string_view tableName) const {
require(!tableName.empty(), "missing vector index");
if ( _simple ) {
if ( _simple && !tableName.empty() ) {
// In a "simple" vector match, run the vector query as a nested SELECT:
sql << "(SELECT docid, distance FROM " << sqlIdentifier(tableName) << " WHERE vector MATCH encode_vector("
<< _vector << ")";
Expand All @@ -193,7 +194,7 @@ namespace litecore::qt {
require(limit, "a LIMIT must be given when using APPROX_VECTOR_DISTANCE()");
sql << " LIMIT " << limit << ")";
} else {
sql << sqlIdentifier(tableName);
IndexedNode::writeSourceTable(sql, tableName);
}
}

Expand All @@ -203,6 +204,49 @@ namespace litecore::qt {
ctx << sqlIdentifier(_indexSource->alias()) << ".distance";
}

# pragma mark - PREDICTION:

ExprNode* PredictionNode::parse(Array::iterator args, ParseContext& ctx) {
// Unlike a vector or FTS query, a prediction() is not required to have an index.
// Check whether one exists. Unfortunately, the index identifier is based on the entire
// expression array including the first item `PREDICTION()` which isn't in the iterator,
// so we have to reconstruct it:
auto expr = MutableArray::newArray();
expr.append("PREDICTION()");
expr.append(args[0]);
expr.append(args[1]);
string id = expressionIdentifier(expr);

if ( ctx.delegate.hasPredictiveIndex(id) ) {
return new (ctx) PredictionNode(args, ctx, id);
} else {
return FunctionNode::parse(kPredictionFnName, args, ctx);
}
}

PredictionNode::PredictionNode(Array::iterator& args, ParseContext& ctx, string_view indexID)
: IndexedNode(IndexType::prediction) {
_indexID = ctx.newString(indexID);
setIndexedExpression(ExprNode::parse(args[1], ctx));
if ( args.count() > 2 ) {
slice pathStr = requiredString(args[2], "property path of PREDICTION()");
KeyPath path = parsePath(pathStr);
require(path.count() > 0, "invalid property path in PREDICTION()");
_subProperty = ctx.newString(path.toString());
}
}

void PredictionNode::writeSQL(SQLWriter& out) const {
auto alias = sqlIdentifier(_indexSource->alias());
if ( _subProperty ) {
out << kUnnestedValueFnName << "(" << alias << ".body, " << sqlString(_subProperty);
out << ")";
} else {
out << kRootFnName << "(" << alias << ".body)";
}
}


#endif


Expand All @@ -221,15 +265,14 @@ namespace litecore::qt {
}

bool IndexSourceNode::matchesNode(const IndexedNode* node) const {
return _indexedNode->indexType() == node->indexType()
&& _indexedNode->indexExpressionJSON() == node->indexExpressionJSON()
return _indexedNode->indexType() == node->indexType() && _indexedNode->indexID() == node->indexID()
&& collection() == node->sourceCollection()->collection()
&& scope() == node->sourceCollection()->scope();
}

IndexType IndexSourceNode::indexType() const { return _indexedNode->indexType(); }

string_view IndexSourceNode::indexedExpressionJSON() const { return _indexedNode->indexExpressionJSON(); }
string_view IndexSourceNode::indexID() const { return _indexedNode->indexID(); }

void IndexSourceNode::addIndexedNode(IndexedNode* node) {
Assert(node != _indexedNode && node->indexType() == _indexedNode->indexType());
Expand Down Expand Up @@ -300,7 +343,7 @@ namespace litecore::qt {
/// Adds a SourceNode for an IndexedNode, or finds an existing one.
/// Sets the source as its indexSource.
void SelectNode::addIndexForNode(IndexedNode* node, ParseContext& ctx) {
DebugAssert(!node->indexExpressionJSON().empty());
DebugAssert(!node->indexID().empty());

// Look for an existing index source:
IndexSourceNode* indexSrc = nullptr;
Expand Down
28 changes: 21 additions & 7 deletions LiteCore/Query/Translator/IndexedNodes.hh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ namespace litecore::qt {
public:
IndexType indexType() const { return _type; }

/// JSON of the indexed expression, usually a property
string_view indexExpressionJSON() const { return _indexExpressionJSON; }
/// A unique identifier of the indexed expression, used to match it with an IndexSourceNode.
string_view indexID() const { return _indexID; }

/// The collection being searched.
SourceNode* C4NULLABLE sourceCollection() const { return _sourceCollection; }
Expand All @@ -49,13 +49,16 @@ namespace litecore::qt {
bool isAuxiliary() const { return _isAuxiliary; }

/// Writes SQL for the index table name (or SELECT expression)
virtual void writeSourceTable(SQLWriter& ctx, string_view tableName) const = 0;
virtual void writeSourceTable(SQLWriter& ctx, string_view tableName) const;

protected:
IndexedNode(IndexType type) : _type(type) {}

void setIndexedExpression(ExprNode*);

IndexType const _type; // Index type
string _indexExpressionJSON; // Expression/property that's indexed, as JSON
ExprNode* _indexedExpr; // The indexed expression (usually a doc property)
string_view _indexID; // Expression/property that's indexed
SourceNode* C4NULLABLE _sourceCollection{}; // The collection being queried
IndexSourceNode* C4NULLABLE _indexSource{}; // Source representing the index
SelectNode* C4NULLABLE _select{}; // The containing SELECT statement
Expand All @@ -67,7 +70,6 @@ namespace litecore::qt {
protected:
FTSNode(Array::iterator& args, ParseContext&, const char* name);

void writeSourceTable(SQLWriter& ctx, string_view tableName) const override;
void writeIndex(SQLWriter&) const;
};

Expand Down Expand Up @@ -111,13 +113,25 @@ namespace litecore::qt {
void writeSQL(SQLWriter&) const override;

private:
ExprNode* _indexedExpr; // The indexed expression (usually a doc property)
ExprNode* _vector; // The vector being queried
int _metric; // Distance metric (actually vectorsearch::Metric)
unsigned _numProbes = 0; // Number of probes, or 0 for default
bool _simple = true; // True if this is a simple (non-hybrid) query
};

/** A `prediction()` function call that uses an index. */
class PredictionNode final : public IndexedNode {
public:
static ExprNode* parse(Array::iterator args, ParseContext&);

void writeSQL(SQLWriter&) const override;

private:
PredictionNode(Array::iterator& args, ParseContext& ctx, string_view indexID);

const char* _subProperty{};
};

#endif

#pragma mark - INDEX SOURCE:
Expand All @@ -128,7 +142,7 @@ namespace litecore::qt {
explicit IndexSourceNode(IndexedNode*, string_view alias, ParseContext& ctx);

IndexType indexType() const;
string_view indexedExpressionJSON() const;
string_view indexID() const;

bool matchesNode(IndexedNode const*) const;

Expand Down
3 changes: 2 additions & 1 deletion LiteCore/Query/Translator/Node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ namespace litecore::qt {
// Typical queries only allocate a few KB, not enough to fill a single chunk.
static constexpr size_t kArenaChunkSize = 4000;

RootContext::RootContext() : Arena(kArenaChunkSize), ParseContext(*static_cast<Arena*>(this)) {}
RootContext::RootContext()
: Arena(kArenaChunkSize), ParseContext(*static_cast<ParseDelegate*>(this), *static_cast<Arena*>(this)) {}

void* Node::operator new(size_t size, ParseContext& ctx) noexcept { return ctx.arena.alloc(size, alignof(Node)); }

Expand Down
26 changes: 22 additions & 4 deletions LiteCore/Query/Translator/Node.hh
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,32 @@ namespace litecore::qt {
};

/** Types of indexes. */
enum class IndexType { FTS, vector };
enum class IndexType {
FTS,
#ifdef COUCHBASE_ENTERPRISE
vector,
prediction,
#endif
};

#pragma mark - PARSE CONTEXT:

struct ParseDelegate {
#ifdef COUCHBASE_ENTERPRISE
std::function<bool(string_view id)> hasPredictiveIndex;
#endif
};

/** State used during parsing, passed down through the recursive descent. */
struct ParseContext {
ParseContext(Arena<>& a) : arena(a) {}
ParseContext(ParseDelegate& d, Arena<>& a) : delegate(d), arena(a) {}

// not a copy constructor! Creates a new child context.
explicit ParseContext(ParseContext& parent) : delegate(parent.delegate), arena(parent.arena){};

ParseContext(ParseContext const& parent) : arena(parent.arena){};
ParseContext(ParseContext&&) = default;

ParseDelegate& delegate;
Arena<>& arena; // The arena allocator
SelectNode* C4NULLABLE select{}; // The enclosing SELECT, if any
std::unordered_map<string, AliasedNode*> aliases; // All of the sources & named results
Expand All @@ -87,8 +103,10 @@ namespace litecore::qt {
/** Top-level Context that provides an Arena, and destructs all Nodes in its destructor. */
struct RootContext
: Arena<>
, public ParseDelegate
, public ParseContext {
RootContext();
explicit RootContext();
RootContext(RootContext&&) = default;
};

#pragma mark - NODE CLASS:
Expand Down
2 changes: 1 addition & 1 deletion LiteCore/Query/Translator/NodesToSQL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace litecore::qt {

void MetaNode::writeSQL(SQLWriter& ctx) const {
string aliasDot;
if ( _source ) aliasDot = CONCAT(sqlIdentifier(_source->alias()) << ".");
if ( _source && !_source->alias().empty() ) aliasDot = CONCAT(sqlIdentifier(_source->alias()) << ".");
writeMetaSQL(aliasDot, _property, ctx);
}

Expand Down
Loading
Loading