Add '_name' field support to score functions and provide it back in explanation response (#2244)

* Add '_name' field support to score functions and provide it back in explanation response

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>

* Address code review comments

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
This commit is contained in:
Andriy Redko 2022-03-04 11:12:27 -05:00 committed by GitHub
parent ae14259a2c
commit 5f90227a05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1126 additions and 78 deletions

View File

@ -51,6 +51,7 @@ import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder; import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder;
import org.opensearch.index.query.functionscore.ScoreFunctionBuilders; import org.opensearch.index.query.functionscore.ScoreFunctionBuilders;
import org.opensearch.search.MultiValueMode; import org.opensearch.search.MultiValueMode;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits; import org.opensearch.search.SearchHits;
import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.test.VersionUtils; import org.opensearch.test.VersionUtils;
@ -77,7 +78,9 @@ import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertOrderedSearchHits; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertOrderedSearchHits;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchHits; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchHits;
import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThan;
@ -616,6 +619,76 @@ public class DecayFunctionScoreIT extends OpenSearchIntegTestCase {
} }
public void testCombineModesExplain() throws Exception {
assertAcked(
prepareCreate("test").addMapping(
"type1",
jsonBuilder().startObject()
.startObject("type1")
.startObject("properties")
.startObject("test")
.field("type", "text")
.endObject()
.startObject("num")
.field("type", "double")
.endObject()
.endObject()
.endObject()
.endObject()
)
);
client().prepareIndex()
.setId("1")
.setIndex("test")
.setRefreshPolicy(IMMEDIATE)
.setSource(jsonBuilder().startObject().field("test", "value value").field("num", 1.0).endObject())
.get();
FunctionScoreQueryBuilder baseQuery = functionScoreQuery(
constantScoreQuery(termQuery("test", "value")).queryName("query1"),
ScoreFunctionBuilders.weightFactorFunction(2, "weight1")
);
// decay score should return 0.5 for this function and baseQuery should return 2.0f as it's score
ActionFuture<SearchResponse> response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(baseQuery, gaussDecayFunction("num", 0.0, 1.0, null, 0.5, "func2")).boostMode(
CombineFunction.MULTIPLY
)
)
)
);
SearchResponse sr = response.actionGet();
SearchHits sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(0).getExplanation().getDetails(), arrayWithSize(2));
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails(), arrayWithSize(2));
// "description": "ConstantScore(test:value) (_name: query1)"
assertThat(
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[0].getDescription(),
equalTo("ConstantScore(test:value) (_name: query1)")
);
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails(), arrayWithSize(2));
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(2));
// "description": "constant score 1.0(_name: func1) - no function provided"
assertThat(
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
equalTo("constant score 1.0(_name: weight1) - no function provided")
);
// "description": "exp(-0.5*pow(MIN[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.7213475204444817,
// _name: func2)"
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
assertThat(
sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
containsString("_name: func2")
);
}
public void testExceptionThrownIfScaleLE0() throws Exception { public void testExceptionThrownIfScaleLE0() throws Exception {
assertAcked( assertAcked(
prepareCreate("test").addMapping( prepareCreate("test").addMapping(
@ -1195,4 +1268,132 @@ public class DecayFunctionScoreIT extends OpenSearchIntegTestCase {
sh = sr.getHits(); sh = sr.getHits();
assertThat((double) (sh.getAt(0).getScore()), closeTo((sh.getAt(1).getScore()), 1.e-6d)); assertThat((double) (sh.getAt(0).getScore()), closeTo((sh.getAt(1).getScore()), 1.e-6d));
} }
public void testDistanceScoreGeoLinGaussExplain() throws Exception {
assertAcked(
prepareCreate("test").addMapping(
"type1",
jsonBuilder().startObject()
.startObject("type1")
.startObject("properties")
.startObject("test")
.field("type", "text")
.endObject()
.startObject("loc")
.field("type", "geo_point")
.endObject()
.endObject()
.endObject()
.endObject()
)
);
List<IndexRequestBuilder> indexBuilders = new ArrayList<>();
indexBuilders.add(
client().prepareIndex()
.setId("1")
.setIndex("test")
.setSource(
jsonBuilder().startObject()
.field("test", "value")
.startObject("loc")
.field("lat", 10)
.field("lon", 20)
.endObject()
.endObject()
)
);
indexBuilders.add(
client().prepareIndex()
.setId("2")
.setIndex("test")
.setSource(
jsonBuilder().startObject()
.field("test", "value")
.startObject("loc")
.field("lat", 11)
.field("lon", 22)
.endObject()
.endObject()
)
);
indexRandom(true, indexBuilders);
// Test Gauss
List<Float> lonlat = new ArrayList<>();
lonlat.add(20f);
lonlat.add(11f);
final String queryName = "query1";
final String functionName = "func1";
ActionFuture<SearchResponse> response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(baseQuery.queryName(queryName), gaussDecayFunction("loc", lonlat, "1000km", functionName))
)
)
);
SearchResponse sr = response.actionGet();
SearchHits sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo(2L));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
assertExplain(queryName, functionName, sr);
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(baseQuery.queryName(queryName), linearDecayFunction("loc", lonlat, "1000km", functionName))
)
)
);
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo(2L));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
assertExplain(queryName, functionName, sr);
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(
baseQuery.queryName(queryName),
exponentialDecayFunction("loc", lonlat, "1000km", functionName)
)
)
)
);
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits().value, equalTo(2L));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat(sh.getAt(1).getId(), equalTo("2"));
assertExplain(queryName, functionName, sr);
}
private void assertExplain(final String queryName, final String functionName, SearchResponse sr) {
SearchHit firstHit = sr.getHits().getAt(0);
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
// "description": "*:* (_name: query1)"
assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName));
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
// "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)"
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
// "description": "exp(-0.5*pow(MIN of: [Math.max(arcDistance(10.999999972991645, 21.99999994598329(=doc value),11.0, 20.0(=origin))
// - 0.0(=offset), 0)],2.0)/7.213475204444817E11, _name: func1)"
assertThat(
firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription().toString(),
containsString("_name: " + functionName)
);
}
} }

View File

@ -38,6 +38,7 @@ import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchType; import org.opensearch.action.search.SearchType;
import org.opensearch.common.lucene.search.function.CombineFunction; import org.opensearch.common.lucene.search.function.CombineFunction;
import org.opensearch.common.lucene.search.function.Functions;
import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.Settings;
import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.plugins.Plugin; import org.opensearch.plugins.Plugin;
@ -72,6 +73,7 @@ import static org.opensearch.index.query.QueryBuilders.functionScoreQuery;
import static org.opensearch.index.query.QueryBuilders.termQuery; import static org.opensearch.index.query.QueryBuilders.termQuery;
import static org.opensearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction; import static org.opensearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction;
import static org.opensearch.search.builder.SearchSourceBuilder.searchSource; import static org.opensearch.search.builder.SearchSourceBuilder.searchSource;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -121,8 +123,17 @@ public class ExplainableScriptIT extends OpenSearchIntegTestCase {
@Override @Override
public Explanation explain(Explanation subQueryScore) throws IOException { public Explanation explain(Explanation subQueryScore) throws IOException {
return explain(subQueryScore, null);
}
@Override
public Explanation explain(Explanation subQueryScore, String functionName) throws IOException {
Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore); Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
return Explanation.match((float) (execute(null)), "This script returned " + execute(null), scoreExp); return Explanation.match(
(float) (execute(null)),
"This script" + Functions.nameOrEmptyFunc(functionName) + " returned " + execute(null),
scoreExp
);
} }
@Override @Override
@ -174,4 +185,36 @@ public class ExplainableScriptIT extends OpenSearchIntegTestCase {
idCounter--; idCounter--;
} }
} }
public void testExplainScriptWithName() throws InterruptedException, IOException, ExecutionException {
List<IndexRequestBuilder> indexRequests = new ArrayList<>();
indexRequests.add(
client().prepareIndex("test")
.setId(Integer.toString(1))
.setSource(jsonBuilder().startObject().field("number_field", 1).field("text", "text").endObject())
);
indexRandom(true, true, indexRequests);
client().admin().indices().prepareRefresh().get();
ensureYellow();
SearchResponse response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
.source(
searchSource().explain(true)
.query(
functionScoreQuery(
termQuery("text", "text"),
scriptFunction(new Script(ScriptType.INLINE, "test", "explainable_script", Collections.emptyMap()), "func1")
).boostMode(CombineFunction.REPLACE)
)
)
).actionGet();
OpenSearchAssertions.assertNoFailures(response);
SearchHits hits = response.getHits();
assertThat(hits.getTotalHits().value, equalTo(1L));
assertThat(hits.getHits()[0].getId(), equalTo("1"));
assertThat(hits.getHits()[0].getExplanation().getDetails(), arrayWithSize(2));
assertThat(hits.getHits()[0].getExplanation().getDetails()[0].getDescription(), containsString("_name: func1"));
}
} }

View File

@ -35,10 +35,13 @@ package org.opensearch.search.functionscore;
import org.opensearch.action.search.SearchPhaseExecutionException; import org.opensearch.action.search.SearchPhaseExecutionException;
import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.lucene.search.function.FieldValueFactorFunction; import org.opensearch.common.lucene.search.function.FieldValueFactorFunction;
import org.opensearch.search.SearchHit;
import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.OpenSearchIntegTestCase;
import java.io.IOException; import java.io.IOException;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.opensearch.index.query.QueryBuilders.functionScoreQuery; import static org.opensearch.index.query.QueryBuilders.functionScoreQuery;
import static org.opensearch.index.query.QueryBuilders.matchAllQuery; import static org.opensearch.index.query.QueryBuilders.matchAllQuery;
@ -163,4 +166,47 @@ public class FunctionScoreFieldValueIT extends OpenSearchIntegTestCase {
// locally, instead of just having failures // locally, instead of just having failures
} }
} }
public void testFieldValueFactorExplain() throws IOException {
assertAcked(
prepareCreate("test").addMapping(
"type1",
jsonBuilder().startObject()
.startObject("type1")
.startObject("properties")
.startObject("test")
.field("type", randomFrom(new String[] { "short", "float", "long", "integer", "double" }))
.endObject()
.startObject("body")
.field("type", "text")
.endObject()
.endObject()
.endObject()
.endObject()
).get()
);
client().prepareIndex("test").setId("1").setSource("test", 5, "body", "foo").get();
client().prepareIndex("test").setId("2").setSource("test", 17, "body", "foo").get();
client().prepareIndex("test").setId("3").setSource("body", "bar").get();
refresh();
// document 2 scores higher because 17 > 5
final String functionName = "func1";
final String queryName = "query";
SearchResponse response = client().prepareSearch("test")
.setExplain(true)
.setQuery(
functionScoreQuery(simpleQueryStringQuery("foo").queryName(queryName), fieldValueFactorFunction("test", functionName))
)
.get();
assertOrderedSearchHits(response, "2", "1");
SearchHit firstHit = response.getHits().getAt(0);
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
// "description": "sum of: (_name: query)"
assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: " + queryName));
// "description": "field value function(_name: func1): none(doc['test'].value * factor=1.0)"
assertThat(firstHit.getExplanation().getDetails()[1].toString(), containsString("_name: " + functionName));
}
} }

