Skip to content

Commit

Permalink
feat: Add kNN Query (#198)
Browse files Browse the repository at this point in the history
kNN search was added to Elasticsearch in v8.0

Co-authored-by: kennylindahl <[email protected]>
Co-authored-by: Andreas Franzon <[email protected]>
  • Loading branch information
3 people authored May 5, 2024
1 parent 8f73d34 commit 78de179
Show file tree
Hide file tree
Showing 7 changed files with 485 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/core/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ exports.Aggregation = require('./aggregation');

exports.Query = require('./query');

exports.KNN = require('./knn');

exports.Suggester = require('./suggester');

exports.Script = require('./script');
Expand Down
138 changes: 138 additions & 0 deletions src/core/knn.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
'use strict';

const { recursiveToJSON, checkType } = require('./util');
const Query = require('./query');

/**
* Class representing a k-Nearest Neighbors (k-NN) query.
* This class extends the Query class to support the specifics of k-NN search, including setting up the field,
* query vector, number of neighbors (k), and number of candidates.
*
* @example
* const qry = esb.kNN('my_field', 100, 1000).vector([1,2,3]);
* const qry = esb.kNN('my_field', 100, 1000).queryVectorBuilder('model_123', 'Sample model text');
*
* NOTE: kNN search was added to Elasticsearch in v8.0
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html)
*/
class KNN {
// eslint-disable-next-line require-jsdoc
constructor(field, k, numCandidates) {
if (k > numCandidates)
throw new Error('KNN numCandidates cannot be less than k');
this._body = {};
this._body.field = field;
this._body.k = k;
this._body.filter = [];
this._body.num_candidates = numCandidates;
}

/**
* Sets the query vector for the k-NN search.
* @param {Array<number>} vector - The query vector.
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
queryVector(vector) {
if (this._body.query_vector_builder)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this._body.query_vector = vector;
return this;
}

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* It's important to note that either a direct query vector or a query vector builder can be
* provided, but not both.
*
* @param {string} modelId - The ID of the model to be used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @returns {KNN} Returns the instance of KNN for method chaining.
* @throws {Error} Throws an error if both query_vector_builder and query_vector are provided.
*
* @example
* let knn = new esb.KNN().queryVectorBuilder('model_123', 'Sample model text');
*/
queryVectorBuilder(modelId, modelText) {
if (this._body.query_vector)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this._body.query_vector_builder = {
text_embeddings: {
model_id: modelId,
model_text: modelText
}
};
return this;
}

/**
* Adds one or more filter queries to the k-NN search.
*
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
* If any provided query is not an instance of `Query`, a TypeError is thrown.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {KNN} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*
* @example
* let knn = new esb.KNN().filter(new esb.TermQuery('field', 'value')); // Applying a single filter query
*
* @example
* let knn = new esb.KNN().filter([
* new esb.TermQuery('field1', 'value1'),
* new esb.TermQuery('field2', 'value2')
* ]); // Applying multiple filter queries
*/
filter(queries) {
const queryArray = Array.isArray(queries) ? queries : [queries];
queryArray.forEach(query => {
checkType(query, Query);
this._body.filter.push(query);
});
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} boost - The number of the boost
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
boost(boost) {
this._body.boost = boost;
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} similarity - The number of the similarity
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
similarity(similarity) {
this._body.similarity = similarity;
return this;
}

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
* @returns {Object} returns an Object which maps to the elasticsearch query DSL
*/
toJSON() {
if (!this._body.query_vector && !this._body.query_vector_builder)
throw new Error(
'either query_vector_builder or query_vector must be provided'
);
return recursiveToJSON(this._body);
}
}

module.exports = KNN;
25 changes: 24 additions & 1 deletion src/core/request-body-search.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ const Query = require('./query'),
Rescore = require('./rescore'),
Sort = require('./sort'),
Highlight = require('./highlight'),
InnerHits = require('./inner-hits');
InnerHits = require('./inner-hits'),
KNN = require('./knn');

const { checkType, setDefault, recursiveToJSON } = require('./util');
const RuntimeField = require('./runtime-field');
Expand Down Expand Up @@ -70,6 +71,7 @@ class RequestBodySearch {
constructor() {
// Maybe accept some optional parameter?
this._body = {};
this._knn = [];
this._aggs = [];
this._suggests = [];
this._suggestText = null;
Expand All @@ -88,6 +90,21 @@ class RequestBodySearch {
return this;
}

/**
* Sets knn on the search request body.
*
* @param {Knn|Knn[]} knn
* @returns {RequestBodySearch} returns `this` so that calls can be chained.
*/
kNN(knn) {
const knns = Array.isArray(knn) ? knn : [knn];
knns.forEach(_knn => {
checkType(_knn, KNN);
this._knn.push(_knn);
});
return this;
}

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -867,6 +884,12 @@ class RequestBodySearch {
toJSON() {
const dsl = recursiveToJSON(this._body);

if (!isEmpty(this._knn))
dsl.knn =
this._knn.length == 1
? recMerge(this._knn)
: this._knn.map(knn => recursiveToJSON(knn));

if (!isEmpty(this._aggs)) dsl.aggs = recMerge(this._aggs);

if (!isEmpty(this._suggests) || !isNil(this._suggestText)) {
Expand Down
103 changes: 94 additions & 9 deletions src/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ declare namespace esb {
*/
query(query: Query): this;

/**
* Sets knn on the request body.
*
* @param {KNN|KNN[]} knn
*/
kNN(knn: KNN | KNN[]): this;

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -3141,7 +3148,7 @@ declare namespace esb {

/**
* Sets the script used to compute the score of documents returned by the query.
*
*
* @param {Script} script A valid `Script` object
*/
script(script: Script): this;
Expand Down Expand Up @@ -3761,6 +3768,84 @@ declare namespace esb {
spanQry?: SpanQueryBase
): SpanFieldMaskingQuery;

/**
* Knn performs k-nearest neighbor (KNN) searches.
* This class allows configuring the KNN search with various parameters such as field, query vector,
* number of nearest neighbors (k), number of candidates, boost factor, and similarity metric.
*
* NOTE: Only available in Elasticsearch v8.0+
*/
export class KNN {
/**
* Creates an instance of Knn, initializing the internal state for the k-NN search.
*
* @param {string} field - (Optional) The field against which to perform the k-NN search.
* @param {number} k - (Optional) The number of nearest neighbors to retrieve.
* @param {number} numCandidates - (Optional) The number of candidate neighbors to consider during the search.
* @throws {Error} If the number of candidates (numCandidates) is less than the number of neighbors (k).
*/
constructor(field: string, k: number, numCandidates: number);

/**
* Sets the query vector for the KNN search, an array of numbers representing the reference point.
*
* @param {number[]} vector
*/
queryVector(vector: number[]): this;

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* Note that either a direct query vector or a query vector builder can be provided, but not both.
*
* @param {string} modelId - The ID of the model used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @returns {KNN} Returns the instance of Knn for method chaining.
* @throws {Error} If both query_vector_builder and query_vector are provided.
*/
queryVectorBuilder(modelId: string, modelText: string): this;

/**
* Adds one or more filter queries to the k-NN search.
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {KNN} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*/
filter(queries: Query | Query[]): this;

/**
* Applies a boost factor to the query to influence the relevance score of returned documents.
*
* @param {number} boost
*/
boost(boost: number): this;

/**
* Sets the similarity metric used in the KNN algorithm to calculate similarity.
*
* @param {number} similarity
*/
similarity(similarity: number): this;

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
*/
toJSON(): object;
}

/**
* Factory function to instantiate a new Knn object.
*
* @returns {KNN}
*/
export function kNN(field: string, k: number, numCandidates: number): KNN;

/**
* Base class implementation for all aggregation types.
*
Expand Down Expand Up @@ -3913,9 +3998,9 @@ declare namespace esb {
/**
* A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents.
* These values can be extracted either from specific numeric fields in the documents.
*
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html)
*
*
* Added in Elasticsearch v6.4.0
* [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html)
*
Expand All @@ -3929,7 +4014,7 @@ declare namespace esb {

/**
* Sets the value
*
*
* @param {string | Script} value Field name or script to be used as the value
* @param {number=} missing Sets the missing parameter which defines how documents
* that are missing a value should be treated.
Expand All @@ -3939,7 +4024,7 @@ declare namespace esb {

/**
* Sets the weight
*
*
* @param {string | Script} weight Field name or script to be used as the weight
* @param {number=} missing Sets the missing parameter which defines how documents
* that are missing a value should be treated.
Expand Down Expand Up @@ -3969,9 +4054,9 @@ declare namespace esb {
/**
* A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents.
* These values can be extracted either from specific numeric fields in the documents.
*
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html)
*
*
* Added in Elasticsearch v6.4.0
* [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html)
*
Expand Down Expand Up @@ -8922,15 +9007,15 @@ declare namespace esb {

/**
* Sets the type of the runtime field.
*
*
* @param {string} type One of `boolean`, `composite`, `date`, `double`, `geo_point`, `ip`, `keyword`, `long`, `lookup`.
* @returns {void}
*/
type(type: 'boolean' | 'composite' | 'date' | 'double' | 'geo_point' | 'ip' | 'keyword' | 'long' | 'lookup'): void;

/**
* Sets the source of the script.
*
*
* @param {string} script
* @returns {void}
*/
Expand Down
8 changes: 8 additions & 0 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const {
RuntimeField,
SearchTemplate,
Query,
KNN,
util: { constructorWrapper }
} = require('./core');

Expand Down Expand Up @@ -343,6 +344,13 @@ exports.spanWithinQuery = constructorWrapper(SpanWithinQuery);

exports.SpanFieldMaskingQuery = SpanFieldMaskingQuery;
exports.spanFieldMaskingQuery = constructorWrapper(SpanFieldMaskingQuery);

/* ============ ============ ============ */
/* ======== KNN ======== */
/* ============ ============ ============ */
exports.KNN = KNN;
exports.kNN = constructorWrapper(KNN);

/* ============ ============ ============ */
/* ======== Metrics Aggregations ======== */
/* ============ ============ ============ */
Expand Down
Loading

0 comments on commit 78de179

Please sign in to comment.