Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add knnQuery #198

Merged
merged 7 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/core/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ exports.Aggregation = require('./aggregation');

exports.Query = require('./query');

exports.KNN = require('./knn');

exports.Suggester = require('./suggester');

exports.Script = require('./script');
Expand Down
138 changes: 138 additions & 0 deletions src/core/knn.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
'use strict';

const { recursiveToJSON, checkType } = require('./util');
const Query = require('./query');

/**
* Class representing a k-Nearest Neighbors (k-NN) query.
* This class extends the Query class to support the specifics of k-NN search, including setting up the field,
* query vector, number of neighbors (k), and number of candidates.
sudo-suhas marked this conversation as resolved.
Show resolved Hide resolved
*
sudo-suhas marked this conversation as resolved.
Show resolved Hide resolved
* @example
* const qry = esb.kNN('my_field', 100, 1000).vector([1,2,3]);
* const qry = esb.kNN('my_field', 100, 1000).queryVectorBuilder('model_123', 'Sample model text');
*
* NOTE: kNN search was added to Elasticsearch in v8.0
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html)
*/
class KNN {
// eslint-disable-next-line require-jsdoc
constructor(field, k, numCandidates) {
if (k > numCandidates)
throw new Error('KNN numCandidates cannot be less than k');
this._body = {};
this._body.field = field;
this._body.k = k;
this._body.filter = [];
this._body.num_candidates = numCandidates;
}

/**
* Sets the query vector for the k-NN search.
* @param {Array<number>} vector - The query vector.
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
queryVector(vector) {
if (this._body.query_vector_builder)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this._body.query_vector = vector;
return this;
}

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* It's important to note that either a direct query vector or a query vector builder can be
* provided, but not both.
*
* @param {string} modelId - The ID of the model to be used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @returns {KNN} Returns the instance of KNN for method chaining.
* @throws {Error} Throws an error if both query_vector_builder and query_vector are provided.
*
* @example
* let knn = new esb.KNN().queryVectorBuilder('model_123', 'Sample model text');
*/
queryVectorBuilder(modelId, modelText) {
if (this._body.query_vector)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this._body.query_vector_builder = {
text_embeddings: {
model_id: modelId,
model_text: modelText
}
};
return this;
}

/**
* Adds one or more filter queries to the k-NN search.
*
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
* If any provided query is not an instance of `Query`, a TypeError is thrown.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {KNN} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*
* @example
* let knn = new esb.KNN().filter(new esb.TermQuery('field', 'value')); // Applying a single filter query
*
* @example
* let knn = new esb.KNN().filter([
* new esb.TermQuery('field1', 'value1'),
* new esb.TermQuery('field2', 'value2')
* ]); // Applying multiple filter queries
*/
filter(queries) {
const queryArray = Array.isArray(queries) ? queries : [queries];
queryArray.forEach(query => {
checkType(query, Query);
this._body.filter.push(query);
});
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} boost - The number of the boost
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
boost(boost) {
this._body.boost = boost;
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} similarity - The number of the similarity
* @returns {KNN} Returns the instance of KNN for method chaining.
*/
similarity(similarity) {
this._body.similarity = similarity;
return this;
}

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
* @returns {Object} returns an Object which maps to the elasticsearch query DSL
*/
toJSON() {
if (!this._body.query_vector && !this._body.query_vector_builder)
throw new Error(
'either query_vector_builder or query_vector must be provided'
);
return recursiveToJSON(this._body);
}
}

module.exports = KNN;
25 changes: 24 additions & 1 deletion src/core/request-body-search.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
Rescore = require('./rescore'),
Sort = require('./sort'),
Highlight = require('./highlight'),
InnerHits = require('./inner-hits');
InnerHits = require('./inner-hits'),
KNN = require('./knn');

const { checkType, setDefault, recursiveToJSON } = require('./util');
const RuntimeField = require('./runtime-field');
Expand Down Expand Up @@ -70,6 +71,7 @@
constructor() {
// Maybe accept some optional parameter?
this._body = {};
this._knn = [];
this._aggs = [];
this._suggests = [];
this._suggestText = null;
Expand All @@ -88,6 +90,21 @@
return this;
}

/**
* Sets knn on the search request body.
*
* @param {Knn|Knn[]} knn
* @returns {RequestBodySearch} returns `this` so that calls can be chained.
*/
kNN(knn) {
sudo-suhas marked this conversation as resolved.
Show resolved Hide resolved
const knns = Array.isArray(knn) ? knn : [knn];
knns.forEach(_knn => {
checkType(_knn, KNN);
this._knn.push(_knn);
});
return this;
}

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -694,7 +711,7 @@
return this;
}

// TODO: Scroll related changes

Check warning on line 714 in src/core/request-body-search.js

View workflow job for this annotation

GitHub Actions / check (10.x)

Unexpected 'todo' comment

Check warning on line 714 in src/core/request-body-search.js

View workflow job for this annotation

GitHub Actions / check (12.x)

Unexpected 'todo' comment

Check warning on line 714 in src/core/request-body-search.js

View workflow job for this annotation

GitHub Actions / check (14.x)

Unexpected 'todo' comment
// Maybe only slice needs to be supported.