View File

@ -43,6 +43,7 @@ import org.opensearch.plugins.Plugin;
import org.opensearch.script.MockScriptPlugin; import org.opensearch.script.MockScriptPlugin;
import org.opensearch.script.Script; import org.opensearch.script.Script;
import org.opensearch.script.ScriptType; import org.opensearch.script.ScriptType;
import org.opensearch.search.SearchHit;
import org.opensearch.search.aggregations.bucket.terms.Terms; import org.opensearch.search.aggregations.bucket.terms.Terms;
import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.OpenSearchTestCase;
@ -66,6 +67,8 @@ import static org.opensearch.search.aggregations.AggregationBuilders.terms;
import static org.opensearch.search.builder.SearchSourceBuilder.searchSource; import static org.opensearch.search.builder.SearchSourceBuilder.searchSource;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
@ -140,6 +143,35 @@ public class FunctionScoreIT extends OpenSearchIntegTestCase {
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1L)); assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1L));
} }
public void testScriptScoresWithAggWithExplain() throws IOException {
createIndex(INDEX);
index(INDEX, TYPE, "1", jsonBuilder().startObject().field("dummy_field", 1).endObject());
refresh();
Script script = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "get score value", Collections.emptyMap());
SearchResponse response = client().search(
searchRequest().source(
searchSource().explain(true)
.query(functionScoreQuery(scriptFunction(script, "func1"), "query1"))
.aggregation(terms("score_agg").script(script))
)
).actionGet();
assertSearchResponse(response);
final SearchHit firstHit = response.getHits().getAt(0);
assertThat(firstHit.getScore(), equalTo(1.0f));
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
// "description": "*:* (_name: query1)"
assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: query1"));
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
// "description": "script score function(_name: func1), computed with script:\"Script{ ... }\""
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDescription(), containsString("_name: func1"));
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsString(), equalTo("1.0"));
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1L));
}
public void testMinScoreFunctionScoreBasic() throws IOException { public void testMinScoreFunctionScoreBasic() throws IOException {
float score = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat); float score = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat);
float minScore = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat); float minScore = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat);

View File

@ -171,7 +171,7 @@ public class FunctionScorePluginIT extends OpenSearchIntegTestCase {
} }
@Override @Override
public Explanation explainFunction(String distanceString, double distanceVal, double scale) { public Explanation explainFunction(String distanceString, double distanceVal, double scale, String functionName) {
return Explanation.match((float) distanceVal, "" + distanceVal); return Explanation.match((float) distanceVal, "" + distanceVal);
} }

View File

