Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fork Sync: Update from parent repository #132

Merged
merged 1 commit into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Make SearchTemplateRequest implement IndicesRequest.Replaceable ([#9122]()https://github.com/opensearch-project/OpenSearch/pull/9122)
- [BWC and API enforcement] Define the initial set of annotations, their meaning and relations between them ([#9223](https://github.com/opensearch-project/OpenSearch/pull/9223))
- [Segment Replication] Support realtime reads for GET requests ([#9212](https://github.com/opensearch-project/OpenSearch/pull/9212))
- [Feature] Expose term frequency in Painless script score context ([#9081](https://github.com/opensearch-project/OpenSearch/pull/9081))

### Dependencies
- Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307))
Expand Down Expand Up @@ -164,4 +165,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Security

[Unreleased 3.0]: https://github.com/opensearch-project/OpenSearch/compare/2.x...HEAD
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public boolean needs_score() {

@Override
public ScoreScript newInstance(final LeafReaderContext leaf) throws IOException {
return new ScoreScript(null, null, null) {
return new ScoreScript(null, null, null, null) {
// Fake the scorer until setScorer is called.
DoubleValues values = source.getValues(leaf, new DoubleValues() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.lucene.expressions.js.JavascriptCompiler;
import org.apache.lucene.expressions.js.VariableContext;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.SpecialPermission;
import org.opensearch.common.Nullable;
import org.opensearch.index.fielddata.IndexFieldData;
Expand Down Expand Up @@ -110,7 +111,7 @@ public FilterScript.LeafFactory newFactory(Map<String, Object> params, SearchLoo

contexts.put(ScoreScript.CONTEXT, (Expression expr) -> new ScoreScript.Factory() {
@Override
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
return newScoreScript(expr, lookup, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,11 @@ static Response innerShardOperation(Request request, ScriptService scriptService
} else if (scriptContext == ScoreScript.CONTEXT) {
return prepareRamIndex(request, (context, leafReaderContext) -> {
ScoreScript.Factory factory = scriptService.compile(request.script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory leafFactory = factory.newFactory(request.getScript().getParams(), context.lookup());
ScoreScript.LeafFactory leafFactory = factory.newFactory(
request.getScript().getParams(),
context.lookup(),
context.searcher()
);
ScoreScript scoreScript = leafFactory.newInstance(leafReaderContext);
scoreScript.setDocument(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class org.opensearch.script.ScoreScript @no_import {
}

static_import {
int termFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TermFreq
float tf(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TF
long totalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TotalTermFreq
long sumTotalTermFreq(org.opensearch.script.ScoreScript, String) bound_to org.opensearch.script.ScoreScriptUtils$SumTotalTermFreq
double saturation(double, double) from_class org.opensearch.script.ScoreScriptUtils
double sigmoid(double, double, double) from_class org.opensearch.script.ScoreScriptUtils
double randomScore(org.opensearch.script.ScoreScript, int, String) bound_to org.opensearch.script.ScoreScriptUtils$RandomScoreField
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
---
setup:
- skip:
version: " - 2.9.99"
reason: "termFreq functions for script_score was introduced in 2.10.0"
- do:
indices.create:
index: test
body:
settings:
number_of_shards: 1
mappings:
properties:
f1:
type: keyword
f2:
type: text
- do:
bulk:
refresh: true
body:
- '{"index": {"_index": "test", "_id": "doc1"}}'
- '{"f1": "v0", "f2": "v1"}'
- '{"index": {"_index": "test", "_id": "doc2"}}'
- '{"f2": "v2"}'

---
"Script score function using the termFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "termFreq(params.field, params.term)"
params:
field: "f1"
term: "v0"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.0 }

---
"Script score function using the totalTermFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "if (doc[params.field].size() == 0) return params.default_value; else { return totalTermFreq(params.field, params.term); }"
params:
default_value: 0.5
field: "f1"
term: "v0"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.5 }

---
"Script score function using the sumTotalTermFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "if (doc[params.field].size() == 0) return params.default_value; else { return sumTotalTermFreq(params.field); }"
params:
default_value: 0.5
field: "f1"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.5 }
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.ScriptPlugin;
Expand Down Expand Up @@ -120,20 +121,22 @@ public boolean isResultDeterministic() {
@Override
public LeafFactory newFactory(
Map<String, Object> params,
SearchLookup lookup
SearchLookup lookup,
IndexSearcher indexSearcher
) {
return new PureDfLeafFactory(params, lookup);
return new PureDfLeafFactory(params, lookup, indexSearcher);
}
}

private static class PureDfLeafFactory implements LeafFactory {
private final Map<String, Object> params;
private final SearchLookup lookup;
private final IndexSearcher indexSearcher;
private final String field;
private final String term;

private PureDfLeafFactory(
Map<String, Object> params, SearchLookup lookup) {
Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
if (params.containsKey("field") == false) {
throw new IllegalArgumentException(
"Missing parameter [field]");
Expand All @@ -144,6 +147,7 @@ private PureDfLeafFactory(
}
this.params = params;
this.lookup = lookup;
this.indexSearcher = indexSearcher;
field = params.get("field").toString();
term = params.get("term").toString();
}
Expand All @@ -163,7 +167,7 @@ public ScoreScript newInstance(LeafReaderContext context)
* the field and/or term don't exist in this segment,
* so always return 0
*/
return new ScoreScript(params, lookup, context) {
return new ScoreScript(params, lookup, indexSearcher, context) {
@Override
public double execute(
ExplanationHolder explanation
Expand All @@ -172,7 +176,7 @@ public double execute(
}
};
}
return new ScoreScript(params, lookup, context) {
return new ScoreScript(params, lookup, indexSearcher, context) {
int currentDocid = -1;
@Override
public void setDocument(int docid) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchType;
Expand Down Expand Up @@ -93,15 +94,15 @@ public String getType() {
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
assert scriptSource.equals("explainable_script");
assert context == ScoreScript.CONTEXT;
ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() {
ScoreScript.Factory factory = (params1, lookup, indexSearcher) -> new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return false;
}

@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return new MyScript(params1, lookup, ctx);
return new MyScript(params1, lookup, indexSearcher, ctx);
}
};
return context.factoryClazz.cast(factory);
Expand All @@ -117,8 +118,8 @@ public Set<ScriptContext<?>> getSupportedContexts() {

static class MyScript extends ScoreScript implements ExplainableScoreScript {

MyScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
super(params, lookup, leafContext);
MyScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
super(params, lookup, indexSearcher, leafContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ protected int doHashCode() {
protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
return new ScriptScoreFunction(
script,
searchScript,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
);
}
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
final QueryBuilder queryBuilder = this.query;
Query query = queryBuilder.toQuery(context);
return new ScriptScoreQuery(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import java.io.IOException;

/**
* An interface representing a term frequency function used to compute document scores
* based on specific term frequency calculations. Implementations of this interface should
* provide a way to execute the term frequency function for a given document ID.
*
* @opensearch.internal
*/
public interface TermFrequencyFunction {
Object execute(int docId) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
* A factory class for creating instances of {@link TermFrequencyFunction}.
* This class provides methods for creating different term frequency functions based on
* the specified function name, field, and term. Each term frequency function is designed
* to compute document scores based on specific term frequency calculations.
*
* @opensearch.internal
*/
public class TermFrequencyFunctionFactory {
public static TermFrequencyFunction createFunction(
TermFrequencyFunctionName functionName,
String field,
String term,
LeafReaderContext readerContext,
IndexSearcher indexSearcher
) throws IOException {
switch (functionName) {
case TERM_FREQ:
TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
FunctionValues functionValues = termFreqValueSource.getValues(null, readerContext);
return docId -> functionValues.intVal(docId);
case TF:
TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
Map<Object, Object> tfContext = new HashMap<>() {
{
put("searcher", indexSearcher);
}
};
functionValues = tfValueSource.getValues(tfContext, readerContext);
return docId -> functionValues.floatVal(docId);
case TOTAL_TERM_FREQ:
TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource(
field,
term,
field,
BytesRefs.toBytesRef(term)
);
Map<Object, Object> ttfContext = new HashMap<>();
totalTermFreqValueSource.createWeight(ttfContext, indexSearcher);
functionValues = totalTermFreqValueSource.getValues(ttfContext, readerContext);
return docId -> functionValues.longVal(docId);
case SUM_TOTAL_TERM_FREQ:
SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field);
Map<Object, Object> sttfContext = new HashMap<>();
sumTotalTermFreqValueSource.createWeight(sttfContext, indexSearcher);
functionValues = sumTotalTermFreqValueSource.getValues(sttfContext, readerContext);
return docId -> functionValues.longVal(docId);
default:
throw new IllegalArgumentException("Unsupported function: " + functionName);
}
}

/**
* An enumeration representing the names of supported term frequency functions.
*/
public enum TermFrequencyFunctionName {
TERM_FREQ("termFreq"),
TF("tf"),
TOTAL_TERM_FREQ("totalTermFreq"),
SUM_TOTAL_TERM_FREQ("sumTotalTermFreq");

private final String termFrequencyFunctionName;

TermFrequencyFunctionName(String termFrequencyFunctionName) {
this.termFrequencyFunctionName = termFrequencyFunctionName;
}

public String getTermFrequencyFunctionName() {
return termFrequencyFunctionName;
}
}
}
Loading
Loading