/**
Expand Down Expand Up @@ -867,6 +884,12 @@
toJSON() {
const dsl = recursiveToJSON(this._body);

if (!isEmpty(this._knn))
dsl.knn =
this._knn.length == 1
? recMerge(this._knn)
: this._knn.map(knn => recursiveToJSON(knn));

if (!isEmpty(this._aggs)) dsl.aggs = recMerge(this._aggs);

if (!isEmpty(this._suggests) || !isNil(this._suggestText)) {
Expand Down
103 changes: 94 additions & 9 deletions src/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ declare namespace esb {
*/
query(query: Query): this;

/**
* Sets knn on the request body.
*
* @param {KNN|KNN[]} knn
*/
kNN(knn: KNN | KNN[]): this;

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -3141,7 +3148,7 @@ declare namespace esb {

/**
* Sets the script used to compute the score of documents returned by the query.
*
*
* @param {Script} script A valid `Script` object
*/
script(script: Script): this;
Expand Down Expand Up @@ -3761,6 +3768,84 @@ declare namespace esb {
spanQry?: SpanQueryBase
): SpanFieldMaskingQuery;

/**
* Knn performs k-nearest neighbor (KNN) searches.
* This class allows configuring the KNN search with various parameters such as field, query vector,
* number of nearest neighbors (k), number of candidates, boost factor, and similarity metric.
*
* NOTE: Only available in Elasticsearch v8.0+
*/
export class KNN {
/**
* Creates an instance of Knn, initializing the internal state for the k-NN search.
*
* @param {string} field - (Optional) The field against which to perform the k-NN search.
* @param {number} k - (Optional) The number of nearest neighbors to retrieve.
* @param {number} numCandidates - (Optional) The number of candidate neighbors to consider during the search.
* @throws {Error} If the number of candidates (numCandidates) is less than the number of neighbors (k).
*/
constructor(field: string, k: number, numCandidates: number);

/**
* Sets the query vector for the KNN search, an array of numbers representing the reference point.
*
* @param {number[]} vector
*/
queryVector(vector: number[]): this;

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* Note that either a direct query vector or a query vector builder can be provided, but not both.
*
* @param {string} modelId - The ID of the model used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @returns {KNN} Returns the instance of Knn for method chaining.
* @throws {Error} If both query_vector_builder and query_vector are provided.
*/
queryVectorBuilder(modelId: string, modelText: string): this;

/**
* Adds one or more filter queries to the k-NN search.
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {KNN} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*/
filter(queries: Query | Query[]): this;

/**
* Applies a boost factor to the query to influence the relevance score of returned documents.
*
* @param {number} boost
*/
boost(boost: number): this;

/**
* Sets the similarity metric used in the KNN algorithm to calculate similarity.
*
* @param {number} similarity
*/
similarity(similarity: number): this;

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
*/
toJSON(): object;
}

/**
* Factory function to instantiate a new Knn object.
*
* @returns {KNN}
*/
export function kNN(field: string, k: number, numCandidates: number): KNN;

/**
* Base class implementation for all aggregation types.
*
Expand Down Expand Up @@ -3913,9 +3998,9 @@ declare namespace esb {
/**
* A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents.
* These values can be extracted either from specific numeric fields in the documents.
*
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html)
*
*
* Added in Elasticsearch v6.4.0
* [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html)
*
Expand All @@ -3929,7 +4014,7 @@ declare namespace esb {

/**
* Sets the value
*
*
* @param {string | Script} value Field name or script to be used as the value
* @param {number=} missing Sets the missing parameter which defines how documents
* that are missing a value should be treated.
Expand All @@ -3939,7 +4024,7 @@ declare namespace esb {

/**
* Sets the weight
*
*
* @param {string | Script} weight Field name or script to be used as the weight
* @param {number=} missing Sets the missing parameter which defines how documents
* that are missing a value should be treated.
Expand Down Expand Up @@ -3969,9 +4054,9 @@ declare namespace esb {
/**
* A single-value metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents.
* These values can be extracted either from specific numeric fields in the documents.
*
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-metrics-weight-avg-aggregation.html)
*
*
* Added in Elasticsearch v6.4.0
* [Release notes](https://www.elastic.co/guide/en/elasticsearch/reference/6.4/release-notes-6.4.0.html)
*
Expand Down Expand Up @@ -8922,15 +9007,15 @@ declare namespace esb {

/**
* Sets the type of the runtime field.
*
*
* @param {string} type One of `boolean`, `composite`, `date`, `double`, `geo_point`, `ip`, `keyword`, `long`, `lookup`.
* @returns {void}
*/
type(type: 'boolean' | 'composite' | 'date' | 'double' | 'geo_point' | 'ip' | 'keyword' | 'long' | 'lookup'): void;

/**
* Sets the source of the script.
*
*
* @param {string} script
* @returns {void}
*/
Expand Down
8 changes: 8 additions & 0 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const {
RuntimeField,
SearchTemplate,
Query,
KNN,
util: { constructorWrapper }
} = require('./core');

Expand Down Expand Up @@ -343,6 +344,13 @@ exports.spanWithinQuery = constructorWrapper(SpanWithinQuery);

exports.SpanFieldMaskingQuery = SpanFieldMaskingQuery;
exports.spanFieldMaskingQuery = constructorWrapper(SpanFieldMaskingQuery);

/* ============ ============ ============ */
/* ======== KNN ======== */
/* ============ ============ ============ */
exports.KNN = KNN;
exports.kNN = constructorWrapper(KNN);

/* ============ ============ ============ */
/* ======== Metrics Aggregations ======== */
/* ============ ============ ============ */
Expand Down
Loading
Loading