@ -63,6 +63,7 @@ import static org.opensearch.script.MockScriptPlugin.NAME;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
@ -289,6 +290,37 @@ public class RandomScoreFunctionIT extends OpenSearchIntegTestCase {
assertThat(firstHit.getExplanation().toString(), containsString("" + seed)); assertThat(firstHit.getExplanation().toString(), containsString("" + seed));
} }
public void testSeedAndNameReportedInExplain() throws Exception {
createIndex("test");
ensureGreen();
index("test", "type", "1", jsonBuilder().startObject().endObject());
flush();
refresh();
int seed = 12345678;
final String queryName = "query1";
final String functionName = "func1";
SearchResponse resp = client().prepareSearch("test")
.setQuery(
functionScoreQuery(
matchAllQuery().queryName(queryName),
randomFunction(functionName).seed(seed).setField(SeqNoFieldMapper.NAME)
)
)
.setExplain(true)
.get();
assertNoFailures(resp);
assertEquals(1, resp.getHits().getTotalHits().value);
SearchHit firstHit = resp.getHits().getAt(0);
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
// "description": "*:* (_name: query1)"
assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName));
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
// "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)"
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDescription().toString(), containsString("seed: " + seed));
}
public void testNoDocs() throws Exception { public void testNoDocs() throws Exception {
createIndex("test"); createIndex("test");
ensureGreen(); ensureGreen();

View File

@ -35,6 +35,7 @@ package org.opensearch.common.lucene.search.function;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchException;
import org.opensearch.common.Nullable;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable; import org.opensearch.common.io.stream.Writeable;
@ -55,6 +56,8 @@ public class FieldValueFactorFunction extends ScoreFunction {
private final String field; private final String field;
private final float boostFactor; private final float boostFactor;
private final Modifier modifier; private final Modifier modifier;
private final String functionName;
/** /**
* Value used if the document is missing the field. * Value used if the document is missing the field.
*/ */
@ -67,6 +70,17 @@ public class FieldValueFactorFunction extends ScoreFunction {
Modifier modifierType, Modifier modifierType,
Double missing, Double missing,
IndexNumericFieldData indexFieldData IndexNumericFieldData indexFieldData
) {
this(field, boostFactor, modifierType, missing, indexFieldData, null);
}
public FieldValueFactorFunction(
String field,
float boostFactor,
Modifier modifierType,
Double missing,
IndexNumericFieldData indexFieldData,
@Nullable String functionName
) { ) {
super(CombineFunction.MULTIPLY); super(CombineFunction.MULTIPLY);
this.field = field; this.field = field;
@ -74,6 +88,7 @@ public class FieldValueFactorFunction extends ScoreFunction {
this.modifier = modifierType; this.modifier = modifierType;
this.indexFieldData = indexFieldData; this.indexFieldData = indexFieldData;
this.missing = missing; this.missing = missing;
this.functionName = functionName;
} }
@Override @Override
@ -127,7 +142,7 @@ public class FieldValueFactorFunction extends ScoreFunction {
(float) score, (float) score,
String.format( String.format(
Locale.ROOT, Locale.ROOT,
"field value function: %s(doc['%s'].value%s * factor=%s)", "field value function" + Functions.nameOrEmptyFunc(functionName) + ": %s(doc['%s'].value%s * factor=%s)",
modifierStr, modifierStr,
field, field,
defaultStr, defaultStr,

View File

@ -46,6 +46,7 @@ import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight; import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchException;
import org.opensearch.common.Nullable;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable; import org.opensearch.common.io.stream.Writeable;
@ -70,11 +71,28 @@ public class FunctionScoreQuery extends Query {
public static class FilterScoreFunction extends ScoreFunction { public static class FilterScoreFunction extends ScoreFunction {
public final Query filter; public final Query filter;
public final ScoreFunction function; public final ScoreFunction function;
public final String queryName;
/**
* Creates a FilterScoreFunction with query and function.
* @param filter filter query
* @param function score function
*/
public FilterScoreFunction(Query filter, ScoreFunction function) { public FilterScoreFunction(Query filter, ScoreFunction function) {
this(filter, function, null);
}
/**
* Creates a FilterScoreFunction with query and function.
* @param filter filter query
* @param function score function
* @param queryName filter query name
*/
public FilterScoreFunction(Query filter, ScoreFunction function, @Nullable String queryName) {
super(function.getDefaultScoreCombiner()); super(function.getDefaultScoreCombiner());
this.filter = filter; this.filter = filter;
this.function = function; this.function = function;
this.queryName = queryName;
} }
@Override @Override
@ -93,12 +111,14 @@ public class FunctionScoreQuery extends Query {
return false; return false;
} }
FilterScoreFunction that = (FilterScoreFunction) other; FilterScoreFunction that = (FilterScoreFunction) other;
return Objects.equals(this.filter, that.filter) && Objects.equals(this.function, that.function); return Objects.equals(this.filter, that.filter)
&& Objects.equals(this.function, that.function)
&& Objects.equals(this.queryName, that.queryName);
} }
@Override @Override
protected int doHashCode() { protected int doHashCode() {
return Objects.hash(filter, function); return Objects.hash(filter, function, queryName);
} }
@Override @Override
@ -107,7 +127,7 @@ public class FunctionScoreQuery extends Query {
if (newFilter == filter) { if (newFilter == filter) {
return this; return this;
} }
return new FilterScoreFunction(newFilter, function); return new FilterScoreFunction(newFilter, function, queryName);
} }
@Override @Override
@ -144,6 +164,7 @@ public class FunctionScoreQuery extends Query {
final float maxBoost; final float maxBoost;
private final Float minScore; private final Float minScore;
private final CombineFunction combineFunction; private final CombineFunction combineFunction;
private final String queryName;
/** /**
* Creates a FunctionScoreQuery without function. * Creates a FunctionScoreQuery without function.
@ -152,7 +173,18 @@ public class FunctionScoreQuery extends Query {
* @param maxBoost The maximum applicable boost. * @param maxBoost The maximum applicable boost.
*/ */
public FunctionScoreQuery(Query subQuery, Float minScore, float maxBoost) { public FunctionScoreQuery(Query subQuery, Float minScore, float maxBoost) {
this(subQuery, ScoreMode.FIRST, new ScoreFunction[0], CombineFunction.MULTIPLY, minScore, maxBoost); this(subQuery, null, minScore, maxBoost);
}
/**
* Creates a FunctionScoreQuery without function.
* @param subQuery The query to match.
* @param queryName filter query name
* @param minScore The minimum score to consider a document.
* @param maxBoost The maximum applicable boost.
*/
public FunctionScoreQuery(Query subQuery, @Nullable String queryName, Float minScore, float maxBoost) {
this(subQuery, queryName, ScoreMode.FIRST, new ScoreFunction[0], CombineFunction.MULTIPLY, minScore, maxBoost);
} }
/** /**
@ -161,7 +193,17 @@ public class FunctionScoreQuery extends Query {
* @param function The {@link ScoreFunction} to apply. * @param function The {@link ScoreFunction} to apply.
*/ */
public FunctionScoreQuery(Query subQuery, ScoreFunction function) { public FunctionScoreQuery(Query subQuery, ScoreFunction function) {
this(subQuery, function, CombineFunction.MULTIPLY, null, DEFAULT_MAX_BOOST); this(subQuery, null, function);
}
/**
* Creates a FunctionScoreQuery with a single {@link ScoreFunction}
* @param subQuery The query to match.
* @param queryName filter query name
* @param function The {@link ScoreFunction} to apply.
*/
public FunctionScoreQuery(Query subQuery, @Nullable String queryName, ScoreFunction function) {
this(subQuery, queryName, function, CombineFunction.MULTIPLY, null, DEFAULT_MAX_BOOST);
} }
/** /**
@ -173,7 +215,27 @@ public class FunctionScoreQuery extends Query {
* @param maxBoost The maximum applicable boost. * @param maxBoost The maximum applicable boost.
*/ */
public FunctionScoreQuery(Query subQuery, ScoreFunction function, CombineFunction combineFunction, Float minScore, float maxBoost) { public FunctionScoreQuery(Query subQuery, ScoreFunction function, CombineFunction combineFunction, Float minScore, float maxBoost) {
this(subQuery, ScoreMode.FIRST, new ScoreFunction[] { function }, combineFunction, minScore, maxBoost); this(subQuery, null, function, combineFunction, minScore, maxBoost);
}
/**
* Creates a FunctionScoreQuery with a single function
* @param subQuery The query to match.
* @param queryName filter query name
* @param function The {@link ScoreFunction} to apply.
* @param combineFunction Defines how the query and function score should be applied.
* @param minScore The minimum score to consider a document.
* @param maxBoost The maximum applicable boost.
*/
public FunctionScoreQuery(
Query subQuery,
@Nullable String queryName,
ScoreFunction function,
CombineFunction combineFunction,
Float minScore,
float maxBoost
) {
this(subQuery, queryName, ScoreMode.FIRST, new ScoreFunction[] { function }, combineFunction, minScore, maxBoost);
} }
/** /**
@ -192,11 +254,34 @@ public class FunctionScoreQuery extends Query {
CombineFunction combineFunction, CombineFunction combineFunction,
Float minScore, Float minScore,
float maxBoost float maxBoost
) {
this(subQuery, null, scoreMode, functions, combineFunction, minScore, maxBoost);
}
/**
* Creates a FunctionScoreQuery with multiple score functions
* @param subQuery The query to match.
* @param queryName filter query name
* @param scoreMode Defines how the different score functions should be combined.
* @param functions The {@link ScoreFunction}s to apply.
* @param combineFunction Defines how the query and function score should be applied.
* @param minScore The minimum score to consider a document.
* @param maxBoost The maximum applicable boost.
*/
public FunctionScoreQuery(
Query subQuery,
@Nullable String queryName,
ScoreMode scoreMode,
ScoreFunction[] functions,
CombineFunction combineFunction,
Float minScore,
float maxBoost
) { ) {
if (Arrays.stream(functions).anyMatch(func -> func == null)) { if (Arrays.stream(functions).anyMatch(func -> func == null)) {
throw new IllegalArgumentException("Score function should not be null"); throw new IllegalArgumentException("Score function should not be null");
} }
this.subQuery = subQuery; this.subQuery = subQuery;
this.queryName = queryName;
this.scoreMode = scoreMode; this.scoreMode = scoreMode;
this.functions = functions; this.functions = functions;
this.maxBoost = maxBoost; this.maxBoost = maxBoost;
@ -240,7 +325,7 @@ public class FunctionScoreQuery extends Query {
needsRewrite |= (newFunctions[i] != functions[i]); needsRewrite |= (newFunctions[i] != functions[i]);
} }
if (needsRewrite) { if (needsRewrite) {
return new FunctionScoreQuery(newQ, scoreMode, newFunctions, combineFunction, minScore, maxBoost); return new FunctionScoreQuery(newQ, queryName, scoreMode, newFunctions, combineFunction, minScore, maxBoost);
} }
return this; return this;
} }
@ -332,8 +417,7 @@ public class FunctionScoreQuery extends Query {
@Override @Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException { public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Explanation expl = Functions.explainWithName(subQueryWeight.explain(context, doc), queryName);
Explanation expl = subQueryWeight.explain(context, doc);
if (!expl.isMatch()) { if (!expl.isMatch()) {
return expl; return expl;
} }
@ -355,11 +439,15 @@ public class FunctionScoreQuery extends Query {
Explanation functionExplanation = function.getLeafScoreFunction(context).explainScore(doc, expl); Explanation functionExplanation = function.getLeafScoreFunction(context).explainScore(doc, expl);
if (function instanceof FilterScoreFunction) { if (function instanceof FilterScoreFunction) {
float factor = functionExplanation.getValue().floatValue(); float factor = functionExplanation.getValue().floatValue();
Query filterQuery = ((FilterScoreFunction) function).filter; final FilterScoreFunction filterScoreFunction = (FilterScoreFunction) function;
Query filterQuery = filterScoreFunction.filter;
Explanation filterExplanation = Explanation.match( Explanation filterExplanation = Explanation.match(
factor, factor,
"function score, product of:", "function score, product of:",
Explanation.match(1.0f, "match filter: " + filterQuery.toString()), Explanation.match(
1.0f,
"match filter" + Functions.nameOrEmptyFunc(filterScoreFunction.queryName) + ": " + filterQuery.toString()
),
functionExplanation functionExplanation
); );
functionsExplanations.add(filterExplanation); functionsExplanations.add(filterExplanation);
@ -543,11 +631,12 @@ public class FunctionScoreQuery extends Query {
&& Objects.equals(this.combineFunction, other.combineFunction) && Objects.equals(this.combineFunction, other.combineFunction)
&& Objects.equals(this.minScore, other.minScore) && Objects.equals(this.minScore, other.minScore)
&& Objects.equals(this.scoreMode, other.scoreMode) && Objects.equals(this.scoreMode, other.scoreMode)
&& Arrays.equals(this.functions, other.functions); && Arrays.equals(this.functions, other.functions)
&& Objects.equals(this.queryName, other.queryName);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(classHash(), subQuery, maxBoost, combineFunction, minScore, scoreMode, Arrays.hashCode(functions)); return Objects.hash(classHash(), subQuery, maxBoost, combineFunction, minScore, scoreMode, Arrays.hashCode(functions), queryName);
} }
} }

View File

@ -0,0 +1,66 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.common.lucene.search.function;
import org.apache.lucene.search.Explanation;
import org.opensearch.common.Strings;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder;
/**
* Helper utility class for functions
*/
public final class Functions {
private Functions() {}
/**
* Return function name wrapped into brackets or empty string, for example: '(_name: func1)'
* @param functionName function name
* @return function name wrapped into brackets or empty string
*/
public static String nameOrEmptyFunc(final String functionName) {
if (!Strings.isNullOrEmpty(functionName)) {
return "(" + AbstractQueryBuilder.NAME_FIELD.getPreferredName() + ": " + functionName + ")";
} else {
return "";
}
}
/**
* Return function name as an argument or empty string, for example: ', _name: func1'
* @param functionName function name
* @return function name as an argument or empty string
*/
public static String nameOrEmptyArg(final String functionName) {
if (!Strings.isNullOrEmpty(functionName)) {
return ", " + FunctionScoreQueryBuilder.NAME_FIELD.getPreferredName() + ": " + functionName;
} else {
return "";
}
}
/**
* Enrich explanation with query name
* @param explanation explanation
* @param queryName query name
* @return explanation enriched with query name
*/
public static Explanation explainWithName(Explanation explanation, String queryName) {
if (Strings.isNullOrEmpty(queryName)) {
return explanation;
} else {
final String description = explanation.getDescription() + " " + nameOrEmptyFunc(queryName);
if (explanation.isMatch()) {
return Explanation.match(explanation.getValue(), description, explanation.getDetails());
} else {
return Explanation.noMatch(description, explanation.getDetails());
}
}
}
}

View File

@ -35,6 +35,7 @@ import com.carrotsearch.hppc.BitMixer;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.StringHelper;
import org.opensearch.common.Nullable;
import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.SortedBinaryDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues;
@ -50,6 +51,7 @@ public class RandomScoreFunction extends ScoreFunction {
private final int originalSeed; private final int originalSeed;
private final int saltedSeed; private final int saltedSeed;
private final IndexFieldData<?> fieldData; private final IndexFieldData<?> fieldData;
private final String functionName;
/** /**
* Creates a RandomScoreFunction. * Creates a RandomScoreFunction.
@ -59,10 +61,23 @@ public class RandomScoreFunction extends ScoreFunction {
* @param uidFieldData The field data for _uid to use for generating consistent random values for the same id * @param uidFieldData The field data for _uid to use for generating consistent random values for the same id
*/ */
public RandomScoreFunction(int seed, int salt, IndexFieldData<?> uidFieldData) { public RandomScoreFunction(int seed, int salt, IndexFieldData<?> uidFieldData) {
this(seed, salt, uidFieldData, null);
}
/**
* Creates a RandomScoreFunction.
*
* @param seed A seed for randomness
* @param salt A value to salt the seed with, ideally unique to the running node/index
* @param uidFieldData The field data for _uid to use for generating consistent random values for the same id
* @param functionName The function name
*/
public RandomScoreFunction(int seed, int salt, IndexFieldData<?> uidFieldData, @Nullable String functionName) {
super(CombineFunction.MULTIPLY); super(CombineFunction.MULTIPLY);
this.originalSeed = seed; this.originalSeed = seed;
this.saltedSeed = BitMixer.mix(seed, salt); this.saltedSeed = BitMixer.mix(seed, salt);
this.fieldData = uidFieldData; this.fieldData = uidFieldData;
this.functionName = functionName;
} }
@Override @Override
@ -97,7 +112,7 @@ public class RandomScoreFunction extends ScoreFunction {
String field = fieldData == null ? null : fieldData.getFieldName(); String field = fieldData == null ? null : fieldData.getFieldName();
return Explanation.match( return Explanation.match(
(float) score(docId, subQueryScore.getValue().floatValue()), (float) score(docId, subQueryScore.getValue().floatValue()),
"random score function (seed: " + originalSeed + ", field: " + field + ")" "random score function (seed: " + originalSeed + ", field: " + field + Functions.nameOrEmptyArg(functionName) + ")"
); );
} }
}; };

View File

@ -39,6 +39,7 @@ import org.opensearch.script.ExplainableScoreScript;
import org.opensearch.script.ScoreScript; import org.opensearch.script.ScoreScript;
import org.opensearch.script.Script; import org.opensearch.script.Script;
import org.opensearch.Version; import org.opensearch.Version;
import org.opensearch.common.Nullable;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
@ -67,14 +68,23 @@ public class ScriptScoreFunction extends ScoreFunction {
private final int shardId; private final int shardId;
private final String indexName; private final String indexName;
private final Version indexVersion; private final Version indexVersion;
private final String functionName;
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) { public ScriptScoreFunction(
Script sScript,
ScoreScript.LeafFactory script,
String indexName,
int shardId,
Version indexVersion,
@Nullable String functionName
) {
super(CombineFunction.REPLACE); super(CombineFunction.REPLACE);
this.sScript = sScript; this.sScript = sScript;
this.script = script; this.script = script;
this.indexName = indexName; this.indexName = indexName;
this.shardId = shardId; this.shardId = shardId;
this.indexVersion = indexVersion; this.indexVersion = indexVersion;
this.functionName = functionName;
} }
@Override @Override
@ -105,11 +115,15 @@ public class ScriptScoreFunction extends ScoreFunction {
leafScript.setDocument(docId); leafScript.setDocument(docId);
scorer.docid = docId; scorer.docid = docId;
scorer.score = subQueryScore.getValue().floatValue(); scorer.score = subQueryScore.getValue().floatValue();
exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore); exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore, functionName);
} else { } else {
double score = score(docId, subQueryScore.getValue().floatValue()); double score = score(docId, subQueryScore.getValue().floatValue());
// info about params already included in sScript // info about params already included in sScript
String explanation = "script score function, computed with script:\"" + sScript + "\""; String explanation = "script score function"
+ Functions.nameOrEmptyFunc(functionName)
+ ", computed with script:\""
+ sScript
+ "\"";
Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore); Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
return Explanation.match((float) score, explanation, scoreExp); return Explanation.match((float) score, explanation, scoreExp);
} }

