diff --git a/src/core/request-body-search.js b/src/core/request-body-search.js index 7031494..87a9232 100644 --- a/src/core/request-body-search.js +++ b/src/core/request-body-search.js @@ -884,7 +884,11 @@ class RequestBodySearch { toJSON() { const dsl = recursiveToJSON(this._body); - if (!isEmpty(this._knn)) dsl.knn = this._knn; + if (!isEmpty(this._knn)) + dsl.knn = + this._knn.length == 1 + ? recMerge(this._knn) + : this._knn.map(knn => recursiveToJSON(knn)); if (!isEmpty(this._aggs)) dsl.aggs = recMerge(this._aggs); diff --git a/test/core-test/knn.test.js b/test/core-test/knn.test.js index 6f717a0..25f5719 100644 --- a/test/core-test/knn.test.js +++ b/test/core-test/knn.test.js @@ -43,6 +43,15 @@ test('knn filter method adds queries correctly', t => { t.deepEqual(json.filter, [query.toJSON()]); }); +test('knn filter method adds queries as array correctly', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const query1 = new TermQuery('field1', 'value1'); + const query2 = new TermQuery('field2', 'value2'); + knn.filter([query1, query2]); + const json = knn.toJSON(); + t.deepEqual(json.filter, [query1.toJSON(), query2.toJSON()]); +}); + test('knn boost method sets correctly', t => { const boostValue = 1.5; const knn = new KNN('my_field', 5, 10) @@ -85,3 +94,37 @@ test('knn toJSON throws error if neither query_vector nor query_vector_builder i 'either query_vector_builder or query_vector must be provided' ); }); + +test('knn throws error when first queryVector and then queryVectorBuilder are set', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const error = t.throws(() => { + knn.queryVectorBuilder('model_123', 'Sample model text'); + }); + t.is( + error.message, + 'cannot provide both query_vector_builder and query_vector' + ); +}); + +test('knn throws error when first queryVectorBuilder and then queryVector are set', t => { + const knn = new KNN('my_field', 5, 10).queryVectorBuilder( + 'model_123', + 'Sample model text' + ); + const error = t.throws(() => { + knn.queryVector([1, 2, 3]); + }); + t.is( + error.message, + 'cannot provide both query_vector_builder and query_vector' + ); +}); + +test('knn filter throws TypeError if non-Query type is passed', t => { + const knn = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); + const error = t.throws(() => { + knn.filter('not_a_query'); + }, TypeError); + + t.is(error.message, 'Argument must be an instance of Query'); +}); diff --git a/test/core-test/request-body-search.test.js b/test/core-test/request-body-search.test.js index 47b1191..2327606 100644 --- a/test/core-test/request-body-search.test.js +++ b/test/core-test/request-body-search.test.js @@ -16,7 +16,8 @@ import { Highlight, Rescore, InnerHits, - RuntimeField + RuntimeField, + KNN } from '../../src'; import { illegalParamType, makeSetsOptionMacro } from '../_macros'; @@ -71,6 +72,11 @@ const innerHits = new InnerHits() .name('last_tweets') .size(5) .sort(new Sort('date', 'desc')); +const kNNVectorBuilder = new KNN('my_field', 5, 10) + .similarity(0.6) + .filter(new TermQuery('field', 'value')) + .queryVectorBuilder('model_123', 'Sample model text'); +const kNNVector = new KNN('my_field', 5, 10).queryVector([1, 2, 3]); const instance = new RequestBodySearch(); @@ -83,9 +89,11 @@ test(illegalParamType, instance, 'scriptFields', 'Object'); test(illegalParamType, instance, 'highlight', 'Highlight'); test(illegalParamType, instance, 'rescore', 'Rescore'); test(illegalParamType, instance, 'postFilter', 'Query'); +test(illegalParamType, instance, 'kNN', 'KNN'); test(setsOption, 'query', { param: searchQry }); test(setsOption, 'aggregation', { param: aggA, keyName: 'aggs' }); test(setsOption, 'agg', { param: aggA, keyName: 'aggs' }); +test(setsOption, 'kNN', { param: kNNVectorBuilder, keyName: 'knn' }); test(setsOption, 'suggest', { param: suggest }); test(setsOption, 'suggestText', { param: 'suggest-text', @@ -347,3 +355,83 @@ test('sets multiple indices_boost', t => { }; t.deepEqual(value, expected); }); + +test('kNN setup query vector builder', t => { + const value = new RequestBodySearch().kNN(kNNVectorBuilder).toJSON(); + const expected = { + knn: { + field: 'my_field', + k: 5, + filter: [ + { + term: { + field: 'value' + } + } + ], + num_candidates: 10, + query_vector_builder: { + text_embeddings: { + model_id: 'model_123', + model_text: 'Sample model text' + } + }, + similarity: 0.6 + } + }; + + t.deepEqual(value, expected); +}); + +test('kNN setup query vector', t => { + const value = new RequestBodySearch().kNN(kNNVector).toJSON(); + const expected = { + knn: { + field: 'my_field', + k: 5, + filter: [], + num_candidates: 10, + query_vector: [1, 2, 3] + } + }; + + t.deepEqual(value, expected); +}); + +test('kNN setup query vector array', t => { + const value = new RequestBodySearch() + .kNN([kNNVector, kNNVectorBuilder]) + .toJSON(); + const expected = { + knn: [ + { + field: 'my_field', + k: 5, + filter: [], + num_candidates: 10, + query_vector: [1, 2, 3] + }, + { + field: 'my_field', + filter: [ + { + term: { + field: 'value' + } + } + ], + k: 5, + num_candidates: 10, + query_vector_builder: { + text_embeddings: { + model_id: 'model_123', + model_text: 'Sample model text' + } + }, + similarity: 0.6 + } + ] + }; + + t.deepEqual(value, expected); +});