View File

@ -50,6 +50,7 @@ import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.opensearch.Version; import org.opensearch.Version;
import org.opensearch.common.Nullable;
import org.opensearch.script.ScoreScript; import org.opensearch.script.ScoreScript;
import org.opensearch.script.ScoreScript.ExplanationHolder; import org.opensearch.script.ScoreScript.ExplanationHolder;
import org.opensearch.script.Script; import org.opensearch.script.Script;
@ -69,6 +70,7 @@ public class ScriptScoreQuery extends Query {
private final String indexName; private final String indexName;
private final int shardId; private final int shardId;
private final Version indexVersion; private final Version indexVersion;
private final String queryName;
public ScriptScoreQuery( public ScriptScoreQuery(
Query subQuery, Query subQuery,
@ -78,8 +80,22 @@ public class ScriptScoreQuery extends Query {
String indexName, String indexName,
int shardId, int shardId,
Version indexVersion Version indexVersion
) {
this(subQuery, null, script, scriptBuilder, minScore, indexName, shardId, indexVersion);
}
public ScriptScoreQuery(
Query subQuery,
@Nullable String queryName,
Script script,
ScoreScript.LeafFactory scriptBuilder,
Float minScore,
String indexName,
int shardId,
Version indexVersion
) { ) {
this.subQuery = subQuery; this.subQuery = subQuery;
this.queryName = queryName;
this.script = script; this.script = script;
this.scriptBuilder = scriptBuilder; this.scriptBuilder = scriptBuilder;
this.minScore = minScore; this.minScore = minScore;
@ -92,7 +108,7 @@ public class ScriptScoreQuery extends Query {
public Query rewrite(IndexReader reader) throws IOException { public Query rewrite(IndexReader reader) throws IOException {
Query newQ = subQuery.rewrite(reader); Query newQ = subQuery.rewrite(reader);
if (newQ != subQuery) { if (newQ != subQuery) {
return new ScriptScoreQuery(newQ, script, scriptBuilder, minScore, indexName, shardId, indexVersion); return new ScriptScoreQuery(newQ, queryName, script, scriptBuilder, minScore, indexName, shardId, indexVersion);
} }
return super.rewrite(reader); return super.rewrite(reader);
} }
@ -140,7 +156,7 @@ public class ScriptScoreQuery extends Query {
@Override @Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException { public Explanation explain(LeafReaderContext context, int doc) throws IOException {
Explanation subQueryExplanation = subQueryWeight.explain(context, doc); Explanation subQueryExplanation = Functions.explainWithName(subQueryWeight.explain(context, doc), queryName);
if (subQueryExplanation.isMatch() == false) { if (subQueryExplanation.isMatch() == false) {
return subQueryExplanation; return subQueryExplanation;
} }
@ -210,7 +226,8 @@ public class ScriptScoreQuery extends Query {
@Override @Override
public String toString(String field) { public String toString(String field) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append("script_score (").append(subQuery.toString(field)).append(", script: "); sb.append("script_score (").append(subQuery.toString(field));
sb.append(Functions.nameOrEmptyArg(queryName)).append(", script: ");
sb.append("{" + script.toString() + "}"); sb.append("{" + script.toString() + "}");
return sb.toString(); return sb.toString();
} }
@ -225,12 +242,13 @@ public class ScriptScoreQuery extends Query {
&& script.equals(that.script) && script.equals(that.script)
&& Objects.equals(minScore, that.minScore) && Objects.equals(minScore, that.minScore)
&& indexName.equals(that.indexName) && indexName.equals(that.indexName)
&& indexVersion.equals(that.indexVersion); && indexVersion.equals(that.indexVersion)
&& Objects.equals(queryName, that.queryName);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion); return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion, queryName);
} }
private static class ScriptScorer extends Scorer { private static class ScriptScorer extends Scorer {

View File

@ -34,6 +34,8 @@ package org.opensearch.common.lucene.search.function;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.common.Nullable;
import org.opensearch.common.Strings;
import java.io.IOException; import java.io.IOException;
import java.util.Objects; import java.util.Objects;
@ -45,9 +47,17 @@ public class WeightFactorFunction extends ScoreFunction {
private float weight = 1.0f; private float weight = 1.0f;
public WeightFactorFunction(float weight, ScoreFunction scoreFunction) { public WeightFactorFunction(float weight, ScoreFunction scoreFunction) {
this(weight, scoreFunction, null);
}
public WeightFactorFunction(float weight, ScoreFunction scoreFunction, @Nullable String functionName) {
super(CombineFunction.MULTIPLY); super(CombineFunction.MULTIPLY);
if (scoreFunction == null) { if (scoreFunction == null) {
this.scoreFunction = SCORE_ONE; if (Strings.isNullOrEmpty(functionName)) {
this.scoreFunction = SCORE_ONE;
} else {
this.scoreFunction = new ScoreOne(CombineFunction.MULTIPLY, functionName);
}
} else { } else {
this.scoreFunction = scoreFunction; this.scoreFunction = scoreFunction;
} }
@ -55,9 +65,11 @@ public class WeightFactorFunction extends ScoreFunction {
} }
public WeightFactorFunction(float weight) { public WeightFactorFunction(float weight) {
super(CombineFunction.MULTIPLY); this(weight, null, null);
this.scoreFunction = SCORE_ONE; }
this.weight = weight;
public WeightFactorFunction(float weight, @Nullable String functionName) {
this(weight, null, functionName);
} }
@Override @Override
@ -112,9 +124,15 @@ public class WeightFactorFunction extends ScoreFunction {
} }
private static class ScoreOne extends ScoreFunction { private static class ScoreOne extends ScoreFunction {
private final String functionName;
protected ScoreOne(CombineFunction scoreCombiner) { protected ScoreOne(CombineFunction scoreCombiner) {
this(scoreCombiner, null);
}
protected ScoreOne(CombineFunction scoreCombiner, @Nullable String functionName) {
super(scoreCombiner); super(scoreCombiner);
this.functionName = functionName;
} }
@Override @Override
@ -127,7 +145,10 @@ public class WeightFactorFunction extends ScoreFunction {
@Override @Override
public Explanation explainScore(int docId, Explanation subQueryScore) { public Explanation explainScore(int docId, Explanation subQueryScore) {
return Explanation.match(1.0f, "constant score 1.0 - no function provided"); return Explanation.match(
1.0f,
"constant score 1.0" + Functions.nameOrEmptyFunc(functionName) + " - no function provided"
);
} }
}; };
} }

View File

@ -33,6 +33,7 @@
package org.opensearch.index.query; package org.opensearch.index.query;
import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.common.Nullable;
import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.geo.GeoPoint; import org.opensearch.common.geo.GeoPoint;
import org.opensearch.common.geo.ShapeRelation; import org.opensearch.common.geo.ShapeRelation;
@ -452,7 +453,17 @@ public final class QueryBuilders {
* @param function The function builder used to custom score * @param function The function builder used to custom score
*/ */
public static FunctionScoreQueryBuilder functionScoreQuery(ScoreFunctionBuilder function) { public static FunctionScoreQueryBuilder functionScoreQuery(ScoreFunctionBuilder function) {
return new FunctionScoreQueryBuilder(function); return functionScoreQuery(function, null);
}
/**
* A query that allows to define a custom scoring function.
*
* @param function The function builder used to custom score
* @param queryName The query name
*/
public static FunctionScoreQueryBuilder functionScoreQuery(ScoreFunctionBuilder function, @Nullable String queryName) {
return new FunctionScoreQueryBuilder(function, queryName);
} }
/** /**

View File

@ -43,9 +43,11 @@ import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight; import org.apache.lucene.search.Weight;
import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchException;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParsingException; import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.lucene.search.function.Functions;
import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.script.FilterScript; import org.opensearch.script.FilterScript;
@ -153,17 +155,19 @@ public class ScriptQueryBuilder extends AbstractQueryBuilder<ScriptQueryBuilder>
} }
FilterScript.Factory factory = context.compile(script, FilterScript.CONTEXT); FilterScript.Factory factory = context.compile(script, FilterScript.CONTEXT);
FilterScript.LeafFactory filterScript = factory.newFactory(script.getParams(), context.lookup()); FilterScript.LeafFactory filterScript = factory.newFactory(script.getParams(), context.lookup());
return new ScriptQuery(script, filterScript); return new ScriptQuery(script, filterScript, queryName);
} }
static class ScriptQuery extends Query { static class ScriptQuery extends Query {
final Script script; final Script script;
final FilterScript.LeafFactory filterScript; final FilterScript.LeafFactory filterScript;
final String queryName;
ScriptQuery(Script script, FilterScript.LeafFactory filterScript) { ScriptQuery(Script script, FilterScript.LeafFactory filterScript, @Nullable String queryName) {
this.script = script; this.script = script;
this.filterScript = filterScript; this.filterScript = filterScript;
this.queryName = queryName;
} }
@Override @Override
@ -171,6 +175,7 @@ public class ScriptQueryBuilder extends AbstractQueryBuilder<ScriptQueryBuilder>
StringBuilder buffer = new StringBuilder(); StringBuilder buffer = new StringBuilder();
buffer.append("ScriptQuery("); buffer.append("ScriptQuery(");
buffer.append(script); buffer.append(script);
buffer.append(Functions.nameOrEmptyArg(queryName));
buffer.append(")"); buffer.append(")");
return buffer.toString(); return buffer.toString();
} }

View File

@ -33,6 +33,7 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.common.Nullable;
/** /**
* Implement this interface to provide a decay function that is executed on a * Implement this interface to provide a decay function that is executed on a
@ -45,7 +46,7 @@ public interface DecayFunction {
double evaluate(double value, double scale); double evaluate(double value, double scale);
Explanation explainFunction(String valueString, double value, double scale); Explanation explainFunction(String valueString, double value, double scale, @Nullable String functionName);
/** /**
* The final scale parameter is computed from the scale parameter given by * The final scale parameter is computed from the scale parameter given by

View File

@ -35,6 +35,7 @@ package org.opensearch.index.query.functionscore;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.OpenSearchParseException; import org.opensearch.OpenSearchParseException;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParsingException; import org.opensearch.common.ParsingException;
import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.geo.GeoDistance; import org.opensearch.common.geo.GeoDistance;
@ -93,10 +94,31 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
this(fieldName, origin, scale, offset, DEFAULT_DECAY); this(fieldName, origin, scale, offset, DEFAULT_DECAY);
} }
/**
* Convenience constructor that converts its parameters into json to parse on the data nodes.
*/
protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
this(fieldName, origin, scale, offset, DEFAULT_DECAY, functionName);
}
/** /**
* Convenience constructor that converts its parameters into json to parse on the data nodes. * Convenience constructor that converts its parameters into json to parse on the data nodes.
*/ */
protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) { protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) {
this(fieldName, origin, scale, offset, decay, null);
}
/**
* Convenience constructor that converts its parameters into json to parse on the data nodes.
*/
protected DecayFunctionBuilder(
String fieldName,
Object origin,
Object scale,
Object offset,
double decay,
@Nullable String functionName
) {
if (fieldName == null) { if (fieldName == null) {
throw new IllegalArgumentException("decay function: field name must not be null"); throw new IllegalArgumentException("decay function: field name must not be null");
} }
@ -123,6 +145,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
} catch (IOException e) { } catch (IOException e) {
throw new IllegalArgumentException("unable to build inner function object", e); throw new IllegalArgumentException("unable to build inner function object", e);
} }
setFunctionName(functionName);
} }
protected DecayFunctionBuilder(String fieldName, BytesReference functionBytes) { protected DecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
@ -285,7 +308,16 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
); );
} }
IndexNumericFieldData numericFieldData = context.getForField(fieldType); IndexNumericFieldData numericFieldData = context.getForField(fieldType);
return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode); return new NumericFieldDataScoreFunction(
origin,
scale,
decay,
offset,
getDecayFunction(),
numericFieldData,
mode,
getFunctionName()
);
} }
private AbstractDistanceScoreFunction parseGeoVariable( private AbstractDistanceScoreFunction parseGeoVariable(
@ -325,7 +357,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
double scale = DistanceUnit.DEFAULT.parse(scaleString, DistanceUnit.DEFAULT); double scale = DistanceUnit.DEFAULT.parse(scaleString, DistanceUnit.DEFAULT);
double offset = DistanceUnit.DEFAULT.parse(offsetString, DistanceUnit.DEFAULT); double offset = DistanceUnit.DEFAULT.parse(offsetString, DistanceUnit.DEFAULT);
IndexGeoPointFieldData indexFieldData = context.getForField(fieldType); IndexGeoPointFieldData indexFieldData = context.getForField(fieldType);
return new GeoFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), indexFieldData, mode); return new GeoFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), indexFieldData, mode, getFunctionName());
} }
@ -375,7 +407,16 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24), DecayFunctionParser.class.getSimpleName() + ".offset"); val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24), DecayFunctionParser.class.getSimpleName() + ".offset");
double offset = val.getMillis(); double offset = val.getMillis();
IndexNumericFieldData numericFieldData = context.getForField(dateFieldType); IndexNumericFieldData numericFieldData = context.getForField(dateFieldType);
return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode); return new NumericFieldDataScoreFunction(
origin,
scale,
decay,
offset,
getDecayFunction(),
numericFieldData,
mode,
getFunctionName()
);
} }
static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction { static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction {
@ -392,9 +433,10 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
double offset, double offset,
DecayFunction func, DecayFunction func,
IndexGeoPointFieldData fieldData, IndexGeoPointFieldData fieldData,
MultiValueMode mode MultiValueMode mode,
@Nullable String functionName
) { ) {
super(scale, decay, offset, func, mode); super(scale, decay, offset, func, mode, functionName);
this.origin = origin; this.origin = origin;
this.fieldData = fieldData; this.fieldData = fieldData;
} }
@ -485,9 +527,10 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
double offset, double offset,
DecayFunction func, DecayFunction func,
IndexNumericFieldData fieldData, IndexNumericFieldData fieldData,
MultiValueMode mode MultiValueMode mode,
@Nullable String functionName
) { ) {
super(scale, decay, offset, func, mode); super(scale, decay, offset, func, mode, functionName);
this.fieldData = fieldData; this.fieldData = fieldData;
this.origin = origin; this.origin = origin;
} }
@ -569,13 +612,15 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
protected final double offset; protected final double offset;
private final DecayFunction func; private final DecayFunction func;
protected final MultiValueMode mode; protected final MultiValueMode mode;
protected final String functionName;
public AbstractDistanceScoreFunction( public AbstractDistanceScoreFunction(
double userSuppiedScale, double userSuppiedScale,
double decay, double decay,
double offset, double offset,
DecayFunction func, DecayFunction func,
MultiValueMode mode MultiValueMode mode,
@Nullable String functionName
) { ) {
super(CombineFunction.MULTIPLY); super(CombineFunction.MULTIPLY);
this.mode = mode; this.mode = mode;
@ -591,6 +636,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
throw new IllegalArgumentException(FunctionScoreQueryBuilder.NAME + " : offset must be > 0.0"); throw new IllegalArgumentException(FunctionScoreQueryBuilder.NAME + " : offset must be > 0.0");
} }
this.offset = offset; this.offset = offset;
this.functionName = functionName;
} }
/** /**
@ -624,7 +670,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
return Explanation.match( return Explanation.match(
(float) score(docId, subQueryScore.getValue().floatValue()), (float) score(docId, subQueryScore.getValue().floatValue()),
"Function for field " + getFieldName() + ":", "Function for field " + getFieldName() + ":",
func.explainFunction(getDistanceString(ctx, docId), value, scale) func.explainFunction(getDistanceString(ctx, docId), value, scale, functionName)
); );
} }
}; };

View File

@ -33,8 +33,10 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.common.Nullable;
import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.lucene.search.function.Functions;
import java.io.IOException; import java.io.IOException;
@ -45,6 +47,10 @@ public class ExponentialDecayFunctionBuilder extends DecayFunctionBuilder<Expone
); );
public static final DecayFunction EXP_DECAY_FUNCTION = new ExponentialDecayScoreFunction(); public static final DecayFunction EXP_DECAY_FUNCTION = new ExponentialDecayScoreFunction();
public ExponentialDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
super(fieldName, origin, scale, offset, functionName);
}
public ExponentialDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset) { public ExponentialDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset) {
super(fieldName, origin, scale, offset); super(fieldName, origin, scale, offset);
} }
@ -53,6 +59,17 @@ public class ExponentialDecayFunctionBuilder extends DecayFunctionBuilder<Expone
super(fieldName, origin, scale, offset, decay); super(fieldName, origin, scale, offset, decay);
} }
public ExponentialDecayFunctionBuilder(
String fieldName,
Object origin,
Object scale,
Object offset,
double decay,
@Nullable String functionName
) {
super(fieldName, origin, scale, offset, decay, functionName);
}
ExponentialDecayFunctionBuilder(String fieldName, BytesReference functionBytes) { ExponentialDecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
super(fieldName, functionBytes); super(fieldName, functionBytes);
} }
@ -82,8 +99,11 @@ public class ExponentialDecayFunctionBuilder extends DecayFunctionBuilder<Expone
} }
@Override @Override
public Explanation explainFunction(String valueExpl, double value, double scale) { public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) {
return Explanation.match((float) evaluate(value, scale), "exp(- " + valueExpl + " * " + -1 * scale + ")"); return Explanation.match(
(float) evaluate(value, scale),
"exp(- " + valueExpl + " * " + -1 * scale + Functions.nameOrEmptyArg(functionName) + ")"
);
} }
@Override @Override

View File

@ -33,6 +33,7 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchException;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParsingException; import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
@ -63,10 +64,15 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
private FieldValueFactorFunction.Modifier modifier = DEFAULT_MODIFIER; private FieldValueFactorFunction.Modifier modifier = DEFAULT_MODIFIER;
public FieldValueFactorFunctionBuilder(String fieldName) { public FieldValueFactorFunctionBuilder(String fieldName) {
this(fieldName, null);
}
public FieldValueFactorFunctionBuilder(String fieldName, @Nullable String functionName) {
if (fieldName == null) { if (fieldName == null) {
throw new IllegalArgumentException("field_value_factor: field must not be null"); throw new IllegalArgumentException("field_value_factor: field must not be null");
} }
this.field = fieldName; this.field = fieldName;
setFunctionName(functionName);
} }
/** /**
@ -166,7 +172,7 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
} else { } else {
fieldData = context.getForField(fieldType); fieldData = context.getForField(fieldType);
} }
return new FieldValueFactorFunction(field, factor, modifier, missing, fieldData); return new FieldValueFactorFunction(field, factor, modifier, missing, fieldData, getFunctionName());
} }
public static FieldValueFactorFunctionBuilder fromXContent(XContentParser parser) throws IOException, ParsingException { public static FieldValueFactorFunctionBuilder fromXContent(XContentParser parser) throws IOException, ParsingException {
@ -176,6 +182,7 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
FieldValueFactorFunction.Modifier modifier = FieldValueFactorFunction.Modifier.NONE; FieldValueFactorFunction.Modifier modifier = FieldValueFactorFunction.Modifier.NONE;
Double missing = null; Double missing = null;
XContentParser.Token token; XContentParser.Token token;
String functionName = null;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) { if (token == XContentParser.Token.FIELD_NAME) {
currentFieldName = parser.currentName(); currentFieldName = parser.currentName();
@ -188,6 +195,8 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
modifier = FieldValueFactorFunction.Modifier.fromString(parser.text()); modifier = FieldValueFactorFunction.Modifier.fromString(parser.text());
} else if ("missing".equals(currentFieldName)) { } else if ("missing".equals(currentFieldName)) {
missing = parser.doubleValue(); missing = parser.doubleValue();
} else if (FunctionScoreQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
functionName = parser.text();
} else { } else {
throw new ParsingException(parser.getTokenLocation(), NAME + " query does not support [" + currentFieldName + "]"); throw new ParsingException(parser.getTokenLocation(), NAME + " query does not support [" + currentFieldName + "]");
} }
@ -204,8 +213,9 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] required field 'field' missing"); throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] required field 'field' missing");
} }
FieldValueFactorFunctionBuilder fieldValueFactorFunctionBuilder = new FieldValueFactorFunctionBuilder(field).factor(boostFactor) FieldValueFactorFunctionBuilder fieldValueFactorFunctionBuilder = new FieldValueFactorFunctionBuilder(field, functionName).factor(
.modifier(modifier); boostFactor
).modifier(modifier);
if (missing != null) { if (missing != null) {
fieldValueFactorFunctionBuilder.missing(missing); fieldValueFactorFunctionBuilder.missing(missing);
} }

View File

@ -34,6 +34,7 @@ package org.opensearch.index.query.functionscore;
import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParseField; import org.opensearch.common.ParseField;
import org.opensearch.common.ParsingException; import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
@ -111,7 +112,17 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
* @param filterFunctionBuilders the filters and functions * @param filterFunctionBuilders the filters and functions
*/ */
public FunctionScoreQueryBuilder(FilterFunctionBuilder[] filterFunctionBuilders) { public FunctionScoreQueryBuilder(FilterFunctionBuilder[] filterFunctionBuilders) {
this(new MatchAllQueryBuilder(), filterFunctionBuilders); this(filterFunctionBuilders, null);
}
/**
* Creates a function_score query that executes the provided filters and functions on all documents
*
* @param filterFunctionBuilders the filters and functions
* @param queryName the query name
*/
public FunctionScoreQueryBuilder(FilterFunctionBuilder[] filterFunctionBuilders, @Nullable String queryName) {
this(new MatchAllQueryBuilder().queryName(queryName), filterFunctionBuilders);
} }
/** /**
@ -120,7 +131,20 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
* @param scoreFunctionBuilder score function that is executed * @param scoreFunctionBuilder score function that is executed
*/ */
public FunctionScoreQueryBuilder(ScoreFunctionBuilder<?> scoreFunctionBuilder) { public FunctionScoreQueryBuilder(ScoreFunctionBuilder<?> scoreFunctionBuilder) {
this(new MatchAllQueryBuilder(), new FilterFunctionBuilder[] { new FilterFunctionBuilder(scoreFunctionBuilder) }); this(scoreFunctionBuilder, null);
}
/**
* Creates a function_score query that will execute the function provided on all documents
*
* @param scoreFunctionBuilder score function that is executed
* @param queryName the query name
*/
public FunctionScoreQueryBuilder(ScoreFunctionBuilder<?> scoreFunctionBuilder, @Nullable String queryName) {
this(
new MatchAllQueryBuilder().queryName(queryName),
new FilterFunctionBuilder[] { new FilterFunctionBuilder(scoreFunctionBuilder) }
);
} }
/** /**
@ -316,15 +340,17 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
int i = 0; int i = 0;
for (FilterFunctionBuilder filterFunctionBuilder : filterFunctionBuilders) { for (FilterFunctionBuilder filterFunctionBuilder : filterFunctionBuilders) {
ScoreFunction scoreFunction = filterFunctionBuilder.getScoreFunction().toFunction(context); ScoreFunction scoreFunction = filterFunctionBuilder.getScoreFunction().toFunction(context);
if (filterFunctionBuilder.getFilter().getName().equals(MatchAllQueryBuilder.NAME)) { final QueryBuilder builder = filterFunctionBuilder.getFilter();
if (builder.getName().equals(MatchAllQueryBuilder.NAME)) {
filterFunctions[i++] = scoreFunction; filterFunctions[i++] = scoreFunction;
} else { } else {
Query filter = filterFunctionBuilder.getFilter().toQuery(context); Query filter = builder.toQuery(context);
filterFunctions[i++] = new FunctionScoreQuery.FilterScoreFunction(filter, scoreFunction); filterFunctions[i++] = new FunctionScoreQuery.FilterScoreFunction(filter, scoreFunction, builder.queryName());
} }
} }
Query query = this.query.toQuery(context); final QueryBuilder builder = this.query;
Query query = builder.toQuery(context);
if (query == null) { if (query == null) {
query = new MatchAllDocsQuery(); query = new MatchAllDocsQuery();
} }
@ -332,12 +358,12 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
CombineFunction boostMode = this.boostMode == null ? DEFAULT_BOOST_MODE : this.boostMode; CombineFunction boostMode = this.boostMode == null ? DEFAULT_BOOST_MODE : this.boostMode;
// handle cases where only one score function and no filter was provided. In this case we create a FunctionScoreQuery. // handle cases where only one score function and no filter was provided. In this case we create a FunctionScoreQuery.
if (filterFunctions.length == 0) { if (filterFunctions.length == 0) {
return new FunctionScoreQuery(query, minScore, maxBoost); return new FunctionScoreQuery(query, builder.queryName(), minScore, maxBoost);
} else if (filterFunctions.length == 1 && filterFunctions[0] instanceof FunctionScoreQuery.FilterScoreFunction == false) { } else if (filterFunctions.length == 1 && filterFunctions[0] instanceof FunctionScoreQuery.FilterScoreFunction == false) {
return new FunctionScoreQuery(query, filterFunctions[0], boostMode, minScore, maxBoost); return new FunctionScoreQuery(query, builder.queryName(), filterFunctions[0], boostMode, minScore, maxBoost);
} }
// in all other cases we create a FunctionScoreQuery with filters // in all other cases we create a FunctionScoreQuery with filters
return new FunctionScoreQuery(query, scoreMode, filterFunctions, boostMode, minScore, maxBoost); return new FunctionScoreQuery(query, builder.queryName(), scoreMode, filterFunctions, boostMode, minScore, maxBoost);
} }
/** /**
@ -606,6 +632,7 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
QueryBuilder filter = null; QueryBuilder filter = null;
ScoreFunctionBuilder<?> scoreFunction = null; ScoreFunctionBuilder<?> scoreFunction = null;
Float functionWeight = null; Float functionWeight = null;
String functionName = null;
if (token != XContentParser.Token.START_OBJECT) { if (token != XContentParser.Token.START_OBJECT) {
throw new ParsingException( throw new ParsingException(
parser.getTokenLocation(), parser.getTokenLocation(),
@ -635,6 +662,8 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
} else if (token.isValue()) { } else if (token.isValue()) {
if (WEIGHT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { if (WEIGHT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
functionWeight = parser.floatValue(); functionWeight = parser.floatValue();
} else if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
functionName = parser.text();
} else { } else {
throw new ParsingException( throw new ParsingException(
parser.getTokenLocation(), parser.getTokenLocation(),
@ -652,6 +681,10 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
scoreFunction.setWeight(functionWeight); scoreFunction.setWeight(functionWeight);
} }
} }
if (functionName != null && scoreFunction != null) {
scoreFunction.setFunctionName(functionName);
}
} }
if (filter == null) { if (filter == null) {
filter = new MatchAllQueryBuilder(); filter = new MatchAllQueryBuilder();

View File

@ -33,9 +33,11 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParseField; import org.opensearch.common.ParseField;
import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.lucene.search.function.Functions;
import java.io.IOException; import java.io.IOException;
@ -49,10 +51,25 @@ public class GaussDecayFunctionBuilder extends DecayFunctionBuilder<GaussDecayFu
super(fieldName, origin, scale, offset); super(fieldName, origin, scale, offset);
} }
public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
super(fieldName, origin, scale, offset, functionName);
}
public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) { public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) {
super(fieldName, origin, scale, offset, decay); super(fieldName, origin, scale, offset, decay);
} }
public GaussDecayFunctionBuilder(
String fieldName,
Object origin,
Object scale,
Object offset,
double decay,
@Nullable String functionName
) {
super(fieldName, origin, scale, offset, decay, functionName);
}
GaussDecayFunctionBuilder(String fieldName, BytesReference functionBytes) { GaussDecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
super(fieldName, functionBytes); super(fieldName, functionBytes);
} }
@ -75,7 +92,6 @@ public class GaussDecayFunctionBuilder extends DecayFunctionBuilder<GaussDecayFu
} }
private static final class GaussScoreFunction implements DecayFunction { private static final class GaussScoreFunction implements DecayFunction {
@Override @Override
public double evaluate(double value, double scale) { public double evaluate(double value, double scale) {
// note that we already computed scale^2 in processScale() so we do // note that we already computed scale^2 in processScale() so we do
@ -84,8 +100,11 @@ public class GaussDecayFunctionBuilder extends DecayFunctionBuilder<GaussDecayFu
} }
@Override @Override
public Explanation explainFunction(String valueExpl, double value, double scale) { public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) {
return Explanation.match((float) evaluate(value, scale), "exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + ")"); return Explanation.match(
(float) evaluate(value, scale),
"exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + Functions.nameOrEmptyArg(functionName) + ")"
);
} }
@Override @Override

View File

@ -33,8 +33,10 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.common.Nullable;
import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.lucene.search.function.Functions;
import java.io.IOException; import java.io.IOException;
@ -47,10 +49,25 @@ public class LinearDecayFunctionBuilder extends DecayFunctionBuilder<LinearDecay
super(fieldName, origin, scale, offset); super(fieldName, origin, scale, offset);
} }
public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
super(fieldName, origin, scale, offset, functionName);
}
public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) { public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) {
super(fieldName, origin, scale, offset, decay); super(fieldName, origin, scale, offset, decay);
} }
public LinearDecayFunctionBuilder(
String fieldName,
Object origin,
Object scale,
Object offset,
double decay,
@Nullable String functionName
) {
super(fieldName, origin, scale, offset, decay, functionName);
}
LinearDecayFunctionBuilder(String fieldName, BytesReference functionBytes) { LinearDecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
super(fieldName, functionBytes); super(fieldName, functionBytes);
} }
@ -80,8 +97,11 @@ public class LinearDecayFunctionBuilder extends DecayFunctionBuilder<LinearDecay
} }
@Override @Override
public Explanation explainFunction(String valueExpl, double value, double scale) { public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) {
return Explanation.match((float) evaluate(value, scale), "max(0.0, ((" + scale + " - " + valueExpl + ")/" + scale + ")"); return Explanation.match(
(float) evaluate(value, scale),
"max(0.0, ((" + scale + " - " + valueExpl + ")/" + scale + Functions.nameOrEmptyArg(functionName) + ")"
);
} }
@Override @Override

View File

@ -31,6 +31,7 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParsingException; import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
@ -58,6 +59,10 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
public RandomScoreFunctionBuilder() {} public RandomScoreFunctionBuilder() {}
public RandomScoreFunctionBuilder(@Nullable String functionName) {
setFunctionName(functionName);
}
/** /**
* Read from a stream. * Read from a stream.
*/ */
@ -166,7 +171,7 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
final int salt = (context.index().getName().hashCode() << 10) | context.getShardId(); final int salt = (context.index().getName().hashCode() << 10) | context.getShardId();
if (seed == null) { if (seed == null) {
// DocID-based random score generation // DocID-based random score generation
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null); return new RandomScoreFunction(hash(context.nowInMillis()), salt, null, getFunctionName());
} else { } else {
final MappedFieldType fieldType; final MappedFieldType fieldType;
if (field != null) { if (field != null) {
@ -181,7 +186,7 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
if (fieldType == null) { if (fieldType == null) {
if (context.getMapperService().documentMapper() == null) { if (context.getMapperService().documentMapper() == null) {
// no mappings: the index is empty anyway // no mappings: the index is empty anyway
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null); return new RandomScoreFunction(hash(context.nowInMillis()), salt, null, getFunctionName());
} }
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Field [" + field + "] is not mapped on [" + context.index() + "] and cannot be used as a source of random numbers." "Field [" + field + "] is not mapped on [" + context.index() + "] and cannot be used as a source of random numbers."
@ -193,7 +198,7 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
} else { } else {
seed = hash(context.nowInMillis()); seed = hash(context.nowInMillis());
} }
return new RandomScoreFunction(seed, salt, context.getForField(fieldType)); return new RandomScoreFunction(seed, salt, context.getForField(fieldType), getFunctionName());
} }
} }
@ -231,6 +236,8 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
} }
} else if ("field".equals(currentFieldName)) { } else if ("field".equals(currentFieldName)) {
randomScoreFunctionBuilder.setField(parser.text()); randomScoreFunctionBuilder.setField(parser.text());
} else if (FunctionScoreQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
randomScoreFunctionBuilder.setFunctionName(parser.text());
} else { } else {
throw new ParsingException(parser.getTokenLocation(), NAME + " query does not support [" + currentFieldName + "]"); throw new ParsingException(parser.getTokenLocation(), NAME + " query does not support [" + currentFieldName + "]");
} }

View File

@ -32,6 +32,7 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.opensearch.Version;
import org.opensearch.common.io.stream.NamedWriteable; import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
@ -47,6 +48,7 @@ import java.util.Objects;
public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>> implements ToXContentFragment, NamedWriteable { public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>> implements ToXContentFragment, NamedWriteable {
private Float weight; private Float weight;
private String functionName;
/** /**
* Standard empty constructor. * Standard empty constructor.
@ -58,11 +60,17 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
*/ */
public ScoreFunctionBuilder(StreamInput in) throws IOException { public ScoreFunctionBuilder(StreamInput in) throws IOException {
weight = checkWeight(in.readOptionalFloat()); weight = checkWeight(in.readOptionalFloat());
if (in.getVersion().onOrAfter(Version.V_2_0_0)) {
functionName = in.readOptionalString();
}
} }
@Override @Override
public final void writeTo(StreamOutput out) throws IOException { public final void writeTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(weight); out.writeOptionalFloat(weight);
if (out.getVersion().onOrAfter(Version.V_2_0_0)) {
out.writeOptionalString(functionName);
}
doWriteTo(out); doWriteTo(out);
} }
@ -99,11 +107,30 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
return weight; return weight;
} }
/**
* The name of this function
*/
public String getFunctionName() {
return functionName;
}
/**
* Set the name of this function
*/
public void setFunctionName(String functionName) {
this.functionName = functionName;
}
@Override @Override
public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (weight != null) { if (weight != null) {
builder.field(FunctionScoreQueryBuilder.WEIGHT_FIELD.getPreferredName(), weight); builder.field(FunctionScoreQueryBuilder.WEIGHT_FIELD.getPreferredName(), weight);
} }
if (functionName != null) {
builder.field(FunctionScoreQueryBuilder.NAME_FIELD.getPreferredName(), functionName);
}
doXContent(builder, params); doXContent(builder, params);
return builder; return builder;
} }
@ -128,7 +155,7 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
FB other = (FB) obj; FB other = (FB) obj;
return Objects.equals(weight, other.getWeight()) && doEquals(other); return Objects.equals(weight, other.getWeight()) && Objects.equals(functionName, other.getFunctionName()) && doEquals(other);
} }
/** /**
@ -139,7 +166,7 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
@Override @Override
public final int hashCode() { public final int hashCode() {
return Objects.hash(getClass(), weight, doHashCode()); return Objects.hash(getClass(), weight, functionName, doHashCode());
} }
/** /**
@ -156,7 +183,7 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
if (weight == null) { if (weight == null) {
return scoreFunction; return scoreFunction;
} }
return new WeightFactorFunction(weight, scoreFunction); return new WeightFactorFunction(weight, scoreFunction, getFunctionName());
} }
/** /**

View File

@ -32,6 +32,7 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.opensearch.common.Nullable;
import org.opensearch.script.Script; import org.opensearch.script.Script;
import org.opensearch.script.ScriptType; import org.opensearch.script.ScriptType;
@ -46,10 +47,29 @@ public class ScoreFunctionBuilders {
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, null); return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, null);
} }
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
String fieldName,
Object origin,
Object scale,
@Nullable String functionName
) {
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, null, functionName);
}
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(String fieldName, Object origin, Object scale, Object offset) { public static ExponentialDecayFunctionBuilder exponentialDecayFunction(String fieldName, Object origin, Object scale, Object offset) {
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset); return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset);
} }
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
String fieldName,
Object origin,
Object scale,
Object offset,
@Nullable String functionName
) {
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, functionName);
}
public static ExponentialDecayFunctionBuilder exponentialDecayFunction( public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
String fieldName, String fieldName,
Object origin, Object origin,
@ -60,10 +80,30 @@ public class ScoreFunctionBuilders {
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, decay); return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, decay);
} }
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
String fieldName,
Object origin,
Object scale,
Object offset,
double decay,
@Nullable String functionName
) {
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName);
}
public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale) { public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale) {
return new GaussDecayFunctionBuilder(fieldName, origin, scale, null); return new GaussDecayFunctionBuilder(fieldName, origin, scale, null);
} }
public static GaussDecayFunctionBuilder gaussDecayFunction(
String fieldName,
Object origin,
Object scale,
@Nullable String functionName
) {
return new GaussDecayFunctionBuilder(fieldName, origin, scale, null, functionName);
}
public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale, Object offset) { public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale, Object offset) {
return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset); return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset);
} }
@ -72,6 +112,26 @@ public class ScoreFunctionBuilders {
return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset, decay); return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset, decay);
} }
public static GaussDecayFunctionBuilder gaussDecayFunction(
String fieldName,
Object origin,
Object scale,
Object offset,
double decay,
@Nullable String functionName
) {
return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName);
}
public static LinearDecayFunctionBuilder linearDecayFunction(
String fieldName,
Object origin,
Object scale,
@Nullable String functionName
) {
return new LinearDecayFunctionBuilder(fieldName, origin, scale, null, functionName);
}
public static LinearDecayFunctionBuilder linearDecayFunction(String fieldName, Object origin, Object scale) { public static LinearDecayFunctionBuilder linearDecayFunction(String fieldName, Object origin, Object scale) {
return new LinearDecayFunctionBuilder(fieldName, origin, scale, null); return new LinearDecayFunctionBuilder(fieldName, origin, scale, null);
} }
@ -80,6 +140,16 @@ public class ScoreFunctionBuilders {
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset); return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset);
} }
public static LinearDecayFunctionBuilder linearDecayFunction(
String fieldName,
Object origin,
Object scale,
Object offset,
@Nullable String functionName
) {
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, functionName);
}
public static LinearDecayFunctionBuilder linearDecayFunction( public static LinearDecayFunctionBuilder linearDecayFunction(
String fieldName, String fieldName,
Object origin, Object origin,
@ -90,23 +160,54 @@ public class ScoreFunctionBuilders {
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, decay); return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, decay);
} }
public static LinearDecayFunctionBuilder linearDecayFunction(
String fieldName,
Object origin,
Object scale,
Object offset,
double decay,
@Nullable String functionName
) {
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName);
}
public static ScriptScoreFunctionBuilder scriptFunction(Script script) { public static ScriptScoreFunctionBuilder scriptFunction(Script script) {
return (new ScriptScoreFunctionBuilder(script)); return scriptFunction(script, null);
} }
public static ScriptScoreFunctionBuilder scriptFunction(String script) { public static ScriptScoreFunctionBuilder scriptFunction(String script) {
return (new ScriptScoreFunctionBuilder(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, script, emptyMap()))); return scriptFunction(script, null);
} }
public static RandomScoreFunctionBuilder randomFunction() { public static RandomScoreFunctionBuilder randomFunction() {
return new RandomScoreFunctionBuilder(); return randomFunction(null);
} }
public static WeightBuilder weightFactorFunction(float weight) { public static WeightBuilder weightFactorFunction(float weight) {
return (WeightBuilder) (new WeightBuilder().setWeight(weight)); return weightFactorFunction(weight, null);
} }
public static FieldValueFactorFunctionBuilder fieldValueFactorFunction(String fieldName) { public static FieldValueFactorFunctionBuilder fieldValueFactorFunction(String fieldName) {
return new FieldValueFactorFunctionBuilder(fieldName); return fieldValueFactorFunction(fieldName, null);
}
public static ScriptScoreFunctionBuilder scriptFunction(Script script, @Nullable String functionName) {
return new ScriptScoreFunctionBuilder(script, functionName);
}
public static ScriptScoreFunctionBuilder scriptFunction(String script, @Nullable String functionName) {
return new ScriptScoreFunctionBuilder(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, script, emptyMap()), functionName);
}
public static RandomScoreFunctionBuilder randomFunction(@Nullable String functionName) {
return new RandomScoreFunctionBuilder(functionName);
}
public static WeightBuilder weightFactorFunction(float weight, @Nullable String functionName) {
return (WeightBuilder) (new WeightBuilder(functionName).setWeight(weight));
}
public static FieldValueFactorFunctionBuilder fieldValueFactorFunction(String fieldName, @Nullable String functionName) {
return new FieldValueFactorFunctionBuilder(fieldName, functionName);
} }
} }

View File

@ -32,6 +32,7 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParsingException; import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
@ -57,10 +58,15 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder<ScriptScore
private final Script script; private final Script script;
public ScriptScoreFunctionBuilder(Script script) { public ScriptScoreFunctionBuilder(Script script) {
this(script, null);
}
public ScriptScoreFunctionBuilder(Script script, @Nullable String functionName) {
if (script == null) { if (script == null) {
throw new IllegalArgumentException("script must not be null"); throw new IllegalArgumentException("script must not be null");
} }
this.script = script; this.script = script;
setFunctionName(functionName);
} }
/** /**
@ -112,7 +118,8 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder<ScriptScore
searchScript, searchScript,
context.index().getName(), context.index().getName(),
context.getShardId(), context.getShardId(),
context.indexVersionCreated() context.indexVersionCreated(),
getFunctionName()
); );
} catch (Exception e) { } catch (Exception e) {
throw new QueryShardException(context, "script_score: the script could not be loaded", e); throw new QueryShardException(context, "script_score: the script could not be loaded", e);

View File

@ -195,9 +195,11 @@ public class ScriptScoreQueryBuilder extends AbstractQueryBuilder<ScriptScoreQue
} }
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT); ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup()); ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup());
Query query = this.query.toQuery(context); final QueryBuilder queryBuilder = this.query;
Query query = queryBuilder.toQuery(context);
return new ScriptScoreQuery( return new ScriptScoreQuery(
query, query,
queryBuilder.queryName(),
script, script,
scoreScriptFactory, scoreScriptFactory,
minScore, minScore,

View File

@ -32,6 +32,7 @@
package org.opensearch.index.query.functionscore; package org.opensearch.index.query.functionscore;
import org.opensearch.common.Nullable;
import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.lucene.search.function.ScoreFunction; import org.opensearch.common.lucene.search.function.ScoreFunction;
@ -51,6 +52,13 @@ public class WeightBuilder extends ScoreFunctionBuilder<WeightBuilder> {
*/ */
public WeightBuilder() {} public WeightBuilder() {}
/**
* Standard constructor.
*/
public WeightBuilder(@Nullable String functionName) {
setFunctionName(functionName);
}
/** /**
* Read from a stream. * Read from a stream.
*/ */

View File

@ -33,6 +33,7 @@
package org.opensearch.script; package org.opensearch.script;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.opensearch.common.Nullable;
import java.io.IOException; import java.io.IOException;
@ -49,7 +50,21 @@ public interface ExplainableScoreScript {
* want to explain how that was computed. * want to explain how that was computed.
* *
* @param subQueryScore the Explanation for _score * @param subQueryScore the Explanation for _score
* @deprecated please use {@code explain(Explanation subQueryScore, @Nullable String scriptName)}
*/ */
@Deprecated
Explanation explain(Explanation subQueryScore) throws IOException; Explanation explain(Explanation subQueryScore) throws IOException;
/**
* Build the explanation of the current document being scored
* The script score needs the Explanation of the sub query score because it might use _score and
* want to explain how that was computed.
*
* @param subQueryScore the Explanation for _score
* @param scriptName the script name
*/
default Explanation explain(Explanation subQueryScore, @Nullable String scriptName) throws IOException {
return explain(subQueryScore);
}
} }

View File

@ -88,6 +88,7 @@ import java.util.Collection;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.endsWith;
import static org.hamcrest.core.Is.is; import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.IsEqual.equalTo;
import static org.hamcrest.core.IsNot.not; import static org.hamcrest.core.IsNot.not;
@ -283,7 +284,8 @@ public class FunctionScoreTests extends OpenSearchTestCase {
0, 0,
GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION,
new IndexNumericFieldDataStub(), new IndexNumericFieldDataStub(),
MultiValueMode.MAX MultiValueMode.MAX,
null
); );
private static final ScoreFunction EXP_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction( private static final ScoreFunction EXP_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
0, 0,
@ -292,7 +294,8 @@ public class FunctionScoreTests extends OpenSearchTestCase {
0, 0,
ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION,
new IndexNumericFieldDataStub(), new IndexNumericFieldDataStub(),
MultiValueMode.MAX MultiValueMode.MAX,
null
); );
private static final ScoreFunction LIN_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction( private static final ScoreFunction LIN_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
0, 0,
@ -301,7 +304,48 @@ public class FunctionScoreTests extends OpenSearchTestCase {
0, 0,
LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION,
new IndexNumericFieldDataStub(), new IndexNumericFieldDataStub(),
MultiValueMode.MAX MultiValueMode.MAX,
null
);
private static final ScoreFunction RANDOM_SCORE_FUNCTION_NAMED = new RandomScoreFunction(0, 0, new IndexFieldDataStub(), "func1");
private static final ScoreFunction FIELD_VALUE_FACTOR_FUNCTION_NAMED = new FieldValueFactorFunction(
"test",
1,
FieldValueFactorFunction.Modifier.LN,
1.0,
null,
"func1"
);
private static final ScoreFunction GAUSS_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
0,
1,
0.1,
0,
GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION,
new IndexNumericFieldDataStub(),
MultiValueMode.MAX,
"func1"
);
private static final ScoreFunction EXP_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
0,
1,
0.1,
0,
ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION,
new IndexNumericFieldDataStub(),
MultiValueMode.MAX,
"func1"
);
private static final ScoreFunction LIN_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
0,
1,
0.1,
0,
LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION,
new IndexNumericFieldDataStub(),
MultiValueMode.MAX,
"func1"
); );
private static final ScoreFunction WEIGHT_FACTOR_FUNCTION = new WeightFactorFunction(4); private static final ScoreFunction WEIGHT_FACTOR_FUNCTION = new WeightFactorFunction(4);
private static final String TEXT = "The way out is through."; private static final String TEXT = "The way out is through.";
@ -383,6 +427,58 @@ public class FunctionScoreTests extends OpenSearchTestCase {
assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0)); assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));
} }
public void testExplainFunctionScoreQueryWithName() throws IOException {
Explanation functionExplanation = getFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION_NAMED);
checkFunctionScoreExplanation(functionExplanation, "random score function (seed: 0, field: test, _name: func1)");
assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION_NAMED);
checkFunctionScoreExplanation(functionExplanation, "field value function(_name: func1): ln(doc['test'].value?:1.0 * factor=1.0)");
assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, GAUSS_DECAY_FUNCTION_NAMED);
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
assertThat(
functionExplanation.getDetails()[0].getDetails()[0].toString(),
equalTo(
"0.1 = exp(-0.5*pow(MAX[Math.max(Math.abs"
+ "(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.21714724095162594, _name: func1)\n"
)
);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, EXP_DECAY_FUNCTION_NAMED);
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
assertThat(
functionExplanation.getDetails()[0].getDetails()[0].toString(),
equalTo(
"0.1 = exp(- MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)] * 2.3025850929940455, _name: func1)\n"
)
);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, LIN_DECAY_FUNCTION_NAMED);
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
assertThat(
functionExplanation.getDetails()[0].getDetails()[0].toString(),
equalTo(
"0.1 = max(0.0, ((1.1111111111111112"
+ " - MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)])/1.1111111111111112, _name: func1)\n"
)
);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, new WeightFactorFunction(4, RANDOM_SCORE_FUNCTION_NAMED));
checkFunctionScoreExplanation(functionExplanation, "product of:");
assertThat(
functionExplanation.getDetails()[0].getDetails()[0].toString(),
endsWith("random score function (seed: 0, field: test, _name: func1)\n")
);
assertThat(functionExplanation.getDetails()[0].getDetails()[1].toString(), equalTo("4.0 = weight\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));
}
public Explanation getFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction scoreFunction) throws IOException { public Explanation getFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction scoreFunction) throws IOException {
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new TermQuery(TERM), scoreFunction, CombineFunction.AVG, 0.0f, 100); FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new TermQuery(TERM), scoreFunction, CombineFunction.AVG, 0.0f, 100);
Weight weight = searcher.createWeight(searcher.rewrite(functionScoreQuery), org.apache.lucene.search.ScoreMode.COMPLETE, 1f); Weight weight = searcher.createWeight(searcher.rewrite(functionScoreQuery), org.apache.lucene.search.ScoreMode.COMPLETE, 1f);

View File

@ -110,6 +110,34 @@ public class ScriptScoreQueryTests extends OpenSearchTestCase {
assertThat(explanation.getValue(), equalTo(1.0)); assertThat(explanation.getValue(), equalTo(1.0));
} }
public void testExplainWithName() throws IOException {
Script script = new Script("script using explain");
ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> {
assertNotNull(explanation);
explanation.set("this explains the score");
return 1.0;
});
ScriptScoreQuery query = new ScriptScoreQuery(
Queries.newMatchAllQuery(),
"query1",
script,
factory,
null,
"index",
0,
Version.CURRENT
);
Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
Explanation explanation = weight.explain(leafReaderContext, 0);
assertNotNull(explanation);
assertThat(explanation.getDescription(), equalTo("this explains the score"));
assertThat(explanation.getValue(), equalTo(1.0));
assertThat(explanation.getDetails(), arrayWithSize(1));
assertThat(explanation.getDetails()[0].getDescription(), equalTo("*:* (_name: query1)"));
}
public void testExplainDefault() throws IOException { public void testExplainDefault() throws IOException {
Script script = new Script("script without setting explanation"); Script script = new Script("script without setting explanation");
ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> 1.5); ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> 1.5);