Add '_name' field support to score functions and provide it back in explanation response (#2244)
* Add '_name' field support to score functions and provide it back in explanation response Signed-off-by: Andriy Redko <andriy.redko@aiven.io> * Address code review comments Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
This commit is contained in:
parent
ae14259a2c
commit
5f90227a05
|
@ -51,6 +51,7 @@ import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder;
|
|||
import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder;
|
||||
import org.opensearch.index.query.functionscore.ScoreFunctionBuilders;
|
||||
import org.opensearch.search.MultiValueMode;
|
||||
import org.opensearch.search.SearchHit;
|
||||
import org.opensearch.search.SearchHits;
|
||||
import org.opensearch.test.OpenSearchIntegTestCase;
|
||||
import org.opensearch.test.VersionUtils;
|
||||
|
@ -77,7 +78,9 @@ import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures
|
|||
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertOrderedSearchHits;
|
||||
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchHits;
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
import static org.hamcrest.Matchers.arrayWithSize;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
|
@ -616,6 +619,76 @@ public class DecayFunctionScoreIT extends OpenSearchIntegTestCase {
|
|||
|
||||
}
|
||||
|
||||
public void testCombineModesExplain() throws Exception {
|
||||
assertAcked(
|
||||
prepareCreate("test").addMapping(
|
||||
"type1",
|
||||
jsonBuilder().startObject()
|
||||
.startObject("type1")
|
||||
.startObject("properties")
|
||||
.startObject("test")
|
||||
.field("type", "text")
|
||||
.endObject()
|
||||
.startObject("num")
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
)
|
||||
);
|
||||
|
||||
client().prepareIndex()
|
||||
.setId("1")
|
||||
.setIndex("test")
|
||||
.setRefreshPolicy(IMMEDIATE)
|
||||
.setSource(jsonBuilder().startObject().field("test", "value value").field("num", 1.0).endObject())
|
||||
.get();
|
||||
|
||||
FunctionScoreQueryBuilder baseQuery = functionScoreQuery(
|
||||
constantScoreQuery(termQuery("test", "value")).queryName("query1"),
|
||||
ScoreFunctionBuilders.weightFactorFunction(2, "weight1")
|
||||
);
|
||||
// decay score should return 0.5 for this function and baseQuery should return 2.0f as it's score
|
||||
ActionFuture<SearchResponse> response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(
|
||||
searchSource().explain(true)
|
||||
.query(
|
||||
functionScoreQuery(baseQuery, gaussDecayFunction("num", 0.0, 1.0, null, 0.5, "func2")).boostMode(
|
||||
CombineFunction.MULTIPLY
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
SearchResponse sr = response.actionGet();
|
||||
SearchHits sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits().value, equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat(sh.getAt(0).getExplanation().getDetails(), arrayWithSize(2));
|
||||
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails(), arrayWithSize(2));
|
||||
// "description": "ConstantScore(test:value) (_name: query1)"
|
||||
assertThat(
|
||||
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[0].getDescription(),
|
||||
equalTo("ConstantScore(test:value) (_name: query1)")
|
||||
);
|
||||
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails(), arrayWithSize(2));
|
||||
assertThat(sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(2));
|
||||
// "description": "constant score 1.0(_name: func1) - no function provided"
|
||||
assertThat(
|
||||
sh.getAt(0).getExplanation().getDetails()[0].getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
|
||||
equalTo("constant score 1.0(_name: weight1) - no function provided")
|
||||
);
|
||||
// "description": "exp(-0.5*pow(MIN[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.7213475204444817,
|
||||
// _name: func2)"
|
||||
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
|
||||
assertThat(sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
|
||||
assertThat(
|
||||
sh.getAt(0).getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription(),
|
||||
containsString("_name: func2")
|
||||
);
|
||||
}
|
||||
|
||||
public void testExceptionThrownIfScaleLE0() throws Exception {
|
||||
assertAcked(
|
||||
prepareCreate("test").addMapping(
|
||||
|
@ -1195,4 +1268,132 @@ public class DecayFunctionScoreIT extends OpenSearchIntegTestCase {
|
|||
sh = sr.getHits();
|
||||
assertThat((double) (sh.getAt(0).getScore()), closeTo((sh.getAt(1).getScore()), 1.e-6d));
|
||||
}
|
||||
|
||||
public void testDistanceScoreGeoLinGaussExplain() throws Exception {
|
||||
assertAcked(
|
||||
prepareCreate("test").addMapping(
|
||||
"type1",
|
||||
jsonBuilder().startObject()
|
||||
.startObject("type1")
|
||||
.startObject("properties")
|
||||
.startObject("test")
|
||||
.field("type", "text")
|
||||
.endObject()
|
||||
.startObject("loc")
|
||||
.field("type", "geo_point")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
)
|
||||
);
|
||||
|
||||
List<IndexRequestBuilder> indexBuilders = new ArrayList<>();
|
||||
indexBuilders.add(
|
||||
client().prepareIndex()
|
||||
.setId("1")
|
||||
.setIndex("test")
|
||||
.setSource(
|
||||
jsonBuilder().startObject()
|
||||
.field("test", "value")
|
||||
.startObject("loc")
|
||||
.field("lat", 10)
|
||||
.field("lon", 20)
|
||||
.endObject()
|
||||
.endObject()
|
||||
)
|
||||
);
|
||||
indexBuilders.add(
|
||||
client().prepareIndex()
|
||||
.setId("2")
|
||||
.setIndex("test")
|
||||
.setSource(
|
||||
jsonBuilder().startObject()
|
||||
.field("test", "value")
|
||||
.startObject("loc")
|
||||
.field("lat", 11)
|
||||
.field("lon", 22)
|
||||
.endObject()
|
||||
.endObject()
|
||||
)
|
||||
);
|
||||
|
||||
indexRandom(true, indexBuilders);
|
||||
|
||||
// Test Gauss
|
||||
List<Float> lonlat = new ArrayList<>();
|
||||
lonlat.add(20f);
|
||||
lonlat.add(11f);
|
||||
|
||||
final String queryName = "query1";
|
||||
final String functionName = "func1";
|
||||
ActionFuture<SearchResponse> response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(
|
||||
searchSource().explain(true)
|
||||
.query(
|
||||
functionScoreQuery(baseQuery.queryName(queryName), gaussDecayFunction("loc", lonlat, "1000km", functionName))
|
||||
)
|
||||
)
|
||||
);
|
||||
SearchResponse sr = response.actionGet();
|
||||
SearchHits sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits().value, equalTo(2L));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat(sh.getAt(1).getId(), equalTo("2"));
|
||||
assertExplain(queryName, functionName, sr);
|
||||
|
||||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(
|
||||
searchSource().explain(true)
|
||||
.query(
|
||||
functionScoreQuery(baseQuery.queryName(queryName), linearDecayFunction("loc", lonlat, "1000km", functionName))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits().value, equalTo(2L));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat(sh.getAt(1).getId(), equalTo("2"));
|
||||
assertExplain(queryName, functionName, sr);
|
||||
|
||||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(
|
||||
searchSource().explain(true)
|
||||
.query(
|
||||
functionScoreQuery(
|
||||
baseQuery.queryName(queryName),
|
||||
exponentialDecayFunction("loc", lonlat, "1000km", functionName)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits().value, equalTo(2L));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat(sh.getAt(1).getId(), equalTo("2"));
|
||||
assertExplain(queryName, functionName, sr);
|
||||
}
|
||||
|
||||
private void assertExplain(final String queryName, final String functionName, SearchResponse sr) {
|
||||
SearchHit firstHit = sr.getHits().getAt(0);
|
||||
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
|
||||
// "description": "*:* (_name: query1)"
|
||||
assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName));
|
||||
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
|
||||
// "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)"
|
||||
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails(), arrayWithSize(1));
|
||||
// "description": "exp(-0.5*pow(MIN of: [Math.max(arcDistance(10.999999972991645, 21.99999994598329(=doc value),11.0, 20.0(=origin))
|
||||
// - 0.0(=offset), 0)],2.0)/7.213475204444817E11, _name: func1)"
|
||||
assertThat(
|
||||
firstHit.getExplanation().getDetails()[1].getDetails()[0].getDetails()[0].getDescription().toString(),
|
||||
containsString("_name: " + functionName)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.opensearch.action.index.IndexRequestBuilder;
|
|||
import org.opensearch.action.search.SearchResponse;
|
||||
import org.opensearch.action.search.SearchType;
|
||||
import org.opensearch.common.lucene.search.function.CombineFunction;
|
||||
import org.opensearch.common.lucene.search.function.Functions;
|
||||
import org.opensearch.common.settings.Settings;
|
||||
import org.opensearch.index.fielddata.ScriptDocValues;
|
||||
import org.opensearch.plugins.Plugin;
|
||||
|
@ -72,6 +73,7 @@ import static org.opensearch.index.query.QueryBuilders.functionScoreQuery;
|
|||
import static org.opensearch.index.query.QueryBuilders.termQuery;
|
||||
import static org.opensearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction;
|
||||
import static org.opensearch.search.builder.SearchSourceBuilder.searchSource;
|
||||
import static org.hamcrest.Matchers.arrayWithSize;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
@ -121,8 +123,17 @@ public class ExplainableScriptIT extends OpenSearchIntegTestCase {
|
|||
|
||||
@Override
|
||||
public Explanation explain(Explanation subQueryScore) throws IOException {
|
||||
return explain(subQueryScore, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(Explanation subQueryScore, String functionName) throws IOException {
|
||||
Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
|
||||
return Explanation.match((float) (execute(null)), "This script returned " + execute(null), scoreExp);
|
||||
return Explanation.match(
|
||||
(float) (execute(null)),
|
||||
"This script" + Functions.nameOrEmptyFunc(functionName) + " returned " + execute(null),
|
||||
scoreExp
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -174,4 +185,36 @@ public class ExplainableScriptIT extends OpenSearchIntegTestCase {
|
|||
idCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
public void testExplainScriptWithName() throws InterruptedException, IOException, ExecutionException {
|
||||
List<IndexRequestBuilder> indexRequests = new ArrayList<>();
|
||||
indexRequests.add(
|
||||
client().prepareIndex("test")
|
||||
.setId(Integer.toString(1))
|
||||
.setSource(jsonBuilder().startObject().field("number_field", 1).field("text", "text").endObject())
|
||||
);
|
||||
indexRandom(true, true, indexRequests);
|
||||
client().admin().indices().prepareRefresh().get();
|
||||
ensureYellow();
|
||||
SearchResponse response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(
|
||||
searchSource().explain(true)
|
||||
.query(
|
||||
functionScoreQuery(
|
||||
termQuery("text", "text"),
|
||||
scriptFunction(new Script(ScriptType.INLINE, "test", "explainable_script", Collections.emptyMap()), "func1")
|
||||
).boostMode(CombineFunction.REPLACE)
|
||||
)
|
||||
)
|
||||
).actionGet();
|
||||
|
||||
OpenSearchAssertions.assertNoFailures(response);
|
||||
SearchHits hits = response.getHits();
|
||||
assertThat(hits.getTotalHits().value, equalTo(1L));
|
||||
assertThat(hits.getHits()[0].getId(), equalTo("1"));
|
||||
assertThat(hits.getHits()[0].getExplanation().getDetails(), arrayWithSize(2));
|
||||
assertThat(hits.getHits()[0].getExplanation().getDetails()[0].getDescription(), containsString("_name: func1"));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -35,10 +35,13 @@ package org.opensearch.search.functionscore;
|
|||
import org.opensearch.action.search.SearchPhaseExecutionException;
|
||||
import org.opensearch.action.search.SearchResponse;
|
||||
import org.opensearch.common.lucene.search.function.FieldValueFactorFunction;
|
||||
import org.opensearch.search.SearchHit;
|
||||
import org.opensearch.test.OpenSearchIntegTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.arrayWithSize;
|
||||
import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
|
||||
import static org.opensearch.index.query.QueryBuilders.functionScoreQuery;
|
||||
import static org.opensearch.index.query.QueryBuilders.matchAllQuery;
|
||||
|
@ -163,4 +166,47 @@ public class FunctionScoreFieldValueIT extends OpenSearchIntegTestCase {
|
|||
// locally, instead of just having failures
|
||||
}
|
||||
}
|
||||
|
||||
public void testFieldValueFactorExplain() throws IOException {
|
||||
assertAcked(
|
||||
prepareCreate("test").addMapping(
|
||||
"type1",
|
||||
jsonBuilder().startObject()
|
||||
.startObject("type1")
|
||||
.startObject("properties")
|
||||
.startObject("test")
|
||||
.field("type", randomFrom(new String[] { "short", "float", "long", "integer", "double" }))
|
||||
.endObject()
|
||||
.startObject("body")
|
||||
.field("type", "text")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
).get()
|
||||
);
|
||||
|
||||
client().prepareIndex("test").setId("1").setSource("test", 5, "body", "foo").get();
|
||||
client().prepareIndex("test").setId("2").setSource("test", 17, "body", "foo").get();
|
||||
client().prepareIndex("test").setId("3").setSource("body", "bar").get();
|
||||
|
||||
refresh();
|
||||
|
||||
// document 2 scores higher because 17 > 5
|
||||
final String functionName = "func1";
|
||||
final String queryName = "query";
|
||||
SearchResponse response = client().prepareSearch("test")
|
||||
.setExplain(true)
|
||||
.setQuery(
|
||||
functionScoreQuery(simpleQueryStringQuery("foo").queryName(queryName), fieldValueFactorFunction("test", functionName))
|
||||
)
|
||||
.get();
|
||||
assertOrderedSearchHits(response, "2", "1");
|
||||
SearchHit firstHit = response.getHits().getAt(0);
|
||||
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
|
||||
// "description": "sum of: (_name: query)"
|
||||
assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: " + queryName));
|
||||
// "description": "field value function(_name: func1): none(doc['test'].value * factor=1.0)"
|
||||
assertThat(firstHit.getExplanation().getDetails()[1].toString(), containsString("_name: " + functionName));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ import org.opensearch.plugins.Plugin;
|
|||
import org.opensearch.script.MockScriptPlugin;
|
||||
import org.opensearch.script.Script;
|
||||
import org.opensearch.script.ScriptType;
|
||||
import org.opensearch.search.SearchHit;
|
||||
import org.opensearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.opensearch.test.OpenSearchIntegTestCase;
|
||||
import org.opensearch.test.OpenSearchTestCase;
|
||||
|
@ -66,6 +67,8 @@ import static org.opensearch.search.aggregations.AggregationBuilders.terms;
|
|||
import static org.opensearch.search.builder.SearchSourceBuilder.searchSource;
|
||||
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
|
||||
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse;
|
||||
import static org.hamcrest.Matchers.arrayWithSize;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
|
@ -140,6 +143,35 @@ public class FunctionScoreIT extends OpenSearchIntegTestCase {
|
|||
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1L));
|
||||
}
|
||||
|
||||
public void testScriptScoresWithAggWithExplain() throws IOException {
|
||||
createIndex(INDEX);
|
||||
index(INDEX, TYPE, "1", jsonBuilder().startObject().field("dummy_field", 1).endObject());
|
||||
refresh();
|
||||
|
||||
Script script = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "get score value", Collections.emptyMap());
|
||||
|
||||
SearchResponse response = client().search(
|
||||
searchRequest().source(
|
||||
searchSource().explain(true)
|
||||
.query(functionScoreQuery(scriptFunction(script, "func1"), "query1"))
|
||||
.aggregation(terms("score_agg").script(script))
|
||||
)
|
||||
).actionGet();
|
||||
assertSearchResponse(response);
|
||||
|
||||
final SearchHit firstHit = response.getHits().getAt(0);
|
||||
assertThat(firstHit.getScore(), equalTo(1.0f));
|
||||
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
|
||||
// "description": "*:* (_name: query1)"
|
||||
assertThat(firstHit.getExplanation().getDetails()[0].getDescription(), containsString("_name: query1"));
|
||||
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
|
||||
// "description": "script score function(_name: func1), computed with script:\"Script{ ... }\""
|
||||
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDescription(), containsString("_name: func1"));
|
||||
|
||||
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getKeyAsString(), equalTo("1.0"));
|
||||
assertThat(((Terms) response.getAggregations().asMap().get("score_agg")).getBuckets().get(0).getDocCount(), is(1L));
|
||||
}
|
||||
|
||||
public void testMinScoreFunctionScoreBasic() throws IOException {
|
||||
float score = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat);
|
||||
float minScore = randomValueOtherThanMany((f) -> Float.compare(f, 0) < 0, OpenSearchTestCase::randomFloat);
|
||||
|
|
|
@ -171,7 +171,7 @@ public class FunctionScorePluginIT extends OpenSearchIntegTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Explanation explainFunction(String distanceString, double distanceVal, double scale) {
|
||||
public Explanation explainFunction(String distanceString, double distanceVal, double scale, String functionName) {
|
||||
return Explanation.match((float) distanceVal, "" + distanceVal);
|
||||
}
|
||||
|
||||
|
|
|
@ -63,6 +63,7 @@ import static org.opensearch.script.MockScriptPlugin.NAME;
|
|||
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
|
||||
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures;
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.arrayWithSize;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
@ -289,6 +290,37 @@ public class RandomScoreFunctionIT extends OpenSearchIntegTestCase {
|
|||
assertThat(firstHit.getExplanation().toString(), containsString("" + seed));
|
||||
}
|
||||
|
||||
public void testSeedAndNameReportedInExplain() throws Exception {
|
||||
createIndex("test");
|
||||
ensureGreen();
|
||||
index("test", "type", "1", jsonBuilder().startObject().endObject());
|
||||
flush();
|
||||
refresh();
|
||||
|
||||
int seed = 12345678;
|
||||
|
||||
final String queryName = "query1";
|
||||
final String functionName = "func1";
|
||||
SearchResponse resp = client().prepareSearch("test")
|
||||
.setQuery(
|
||||
functionScoreQuery(
|
||||
matchAllQuery().queryName(queryName),
|
||||
randomFunction(functionName).seed(seed).setField(SeqNoFieldMapper.NAME)
|
||||
)
|
||||
)
|
||||
.setExplain(true)
|
||||
.get();
|
||||
assertNoFailures(resp);
|
||||
assertEquals(1, resp.getHits().getTotalHits().value);
|
||||
SearchHit firstHit = resp.getHits().getAt(0);
|
||||
assertThat(firstHit.getExplanation().getDetails(), arrayWithSize(2));
|
||||
// "description": "*:* (_name: query1)"
|
||||
assertThat(firstHit.getExplanation().getDetails()[0].getDescription().toString(), containsString("_name: " + queryName));
|
||||
assertThat(firstHit.getExplanation().getDetails()[1].getDetails(), arrayWithSize(2));
|
||||
// "description": "random score function (seed: 12345678, field: _seq_no, _name: func1)"
|
||||
assertThat(firstHit.getExplanation().getDetails()[1].getDetails()[0].getDescription().toString(), containsString("seed: " + seed));
|
||||
}
|
||||
|
||||
public void testNoDocs() throws Exception {
|
||||
createIndex("test");
|
||||
ensureGreen();
|
||||
|
|
|
@ -35,6 +35,7 @@ package org.opensearch.common.lucene.search.function;
|
|||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.OpenSearchException;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
import org.opensearch.common.io.stream.Writeable;
|
||||
|
@ -55,6 +56,8 @@ public class FieldValueFactorFunction extends ScoreFunction {
|
|||
private final String field;
|
||||
private final float boostFactor;
|
||||
private final Modifier modifier;
|
||||
private final String functionName;
|
||||
|
||||
/**
|
||||
* Value used if the document is missing the field.
|
||||
*/
|
||||
|
@ -67,6 +70,17 @@ public class FieldValueFactorFunction extends ScoreFunction {
|
|||
Modifier modifierType,
|
||||
Double missing,
|
||||
IndexNumericFieldData indexFieldData
|
||||
) {
|
||||
this(field, boostFactor, modifierType, missing, indexFieldData, null);
|
||||
}
|
||||
|
||||
public FieldValueFactorFunction(
|
||||
String field,
|
||||
float boostFactor,
|
||||
Modifier modifierType,
|
||||
Double missing,
|
||||
IndexNumericFieldData indexFieldData,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(CombineFunction.MULTIPLY);
|
||||
this.field = field;
|
||||
|
@ -74,6 +88,7 @@ public class FieldValueFactorFunction extends ScoreFunction {
|
|||
this.modifier = modifierType;
|
||||
this.indexFieldData = indexFieldData;
|
||||
this.missing = missing;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -127,7 +142,7 @@ public class FieldValueFactorFunction extends ScoreFunction {
|
|||
(float) score,
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"field value function: %s(doc['%s'].value%s * factor=%s)",
|
||||
"field value function" + Functions.nameOrEmptyFunc(functionName) + ": %s(doc['%s'].value%s * factor=%s)",
|
||||
modifierStr,
|
||||
field,
|
||||
defaultStr,
|
||||
|
|
|
@ -46,6 +46,7 @@ import org.apache.lucene.search.ScorerSupplier;
|
|||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.opensearch.OpenSearchException;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
import org.opensearch.common.io.stream.Writeable;
|
||||
|
@ -70,11 +71,28 @@ public class FunctionScoreQuery extends Query {
|
|||
public static class FilterScoreFunction extends ScoreFunction {
|
||||
public final Query filter;
|
||||
public final ScoreFunction function;
|
||||
public final String queryName;
|
||||
|
||||
/**
|
||||
* Creates a FilterScoreFunction with query and function.
|
||||
* @param filter filter query
|
||||
* @param function score function
|
||||
*/
|
||||
public FilterScoreFunction(Query filter, ScoreFunction function) {
|
||||
this(filter, function, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a FilterScoreFunction with query and function.
|
||||
* @param filter filter query
|
||||
* @param function score function
|
||||
* @param queryName filter query name
|
||||
*/
|
||||
public FilterScoreFunction(Query filter, ScoreFunction function, @Nullable String queryName) {
|
||||
super(function.getDefaultScoreCombiner());
|
||||
this.filter = filter;
|
||||
this.function = function;
|
||||
this.queryName = queryName;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -93,12 +111,14 @@ public class FunctionScoreQuery extends Query {
|
|||
return false;
|
||||
}
|
||||
FilterScoreFunction that = (FilterScoreFunction) other;
|
||||
return Objects.equals(this.filter, that.filter) && Objects.equals(this.function, that.function);
|
||||
return Objects.equals(this.filter, that.filter)
|
||||
&& Objects.equals(this.function, that.function)
|
||||
&& Objects.equals(this.queryName, that.queryName);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int doHashCode() {
|
||||
return Objects.hash(filter, function);
|
||||
return Objects.hash(filter, function, queryName);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -107,7 +127,7 @@ public class FunctionScoreQuery extends Query {
|
|||
if (newFilter == filter) {
|
||||
return this;
|
||||
}
|
||||
return new FilterScoreFunction(newFilter, function);
|
||||
return new FilterScoreFunction(newFilter, function, queryName);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -144,6 +164,7 @@ public class FunctionScoreQuery extends Query {
|
|||
final float maxBoost;
|
||||
private final Float minScore;
|
||||
private final CombineFunction combineFunction;
|
||||
private final String queryName;
|
||||
|
||||
/**
|
||||
* Creates a FunctionScoreQuery without function.
|
||||
|
@ -152,7 +173,18 @@ public class FunctionScoreQuery extends Query {
|
|||
* @param maxBoost The maximum applicable boost.
|
||||
*/
|
||||
public FunctionScoreQuery(Query subQuery, Float minScore, float maxBoost) {
|
||||
this(subQuery, ScoreMode.FIRST, new ScoreFunction[0], CombineFunction.MULTIPLY, minScore, maxBoost);
|
||||
this(subQuery, null, minScore, maxBoost);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a FunctionScoreQuery without function.
|
||||
* @param subQuery The query to match.
|
||||
* @param queryName filter query name
|
||||
* @param minScore The minimum score to consider a document.
|
||||
* @param maxBoost The maximum applicable boost.
|
||||
*/
|
||||
public FunctionScoreQuery(Query subQuery, @Nullable String queryName, Float minScore, float maxBoost) {
|
||||
this(subQuery, queryName, ScoreMode.FIRST, new ScoreFunction[0], CombineFunction.MULTIPLY, minScore, maxBoost);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -161,7 +193,17 @@ public class FunctionScoreQuery extends Query {
|
|||
* @param function The {@link ScoreFunction} to apply.
|
||||
*/
|
||||
public FunctionScoreQuery(Query subQuery, ScoreFunction function) {
|
||||
this(subQuery, function, CombineFunction.MULTIPLY, null, DEFAULT_MAX_BOOST);
|
||||
this(subQuery, null, function);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a FunctionScoreQuery with a single {@link ScoreFunction}
|
||||
* @param subQuery The query to match.
|
||||
* @param queryName filter query name
|
||||
* @param function The {@link ScoreFunction} to apply.
|
||||
*/
|
||||
public FunctionScoreQuery(Query subQuery, @Nullable String queryName, ScoreFunction function) {
|
||||
this(subQuery, queryName, function, CombineFunction.MULTIPLY, null, DEFAULT_MAX_BOOST);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -173,7 +215,27 @@ public class FunctionScoreQuery extends Query {
|
|||
* @param maxBoost The maximum applicable boost.
|
||||
*/
|
||||
public FunctionScoreQuery(Query subQuery, ScoreFunction function, CombineFunction combineFunction, Float minScore, float maxBoost) {
|
||||
this(subQuery, ScoreMode.FIRST, new ScoreFunction[] { function }, combineFunction, minScore, maxBoost);
|
||||
this(subQuery, null, function, combineFunction, minScore, maxBoost);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a FunctionScoreQuery with a single function
|
||||
* @param subQuery The query to match.
|
||||
* @param queryName filter query name
|
||||
* @param function The {@link ScoreFunction} to apply.
|
||||
* @param combineFunction Defines how the query and function score should be applied.
|
||||
* @param minScore The minimum score to consider a document.
|
||||
* @param maxBoost The maximum applicable boost.
|
||||
*/
|
||||
public FunctionScoreQuery(
|
||||
Query subQuery,
|
||||
@Nullable String queryName,
|
||||
ScoreFunction function,
|
||||
CombineFunction combineFunction,
|
||||
Float minScore,
|
||||
float maxBoost
|
||||
) {
|
||||
this(subQuery, queryName, ScoreMode.FIRST, new ScoreFunction[] { function }, combineFunction, minScore, maxBoost);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -192,11 +254,34 @@ public class FunctionScoreQuery extends Query {
|
|||
CombineFunction combineFunction,
|
||||
Float minScore,
|
||||
float maxBoost
|
||||
) {
|
||||
this(subQuery, null, scoreMode, functions, combineFunction, minScore, maxBoost);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a FunctionScoreQuery with multiple score functions
|
||||
* @param subQuery The query to match.
|
||||
* @param queryName filter query name
|
||||
* @param scoreMode Defines how the different score functions should be combined.
|
||||
* @param functions The {@link ScoreFunction}s to apply.
|
||||
* @param combineFunction Defines how the query and function score should be applied.
|
||||
* @param minScore The minimum score to consider a document.
|
||||
* @param maxBoost The maximum applicable boost.
|
||||
*/
|
||||
public FunctionScoreQuery(
|
||||
Query subQuery,
|
||||
@Nullable String queryName,
|
||||
ScoreMode scoreMode,
|
||||
ScoreFunction[] functions,
|
||||
CombineFunction combineFunction,
|
||||
Float minScore,
|
||||
float maxBoost
|
||||
) {
|
||||
if (Arrays.stream(functions).anyMatch(func -> func == null)) {
|
||||
throw new IllegalArgumentException("Score function should not be null");
|
||||
}
|
||||
this.subQuery = subQuery;
|
||||
this.queryName = queryName;
|
||||
this.scoreMode = scoreMode;
|
||||
this.functions = functions;
|
||||
this.maxBoost = maxBoost;
|
||||
|
@ -240,7 +325,7 @@ public class FunctionScoreQuery extends Query {
|
|||
needsRewrite |= (newFunctions[i] != functions[i]);
|
||||
}
|
||||
if (needsRewrite) {
|
||||
return new FunctionScoreQuery(newQ, scoreMode, newFunctions, combineFunction, minScore, maxBoost);
|
||||
return new FunctionScoreQuery(newQ, queryName, scoreMode, newFunctions, combineFunction, minScore, maxBoost);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
@ -332,8 +417,7 @@ public class FunctionScoreQuery extends Query {
|
|||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
|
||||
Explanation expl = subQueryWeight.explain(context, doc);
|
||||
Explanation expl = Functions.explainWithName(subQueryWeight.explain(context, doc), queryName);
|
||||
if (!expl.isMatch()) {
|
||||
return expl;
|
||||
}
|
||||
|
@ -355,11 +439,15 @@ public class FunctionScoreQuery extends Query {
|
|||
Explanation functionExplanation = function.getLeafScoreFunction(context).explainScore(doc, expl);
|
||||
if (function instanceof FilterScoreFunction) {
|
||||
float factor = functionExplanation.getValue().floatValue();
|
||||
Query filterQuery = ((FilterScoreFunction) function).filter;
|
||||
final FilterScoreFunction filterScoreFunction = (FilterScoreFunction) function;
|
||||
Query filterQuery = filterScoreFunction.filter;
|
||||
Explanation filterExplanation = Explanation.match(
|
||||
factor,
|
||||
"function score, product of:",
|
||||
Explanation.match(1.0f, "match filter: " + filterQuery.toString()),
|
||||
Explanation.match(
|
||||
1.0f,
|
||||
"match filter" + Functions.nameOrEmptyFunc(filterScoreFunction.queryName) + ": " + filterQuery.toString()
|
||||
),
|
||||
functionExplanation
|
||||
);
|
||||
functionsExplanations.add(filterExplanation);
|
||||
|
@ -543,11 +631,12 @@ public class FunctionScoreQuery extends Query {
|
|||
&& Objects.equals(this.combineFunction, other.combineFunction)
|
||||
&& Objects.equals(this.minScore, other.minScore)
|
||||
&& Objects.equals(this.scoreMode, other.scoreMode)
|
||||
&& Arrays.equals(this.functions, other.functions);
|
||||
&& Arrays.equals(this.functions, other.functions)
|
||||
&& Objects.equals(this.queryName, other.queryName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classHash(), subQuery, maxBoost, combineFunction, minScore, scoreMode, Arrays.hashCode(functions));
|
||||
return Objects.hash(classHash(), subQuery, maxBoost, combineFunction, minScore, scoreMode, Arrays.hashCode(functions), queryName);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* The OpenSearch Contributors require contributions made to
|
||||
* this file be licensed under the Apache-2.0 license or a
|
||||
* compatible open source license.
|
||||
*/
|
||||
|
||||
package org.opensearch.common.lucene.search.function;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.common.Strings;
|
||||
import org.opensearch.index.query.AbstractQueryBuilder;
|
||||
import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder;
|
||||
|
||||
/**
|
||||
* Helper utility class for functions
|
||||
*/
|
||||
public final class Functions {
|
||||
private Functions() {}
|
||||
|
||||
/**
|
||||
* Return function name wrapped into brackets or empty string, for example: '(_name: func1)'
|
||||
* @param functionName function name
|
||||
* @return function name wrapped into brackets or empty string
|
||||
*/
|
||||
public static String nameOrEmptyFunc(final String functionName) {
|
||||
if (!Strings.isNullOrEmpty(functionName)) {
|
||||
return "(" + AbstractQueryBuilder.NAME_FIELD.getPreferredName() + ": " + functionName + ")";
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return function name as an argument or empty string, for example: ', _name: func1'
|
||||
* @param functionName function name
|
||||
* @return function name as an argument or empty string
|
||||
*/
|
||||
public static String nameOrEmptyArg(final String functionName) {
|
||||
if (!Strings.isNullOrEmpty(functionName)) {
|
||||
return ", " + FunctionScoreQueryBuilder.NAME_FIELD.getPreferredName() + ": " + functionName;
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enrich explanation with query name
|
||||
* @param explanation explanation
|
||||
* @param queryName query name
|
||||
* @return explanation enriched with query name
|
||||
*/
|
||||
public static Explanation explainWithName(Explanation explanation, String queryName) {
|
||||
if (Strings.isNullOrEmpty(queryName)) {
|
||||
return explanation;
|
||||
} else {
|
||||
final String description = explanation.getDescription() + " " + nameOrEmptyFunc(queryName);
|
||||
if (explanation.isMatch()) {
|
||||
return Explanation.match(explanation.getValue(), description, explanation.getDetails());
|
||||
} else {
|
||||
return Explanation.noMatch(description, explanation.getDetails());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -35,6 +35,7 @@ import com.carrotsearch.hppc.BitMixer;
|
|||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.util.StringHelper;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.index.fielddata.IndexFieldData;
|
||||
import org.opensearch.index.fielddata.LeafFieldData;
|
||||
import org.opensearch.index.fielddata.SortedBinaryDocValues;
|
||||
|
@ -50,6 +51,7 @@ public class RandomScoreFunction extends ScoreFunction {
|
|||
private final int originalSeed;
|
||||
private final int saltedSeed;
|
||||
private final IndexFieldData<?> fieldData;
|
||||
private final String functionName;
|
||||
|
||||
/**
|
||||
* Creates a RandomScoreFunction.
|
||||
|
@ -59,10 +61,23 @@ public class RandomScoreFunction extends ScoreFunction {
|
|||
* @param uidFieldData The field data for _uid to use for generating consistent random values for the same id
|
||||
*/
|
||||
public RandomScoreFunction(int seed, int salt, IndexFieldData<?> uidFieldData) {
|
||||
this(seed, salt, uidFieldData, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a RandomScoreFunction.
|
||||
*
|
||||
* @param seed A seed for randomness
|
||||
* @param salt A value to salt the seed with, ideally unique to the running node/index
|
||||
* @param uidFieldData The field data for _uid to use for generating consistent random values for the same id
|
||||
* @param functionName The function name
|
||||
*/
|
||||
public RandomScoreFunction(int seed, int salt, IndexFieldData<?> uidFieldData, @Nullable String functionName) {
|
||||
super(CombineFunction.MULTIPLY);
|
||||
this.originalSeed = seed;
|
||||
this.saltedSeed = BitMixer.mix(seed, salt);
|
||||
this.fieldData = uidFieldData;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -97,7 +112,7 @@ public class RandomScoreFunction extends ScoreFunction {
|
|||
String field = fieldData == null ? null : fieldData.getFieldName();
|
||||
return Explanation.match(
|
||||
(float) score(docId, subQueryScore.getValue().floatValue()),
|
||||
"random score function (seed: " + originalSeed + ", field: " + field + ")"
|
||||
"random score function (seed: " + originalSeed + ", field: " + field + Functions.nameOrEmptyArg(functionName) + ")"
|
||||
);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.opensearch.script.ExplainableScoreScript;
|
|||
import org.opensearch.script.ScoreScript;
|
||||
import org.opensearch.script.Script;
|
||||
import org.opensearch.Version;
|
||||
import org.opensearch.common.Nullable;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
@ -67,14 +68,23 @@ public class ScriptScoreFunction extends ScoreFunction {
|
|||
private final int shardId;
|
||||
private final String indexName;
|
||||
private final Version indexVersion;
|
||||
private final String functionName;
|
||||
|
||||
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
|
||||
public ScriptScoreFunction(
|
||||
Script sScript,
|
||||
ScoreScript.LeafFactory script,
|
||||
String indexName,
|
||||
int shardId,
|
||||
Version indexVersion,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(CombineFunction.REPLACE);
|
||||
this.sScript = sScript;
|
||||
this.script = script;
|
||||
this.indexName = indexName;
|
||||
this.shardId = shardId;
|
||||
this.indexVersion = indexVersion;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -105,11 +115,15 @@ public class ScriptScoreFunction extends ScoreFunction {
|
|||
leafScript.setDocument(docId);
|
||||
scorer.docid = docId;
|
||||
scorer.score = subQueryScore.getValue().floatValue();
|
||||
exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore);
|
||||
exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore, functionName);
|
||||
} else {
|
||||
double score = score(docId, subQueryScore.getValue().floatValue());
|
||||
// info about params already included in sScript
|
||||
String explanation = "script score function, computed with script:\"" + sScript + "\"";
|
||||
String explanation = "script score function"
|
||||
+ Functions.nameOrEmptyFunc(functionName)
|
||||
+ ", computed with script:\""
|
||||
+ sScript
|
||||
+ "\"";
|
||||
Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore);
|
||||
return Explanation.match((float) score, explanation, scoreExp);
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ import org.apache.lucene.search.Scorer;
|
|||
import org.apache.lucene.search.BulkScorer;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.opensearch.Version;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.script.ScoreScript;
|
||||
import org.opensearch.script.ScoreScript.ExplanationHolder;
|
||||
import org.opensearch.script.Script;
|
||||
|
@ -69,6 +70,7 @@ public class ScriptScoreQuery extends Query {
|
|||
private final String indexName;
|
||||
private final int shardId;
|
||||
private final Version indexVersion;
|
||||
private final String queryName;
|
||||
|
||||
public ScriptScoreQuery(
|
||||
Query subQuery,
|
||||
|
@ -78,8 +80,22 @@ public class ScriptScoreQuery extends Query {
|
|||
String indexName,
|
||||
int shardId,
|
||||
Version indexVersion
|
||||
) {
|
||||
this(subQuery, null, script, scriptBuilder, minScore, indexName, shardId, indexVersion);
|
||||
}
|
||||
|
||||
public ScriptScoreQuery(
|
||||
Query subQuery,
|
||||
@Nullable String queryName,
|
||||
Script script,
|
||||
ScoreScript.LeafFactory scriptBuilder,
|
||||
Float minScore,
|
||||
String indexName,
|
||||
int shardId,
|
||||
Version indexVersion
|
||||
) {
|
||||
this.subQuery = subQuery;
|
||||
this.queryName = queryName;
|
||||
this.script = script;
|
||||
this.scriptBuilder = scriptBuilder;
|
||||
this.minScore = minScore;
|
||||
|
@ -92,7 +108,7 @@ public class ScriptScoreQuery extends Query {
|
|||
public Query rewrite(IndexReader reader) throws IOException {
|
||||
Query newQ = subQuery.rewrite(reader);
|
||||
if (newQ != subQuery) {
|
||||
return new ScriptScoreQuery(newQ, script, scriptBuilder, minScore, indexName, shardId, indexVersion);
|
||||
return new ScriptScoreQuery(newQ, queryName, script, scriptBuilder, minScore, indexName, shardId, indexVersion);
|
||||
}
|
||||
return super.rewrite(reader);
|
||||
}
|
||||
|
@ -140,7 +156,7 @@ public class ScriptScoreQuery extends Query {
|
|||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
Explanation subQueryExplanation = subQueryWeight.explain(context, doc);
|
||||
Explanation subQueryExplanation = Functions.explainWithName(subQueryWeight.explain(context, doc), queryName);
|
||||
if (subQueryExplanation.isMatch() == false) {
|
||||
return subQueryExplanation;
|
||||
}
|
||||
|
@ -210,7 +226,8 @@ public class ScriptScoreQuery extends Query {
|
|||
@Override
|
||||
public String toString(String field) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("script_score (").append(subQuery.toString(field)).append(", script: ");
|
||||
sb.append("script_score (").append(subQuery.toString(field));
|
||||
sb.append(Functions.nameOrEmptyArg(queryName)).append(", script: ");
|
||||
sb.append("{" + script.toString() + "}");
|
||||
return sb.toString();
|
||||
}
|
||||
|
@ -225,12 +242,13 @@ public class ScriptScoreQuery extends Query {
|
|||
&& script.equals(that.script)
|
||||
&& Objects.equals(minScore, that.minScore)
|
||||
&& indexName.equals(that.indexName)
|
||||
&& indexVersion.equals(that.indexVersion);
|
||||
&& indexVersion.equals(that.indexVersion)
|
||||
&& Objects.equals(queryName, that.queryName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion);
|
||||
return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion, queryName);
|
||||
}
|
||||
|
||||
private static class ScriptScorer extends Scorer {
|
||||
|
|
|
@ -34,6 +34,8 @@ package org.opensearch.common.lucene.search.function;
|
|||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.Strings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
@ -45,9 +47,17 @@ public class WeightFactorFunction extends ScoreFunction {
|
|||
private float weight = 1.0f;
|
||||
|
||||
public WeightFactorFunction(float weight, ScoreFunction scoreFunction) {
|
||||
this(weight, scoreFunction, null);
|
||||
}
|
||||
|
||||
public WeightFactorFunction(float weight, ScoreFunction scoreFunction, @Nullable String functionName) {
|
||||
super(CombineFunction.MULTIPLY);
|
||||
if (scoreFunction == null) {
|
||||
if (Strings.isNullOrEmpty(functionName)) {
|
||||
this.scoreFunction = SCORE_ONE;
|
||||
} else {
|
||||
this.scoreFunction = new ScoreOne(CombineFunction.MULTIPLY, functionName);
|
||||
}
|
||||
} else {
|
||||
this.scoreFunction = scoreFunction;
|
||||
}
|
||||
|
@ -55,9 +65,11 @@ public class WeightFactorFunction extends ScoreFunction {
|
|||
}
|
||||
|
||||
public WeightFactorFunction(float weight) {
|
||||
super(CombineFunction.MULTIPLY);
|
||||
this.scoreFunction = SCORE_ONE;
|
||||
this.weight = weight;
|
||||
this(weight, null, null);
|
||||
}
|
||||
|
||||
public WeightFactorFunction(float weight, @Nullable String functionName) {
|
||||
this(weight, null, functionName);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -112,9 +124,15 @@ public class WeightFactorFunction extends ScoreFunction {
|
|||
}
|
||||
|
||||
private static class ScoreOne extends ScoreFunction {
|
||||
private final String functionName;
|
||||
|
||||
protected ScoreOne(CombineFunction scoreCombiner) {
|
||||
this(scoreCombiner, null);
|
||||
}
|
||||
|
||||
protected ScoreOne(CombineFunction scoreCombiner, @Nullable String functionName) {
|
||||
super(scoreCombiner);
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -127,7 +145,10 @@ public class WeightFactorFunction extends ScoreFunction {
|
|||
|
||||
@Override
|
||||
public Explanation explainScore(int docId, Explanation subQueryScore) {
|
||||
return Explanation.match(1.0f, "constant score 1.0 - no function provided");
|
||||
return Explanation.match(
|
||||
1.0f,
|
||||
"constant score 1.0" + Functions.nameOrEmptyFunc(functionName) + " - no function provided"
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
package org.opensearch.index.query;
|
||||
|
||||
import org.apache.lucene.search.join.ScoreMode;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.bytes.BytesReference;
|
||||
import org.opensearch.common.geo.GeoPoint;
|
||||
import org.opensearch.common.geo.ShapeRelation;
|
||||
|
@ -452,7 +453,17 @@ public final class QueryBuilders {
|
|||
* @param function The function builder used to custom score
|
||||
*/
|
||||
public static FunctionScoreQueryBuilder functionScoreQuery(ScoreFunctionBuilder function) {
|
||||
return new FunctionScoreQueryBuilder(function);
|
||||
return functionScoreQuery(function, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* A query that allows to define a custom scoring function.
|
||||
*
|
||||
* @param function The function builder used to custom score
|
||||
* @param queryName The query name
|
||||
*/
|
||||
public static FunctionScoreQueryBuilder functionScoreQuery(ScoreFunctionBuilder function, @Nullable String queryName) {
|
||||
return new FunctionScoreQueryBuilder(function, queryName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -43,9 +43,11 @@ import org.apache.lucene.search.Scorer;
|
|||
import org.apache.lucene.search.TwoPhaseIterator;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.opensearch.OpenSearchException;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.ParsingException;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
import org.opensearch.common.lucene.search.function.Functions;
|
||||
import org.opensearch.common.xcontent.XContentBuilder;
|
||||
import org.opensearch.common.xcontent.XContentParser;
|
||||
import org.opensearch.script.FilterScript;
|
||||
|
@ -153,17 +155,19 @@ public class ScriptQueryBuilder extends AbstractQueryBuilder<ScriptQueryBuilder>
|
|||
}
|
||||
FilterScript.Factory factory = context.compile(script, FilterScript.CONTEXT);
|
||||
FilterScript.LeafFactory filterScript = factory.newFactory(script.getParams(), context.lookup());
|
||||
return new ScriptQuery(script, filterScript);
|
||||
return new ScriptQuery(script, filterScript, queryName);
|
||||
}
|
||||
|
||||
static class ScriptQuery extends Query {
|
||||
|
||||
final Script script;
|
||||
final FilterScript.LeafFactory filterScript;
|
||||
final String queryName;
|
||||
|
||||
ScriptQuery(Script script, FilterScript.LeafFactory filterScript) {
|
||||
ScriptQuery(Script script, FilterScript.LeafFactory filterScript, @Nullable String queryName) {
|
||||
this.script = script;
|
||||
this.filterScript = filterScript;
|
||||
this.queryName = queryName;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -171,6 +175,7 @@ public class ScriptQueryBuilder extends AbstractQueryBuilder<ScriptQueryBuilder>
|
|||
StringBuilder buffer = new StringBuilder();
|
||||
buffer.append("ScriptQuery(");
|
||||
buffer.append(script);
|
||||
buffer.append(Functions.nameOrEmptyArg(queryName));
|
||||
buffer.append(")");
|
||||
return buffer.toString();
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.common.Nullable;
|
||||
|
||||
/**
|
||||
* Implement this interface to provide a decay function that is executed on a
|
||||
|
@ -45,7 +46,7 @@ public interface DecayFunction {
|
|||
|
||||
double evaluate(double value, double scale);
|
||||
|
||||
Explanation explainFunction(String valueString, double value, double scale);
|
||||
Explanation explainFunction(String valueString, double value, double scale, @Nullable String functionName);
|
||||
|
||||
/**
|
||||
* The final scale parameter is computed from the scale parameter given by
|
||||
|
|
|
@ -35,6 +35,7 @@ package org.opensearch.index.query.functionscore;
|
|||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.OpenSearchParseException;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.ParsingException;
|
||||
import org.opensearch.common.bytes.BytesReference;
|
||||
import org.opensearch.common.geo.GeoDistance;
|
||||
|
@ -93,10 +94,31 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
this(fieldName, origin, scale, offset, DEFAULT_DECAY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience constructor that converts its parameters into json to parse on the data nodes.
|
||||
*/
|
||||
protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
|
||||
this(fieldName, origin, scale, offset, DEFAULT_DECAY, functionName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience constructor that converts its parameters into json to parse on the data nodes.
|
||||
*/
|
||||
protected DecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) {
|
||||
this(fieldName, origin, scale, offset, decay, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience constructor that converts its parameters into json to parse on the data nodes.
|
||||
*/
|
||||
protected DecayFunctionBuilder(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
double decay,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
if (fieldName == null) {
|
||||
throw new IllegalArgumentException("decay function: field name must not be null");
|
||||
}
|
||||
|
@ -123,6 +145,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
} catch (IOException e) {
|
||||
throw new IllegalArgumentException("unable to build inner function object", e);
|
||||
}
|
||||
setFunctionName(functionName);
|
||||
}
|
||||
|
||||
protected DecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
|
||||
|
@ -285,7 +308,16 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
);
|
||||
}
|
||||
IndexNumericFieldData numericFieldData = context.getForField(fieldType);
|
||||
return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode);
|
||||
return new NumericFieldDataScoreFunction(
|
||||
origin,
|
||||
scale,
|
||||
decay,
|
||||
offset,
|
||||
getDecayFunction(),
|
||||
numericFieldData,
|
||||
mode,
|
||||
getFunctionName()
|
||||
);
|
||||
}
|
||||
|
||||
private AbstractDistanceScoreFunction parseGeoVariable(
|
||||
|
@ -325,7 +357,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
double scale = DistanceUnit.DEFAULT.parse(scaleString, DistanceUnit.DEFAULT);
|
||||
double offset = DistanceUnit.DEFAULT.parse(offsetString, DistanceUnit.DEFAULT);
|
||||
IndexGeoPointFieldData indexFieldData = context.getForField(fieldType);
|
||||
return new GeoFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), indexFieldData, mode);
|
||||
return new GeoFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), indexFieldData, mode, getFunctionName());
|
||||
|
||||
}
|
||||
|
||||
|
@ -375,7 +407,16 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24), DecayFunctionParser.class.getSimpleName() + ".offset");
|
||||
double offset = val.getMillis();
|
||||
IndexNumericFieldData numericFieldData = context.getForField(dateFieldType);
|
||||
return new NumericFieldDataScoreFunction(origin, scale, decay, offset, getDecayFunction(), numericFieldData, mode);
|
||||
return new NumericFieldDataScoreFunction(
|
||||
origin,
|
||||
scale,
|
||||
decay,
|
||||
offset,
|
||||
getDecayFunction(),
|
||||
numericFieldData,
|
||||
mode,
|
||||
getFunctionName()
|
||||
);
|
||||
}
|
||||
|
||||
static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction {
|
||||
|
@ -392,9 +433,10 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
double offset,
|
||||
DecayFunction func,
|
||||
IndexGeoPointFieldData fieldData,
|
||||
MultiValueMode mode
|
||||
MultiValueMode mode,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(scale, decay, offset, func, mode);
|
||||
super(scale, decay, offset, func, mode, functionName);
|
||||
this.origin = origin;
|
||||
this.fieldData = fieldData;
|
||||
}
|
||||
|
@ -485,9 +527,10 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
double offset,
|
||||
DecayFunction func,
|
||||
IndexNumericFieldData fieldData,
|
||||
MultiValueMode mode
|
||||
MultiValueMode mode,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(scale, decay, offset, func, mode);
|
||||
super(scale, decay, offset, func, mode, functionName);
|
||||
this.fieldData = fieldData;
|
||||
this.origin = origin;
|
||||
}
|
||||
|
@ -569,13 +612,15 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
protected final double offset;
|
||||
private final DecayFunction func;
|
||||
protected final MultiValueMode mode;
|
||||
protected final String functionName;
|
||||
|
||||
public AbstractDistanceScoreFunction(
|
||||
double userSuppiedScale,
|
||||
double decay,
|
||||
double offset,
|
||||
DecayFunction func,
|
||||
MultiValueMode mode
|
||||
MultiValueMode mode,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(CombineFunction.MULTIPLY);
|
||||
this.mode = mode;
|
||||
|
@ -591,6 +636,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
throw new IllegalArgumentException(FunctionScoreQueryBuilder.NAME + " : offset must be > 0.0");
|
||||
}
|
||||
this.offset = offset;
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -624,7 +670,7 @@ public abstract class DecayFunctionBuilder<DFB extends DecayFunctionBuilder<DFB>
|
|||
return Explanation.match(
|
||||
(float) score(docId, subQueryScore.getValue().floatValue()),
|
||||
"Function for field " + getFieldName() + ":",
|
||||
func.explainFunction(getDistanceString(ctx, docId), value, scale)
|
||||
func.explainFunction(getDistanceString(ctx, docId), value, scale, functionName)
|
||||
);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -33,8 +33,10 @@
|
|||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.bytes.BytesReference;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.lucene.search.function.Functions;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
@ -45,6 +47,10 @@ public class ExponentialDecayFunctionBuilder extends DecayFunctionBuilder<Expone
|
|||
);
|
||||
public static final DecayFunction EXP_DECAY_FUNCTION = new ExponentialDecayScoreFunction();
|
||||
|
||||
public ExponentialDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
|
||||
super(fieldName, origin, scale, offset, functionName);
|
||||
}
|
||||
|
||||
public ExponentialDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset) {
|
||||
super(fieldName, origin, scale, offset);
|
||||
}
|
||||
|
@ -53,6 +59,17 @@ public class ExponentialDecayFunctionBuilder extends DecayFunctionBuilder<Expone
|
|||
super(fieldName, origin, scale, offset, decay);
|
||||
}
|
||||
|
||||
public ExponentialDecayFunctionBuilder(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
double decay,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(fieldName, origin, scale, offset, decay, functionName);
|
||||
}
|
||||
|
||||
ExponentialDecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
|
||||
super(fieldName, functionBytes);
|
||||
}
|
||||
|
@ -82,8 +99,11 @@ public class ExponentialDecayFunctionBuilder extends DecayFunctionBuilder<Expone
|
|||
}
|
||||
|
||||
@Override
|
||||
public Explanation explainFunction(String valueExpl, double value, double scale) {
|
||||
return Explanation.match((float) evaluate(value, scale), "exp(- " + valueExpl + " * " + -1 * scale + ")");
|
||||
public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) {
|
||||
return Explanation.match(
|
||||
(float) evaluate(value, scale),
|
||||
"exp(- " + valueExpl + " * " + -1 * scale + Functions.nameOrEmptyArg(functionName) + ")"
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.opensearch.OpenSearchException;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.ParsingException;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
|
@ -63,10 +64,15 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
|
|||
private FieldValueFactorFunction.Modifier modifier = DEFAULT_MODIFIER;
|
||||
|
||||
public FieldValueFactorFunctionBuilder(String fieldName) {
|
||||
this(fieldName, null);
|
||||
}
|
||||
|
||||
public FieldValueFactorFunctionBuilder(String fieldName, @Nullable String functionName) {
|
||||
if (fieldName == null) {
|
||||
throw new IllegalArgumentException("field_value_factor: field must not be null");
|
||||
}
|
||||
this.field = fieldName;
|
||||
setFunctionName(functionName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -166,7 +172,7 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
|
|||
} else {
|
||||
fieldData = context.getForField(fieldType);
|
||||
}
|
||||
return new FieldValueFactorFunction(field, factor, modifier, missing, fieldData);
|
||||
return new FieldValueFactorFunction(field, factor, modifier, missing, fieldData, getFunctionName());
|
||||
}
|
||||
|
||||
public static FieldValueFactorFunctionBuilder fromXContent(XContentParser parser) throws IOException, ParsingException {
|
||||
|
@ -176,6 +182,7 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
|
|||
FieldValueFactorFunction.Modifier modifier = FieldValueFactorFunction.Modifier.NONE;
|
||||
Double missing = null;
|
||||
XContentParser.Token token;
|
||||
String functionName = null;
|
||||
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
|
||||
if (token == XContentParser.Token.FIELD_NAME) {
|
||||
currentFieldName = parser.currentName();
|
||||
|
@ -188,6 +195,8 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
|
|||
modifier = FieldValueFactorFunction.Modifier.fromString(parser.text());
|
||||
} else if ("missing".equals(currentFieldName)) {
|
||||
missing = parser.doubleValue();
|
||||
} else if (FunctionScoreQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
|
||||
functionName = parser.text();
|
||||
} else {
|
||||
throw new ParsingException(parser.getTokenLocation(), NAME + " query does not support [" + currentFieldName + "]");
|
||||
}
|
||||
|
@ -204,8 +213,9 @@ public class FieldValueFactorFunctionBuilder extends ScoreFunctionBuilder<FieldV
|
|||
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] required field 'field' missing");
|
||||
}
|
||||
|
||||
FieldValueFactorFunctionBuilder fieldValueFactorFunctionBuilder = new FieldValueFactorFunctionBuilder(field).factor(boostFactor)
|
||||
.modifier(modifier);
|
||||
FieldValueFactorFunctionBuilder fieldValueFactorFunctionBuilder = new FieldValueFactorFunctionBuilder(field, functionName).factor(
|
||||
boostFactor
|
||||
).modifier(modifier);
|
||||
if (missing != null) {
|
||||
fieldValueFactorFunctionBuilder.missing(missing);
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ package org.opensearch.index.query.functionscore;
|
|||
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.ParseField;
|
||||
import org.opensearch.common.ParsingException;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
|
@ -111,7 +112,17 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
|
|||
* @param filterFunctionBuilders the filters and functions
|
||||
*/
|
||||
public FunctionScoreQueryBuilder(FilterFunctionBuilder[] filterFunctionBuilders) {
|
||||
this(new MatchAllQueryBuilder(), filterFunctionBuilders);
|
||||
this(filterFunctionBuilders, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a function_score query that executes the provided filters and functions on all documents
|
||||
*
|
||||
* @param filterFunctionBuilders the filters and functions
|
||||
* @param queryName the query name
|
||||
*/
|
||||
public FunctionScoreQueryBuilder(FilterFunctionBuilder[] filterFunctionBuilders, @Nullable String queryName) {
|
||||
this(new MatchAllQueryBuilder().queryName(queryName), filterFunctionBuilders);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -120,7 +131,20 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
|
|||
* @param scoreFunctionBuilder score function that is executed
|
||||
*/
|
||||
public FunctionScoreQueryBuilder(ScoreFunctionBuilder<?> scoreFunctionBuilder) {
|
||||
this(new MatchAllQueryBuilder(), new FilterFunctionBuilder[] { new FilterFunctionBuilder(scoreFunctionBuilder) });
|
||||
this(scoreFunctionBuilder, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a function_score query that will execute the function provided on all documents
|
||||
*
|
||||
* @param scoreFunctionBuilder score function that is executed
|
||||
* @param queryName the query name
|
||||
*/
|
||||
public FunctionScoreQueryBuilder(ScoreFunctionBuilder<?> scoreFunctionBuilder, @Nullable String queryName) {
|
||||
this(
|
||||
new MatchAllQueryBuilder().queryName(queryName),
|
||||
new FilterFunctionBuilder[] { new FilterFunctionBuilder(scoreFunctionBuilder) }
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -316,15 +340,17 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
|
|||
int i = 0;
|
||||
for (FilterFunctionBuilder filterFunctionBuilder : filterFunctionBuilders) {
|
||||
ScoreFunction scoreFunction = filterFunctionBuilder.getScoreFunction().toFunction(context);
|
||||
if (filterFunctionBuilder.getFilter().getName().equals(MatchAllQueryBuilder.NAME)) {
|
||||
final QueryBuilder builder = filterFunctionBuilder.getFilter();
|
||||
if (builder.getName().equals(MatchAllQueryBuilder.NAME)) {
|
||||
filterFunctions[i++] = scoreFunction;
|
||||
} else {
|
||||
Query filter = filterFunctionBuilder.getFilter().toQuery(context);
|
||||
filterFunctions[i++] = new FunctionScoreQuery.FilterScoreFunction(filter, scoreFunction);
|
||||
Query filter = builder.toQuery(context);
|
||||
filterFunctions[i++] = new FunctionScoreQuery.FilterScoreFunction(filter, scoreFunction, builder.queryName());
|
||||
}
|
||||
}
|
||||
|
||||
Query query = this.query.toQuery(context);
|
||||
final QueryBuilder builder = this.query;
|
||||
Query query = builder.toQuery(context);
|
||||
if (query == null) {
|
||||
query = new MatchAllDocsQuery();
|
||||
}
|
||||
|
@ -332,12 +358,12 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
|
|||
CombineFunction boostMode = this.boostMode == null ? DEFAULT_BOOST_MODE : this.boostMode;
|
||||
// handle cases where only one score function and no filter was provided. In this case we create a FunctionScoreQuery.
|
||||
if (filterFunctions.length == 0) {
|
||||
return new FunctionScoreQuery(query, minScore, maxBoost);
|
||||
return new FunctionScoreQuery(query, builder.queryName(), minScore, maxBoost);
|
||||
} else if (filterFunctions.length == 1 && filterFunctions[0] instanceof FunctionScoreQuery.FilterScoreFunction == false) {
|
||||
return new FunctionScoreQuery(query, filterFunctions[0], boostMode, minScore, maxBoost);
|
||||
return new FunctionScoreQuery(query, builder.queryName(), filterFunctions[0], boostMode, minScore, maxBoost);
|
||||
}
|
||||
// in all other cases we create a FunctionScoreQuery with filters
|
||||
return new FunctionScoreQuery(query, scoreMode, filterFunctions, boostMode, minScore, maxBoost);
|
||||
return new FunctionScoreQuery(query, builder.queryName(), scoreMode, filterFunctions, boostMode, minScore, maxBoost);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -606,6 +632,7 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
|
|||
QueryBuilder filter = null;
|
||||
ScoreFunctionBuilder<?> scoreFunction = null;
|
||||
Float functionWeight = null;
|
||||
String functionName = null;
|
||||
if (token != XContentParser.Token.START_OBJECT) {
|
||||
throw new ParsingException(
|
||||
parser.getTokenLocation(),
|
||||
|
@ -635,6 +662,8 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
|
|||
} else if (token.isValue()) {
|
||||
if (WEIGHT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
|
||||
functionWeight = parser.floatValue();
|
||||
} else if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
|
||||
functionName = parser.text();
|
||||
} else {
|
||||
throw new ParsingException(
|
||||
parser.getTokenLocation(),
|
||||
|
@ -652,6 +681,10 @@ public class FunctionScoreQueryBuilder extends AbstractQueryBuilder<FunctionScor
|
|||
scoreFunction.setWeight(functionWeight);
|
||||
}
|
||||
}
|
||||
|
||||
if (functionName != null && scoreFunction != null) {
|
||||
scoreFunction.setFunctionName(functionName);
|
||||
}
|
||||
}
|
||||
if (filter == null) {
|
||||
filter = new MatchAllQueryBuilder();
|
||||
|
|
|
@ -33,9 +33,11 @@
|
|||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.ParseField;
|
||||
import org.opensearch.common.bytes.BytesReference;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.lucene.search.function.Functions;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
@ -49,10 +51,25 @@ public class GaussDecayFunctionBuilder extends DecayFunctionBuilder<GaussDecayFu
|
|||
super(fieldName, origin, scale, offset);
|
||||
}
|
||||
|
||||
public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
|
||||
super(fieldName, origin, scale, offset, functionName);
|
||||
}
|
||||
|
||||
public GaussDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) {
|
||||
super(fieldName, origin, scale, offset, decay);
|
||||
}
|
||||
|
||||
public GaussDecayFunctionBuilder(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
double decay,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(fieldName, origin, scale, offset, decay, functionName);
|
||||
}
|
||||
|
||||
GaussDecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
|
||||
super(fieldName, functionBytes);
|
||||
}
|
||||
|
@ -75,7 +92,6 @@ public class GaussDecayFunctionBuilder extends DecayFunctionBuilder<GaussDecayFu
|
|||
}
|
||||
|
||||
private static final class GaussScoreFunction implements DecayFunction {
|
||||
|
||||
@Override
|
||||
public double evaluate(double value, double scale) {
|
||||
// note that we already computed scale^2 in processScale() so we do
|
||||
|
@ -84,8 +100,11 @@ public class GaussDecayFunctionBuilder extends DecayFunctionBuilder<GaussDecayFu
|
|||
}
|
||||
|
||||
@Override
|
||||
public Explanation explainFunction(String valueExpl, double value, double scale) {
|
||||
return Explanation.match((float) evaluate(value, scale), "exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + ")");
|
||||
public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) {
|
||||
return Explanation.match(
|
||||
(float) evaluate(value, scale),
|
||||
"exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + Functions.nameOrEmptyArg(functionName) + ")"
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -33,8 +33,10 @@
|
|||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.bytes.BytesReference;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.lucene.search.function.Functions;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
@ -47,10 +49,25 @@ public class LinearDecayFunctionBuilder extends DecayFunctionBuilder<LinearDecay
|
|||
super(fieldName, origin, scale, offset);
|
||||
}
|
||||
|
||||
public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, @Nullable String functionName) {
|
||||
super(fieldName, origin, scale, offset, functionName);
|
||||
}
|
||||
|
||||
public LinearDecayFunctionBuilder(String fieldName, Object origin, Object scale, Object offset, double decay) {
|
||||
super(fieldName, origin, scale, offset, decay);
|
||||
}
|
||||
|
||||
public LinearDecayFunctionBuilder(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
double decay,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
super(fieldName, origin, scale, offset, decay, functionName);
|
||||
}
|
||||
|
||||
LinearDecayFunctionBuilder(String fieldName, BytesReference functionBytes) {
|
||||
super(fieldName, functionBytes);
|
||||
}
|
||||
|
@ -80,8 +97,11 @@ public class LinearDecayFunctionBuilder extends DecayFunctionBuilder<LinearDecay
|
|||
}
|
||||
|
||||
@Override
|
||||
public Explanation explainFunction(String valueExpl, double value, double scale) {
|
||||
return Explanation.match((float) evaluate(value, scale), "max(0.0, ((" + scale + " - " + valueExpl + ")/" + scale + ")");
|
||||
public Explanation explainFunction(String valueExpl, double value, double scale, @Nullable String functionName) {
|
||||
return Explanation.match(
|
||||
(float) evaluate(value, scale),
|
||||
"max(0.0, ((" + scale + " - " + valueExpl + ")/" + scale + Functions.nameOrEmptyArg(functionName) + ")"
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
|
||||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.ParsingException;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
|
@ -58,6 +59,10 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
|
|||
|
||||
public RandomScoreFunctionBuilder() {}
|
||||
|
||||
public RandomScoreFunctionBuilder(@Nullable String functionName) {
|
||||
setFunctionName(functionName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a stream.
|
||||
*/
|
||||
|
@ -166,7 +171,7 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
|
|||
final int salt = (context.index().getName().hashCode() << 10) | context.getShardId();
|
||||
if (seed == null) {
|
||||
// DocID-based random score generation
|
||||
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null);
|
||||
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null, getFunctionName());
|
||||
} else {
|
||||
final MappedFieldType fieldType;
|
||||
if (field != null) {
|
||||
|
@ -181,7 +186,7 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
|
|||
if (fieldType == null) {
|
||||
if (context.getMapperService().documentMapper() == null) {
|
||||
// no mappings: the index is empty anyway
|
||||
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null);
|
||||
return new RandomScoreFunction(hash(context.nowInMillis()), salt, null, getFunctionName());
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Field [" + field + "] is not mapped on [" + context.index() + "] and cannot be used as a source of random numbers."
|
||||
|
@ -193,7 +198,7 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
|
|||
} else {
|
||||
seed = hash(context.nowInMillis());
|
||||
}
|
||||
return new RandomScoreFunction(seed, salt, context.getForField(fieldType));
|
||||
return new RandomScoreFunction(seed, salt, context.getForField(fieldType), getFunctionName());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -231,6 +236,8 @@ public class RandomScoreFunctionBuilder extends ScoreFunctionBuilder<RandomScore
|
|||
}
|
||||
} else if ("field".equals(currentFieldName)) {
|
||||
randomScoreFunctionBuilder.setField(parser.text());
|
||||
} else if (FunctionScoreQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
|
||||
randomScoreFunctionBuilder.setFunctionName(parser.text());
|
||||
} else {
|
||||
throw new ParsingException(parser.getTokenLocation(), NAME + " query does not support [" + currentFieldName + "]");
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.opensearch.Version;
|
||||
import org.opensearch.common.io.stream.NamedWriteable;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
|
@ -47,6 +48,7 @@ import java.util.Objects;
|
|||
public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>> implements ToXContentFragment, NamedWriteable {
|
||||
|
||||
private Float weight;
|
||||
private String functionName;
|
||||
|
||||
/**
|
||||
* Standard empty constructor.
|
||||
|
@ -58,11 +60,17 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
|
|||
*/
|
||||
public ScoreFunctionBuilder(StreamInput in) throws IOException {
|
||||
weight = checkWeight(in.readOptionalFloat());
|
||||
if (in.getVersion().onOrAfter(Version.V_2_0_0)) {
|
||||
functionName = in.readOptionalString();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalFloat(weight);
|
||||
if (out.getVersion().onOrAfter(Version.V_2_0_0)) {
|
||||
out.writeOptionalString(functionName);
|
||||
}
|
||||
doWriteTo(out);
|
||||
}
|
||||
|
||||
|
@ -99,11 +107,30 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
|
|||
return weight;
|
||||
}
|
||||
|
||||
/**
|
||||
* The name of this function
|
||||
*/
|
||||
public String getFunctionName() {
|
||||
return functionName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the name of this function
|
||||
*/
|
||||
public void setFunctionName(String functionName) {
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
if (weight != null) {
|
||||
builder.field(FunctionScoreQueryBuilder.WEIGHT_FIELD.getPreferredName(), weight);
|
||||
}
|
||||
|
||||
if (functionName != null) {
|
||||
builder.field(FunctionScoreQueryBuilder.NAME_FIELD.getPreferredName(), functionName);
|
||||
}
|
||||
|
||||
doXContent(builder, params);
|
||||
return builder;
|
||||
}
|
||||
|
@ -128,7 +155,7 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
|
|||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
FB other = (FB) obj;
|
||||
return Objects.equals(weight, other.getWeight()) && doEquals(other);
|
||||
return Objects.equals(weight, other.getWeight()) && Objects.equals(functionName, other.getFunctionName()) && doEquals(other);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -139,7 +166,7 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
|
|||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(getClass(), weight, doHashCode());
|
||||
return Objects.hash(getClass(), weight, functionName, doHashCode());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -156,7 +183,7 @@ public abstract class ScoreFunctionBuilder<FB extends ScoreFunctionBuilder<FB>>
|
|||
if (weight == null) {
|
||||
return scoreFunction;
|
||||
}
|
||||
return new WeightFactorFunction(weight, scoreFunction);
|
||||
return new WeightFactorFunction(weight, scoreFunction, getFunctionName());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.script.Script;
|
||||
import org.opensearch.script.ScriptType;
|
||||
|
||||
|
@ -46,10 +47,29 @@ public class ScoreFunctionBuilders {
|
|||
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, null);
|
||||
}
|
||||
|
||||
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, null, functionName);
|
||||
}
|
||||
|
||||
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(String fieldName, Object origin, Object scale, Object offset) {
|
||||
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset);
|
||||
}
|
||||
|
||||
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, functionName);
|
||||
}
|
||||
|
||||
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
|
@ -60,10 +80,30 @@ public class ScoreFunctionBuilders {
|
|||
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, decay);
|
||||
}
|
||||
|
||||
public static ExponentialDecayFunctionBuilder exponentialDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
double decay,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new ExponentialDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName);
|
||||
}
|
||||
|
||||
public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale) {
|
||||
return new GaussDecayFunctionBuilder(fieldName, origin, scale, null);
|
||||
}
|
||||
|
||||
public static GaussDecayFunctionBuilder gaussDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new GaussDecayFunctionBuilder(fieldName, origin, scale, null, functionName);
|
||||
}
|
||||
|
||||
public static GaussDecayFunctionBuilder gaussDecayFunction(String fieldName, Object origin, Object scale, Object offset) {
|
||||
return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset);
|
||||
}
|
||||
|
@ -72,6 +112,26 @@ public class ScoreFunctionBuilders {
|
|||
return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset, decay);
|
||||
}
|
||||
|
||||
public static GaussDecayFunctionBuilder gaussDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
double decay,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new GaussDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName);
|
||||
}
|
||||
|
||||
public static LinearDecayFunctionBuilder linearDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new LinearDecayFunctionBuilder(fieldName, origin, scale, null, functionName);
|
||||
}
|
||||
|
||||
public static LinearDecayFunctionBuilder linearDecayFunction(String fieldName, Object origin, Object scale) {
|
||||
return new LinearDecayFunctionBuilder(fieldName, origin, scale, null);
|
||||
}
|
||||
|
@ -80,6 +140,16 @@ public class ScoreFunctionBuilders {
|
|||
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset);
|
||||
}
|
||||
|
||||
public static LinearDecayFunctionBuilder linearDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, functionName);
|
||||
}
|
||||
|
||||
public static LinearDecayFunctionBuilder linearDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
|
@ -90,23 +160,54 @@ public class ScoreFunctionBuilders {
|
|||
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, decay);
|
||||
}
|
||||
|
||||
public static LinearDecayFunctionBuilder linearDecayFunction(
|
||||
String fieldName,
|
||||
Object origin,
|
||||
Object scale,
|
||||
Object offset,
|
||||
double decay,
|
||||
@Nullable String functionName
|
||||
) {
|
||||
return new LinearDecayFunctionBuilder(fieldName, origin, scale, offset, decay, functionName);
|
||||
}
|
||||
|
||||
public static ScriptScoreFunctionBuilder scriptFunction(Script script) {
|
||||
return (new ScriptScoreFunctionBuilder(script));
|
||||
return scriptFunction(script, null);
|
||||
}
|
||||
|
||||
public static ScriptScoreFunctionBuilder scriptFunction(String script) {
|
||||
return (new ScriptScoreFunctionBuilder(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, script, emptyMap())));
|
||||
return scriptFunction(script, null);
|
||||
}
|
||||
|
||||
public static RandomScoreFunctionBuilder randomFunction() {
|
||||
return new RandomScoreFunctionBuilder();
|
||||
return randomFunction(null);
|
||||
}
|
||||
|
||||
public static WeightBuilder weightFactorFunction(float weight) {
|
||||
return (WeightBuilder) (new WeightBuilder().setWeight(weight));
|
||||
return weightFactorFunction(weight, null);
|
||||
}
|
||||
|
||||
public static FieldValueFactorFunctionBuilder fieldValueFactorFunction(String fieldName) {
|
||||
return new FieldValueFactorFunctionBuilder(fieldName);
|
||||
return fieldValueFactorFunction(fieldName, null);
|
||||
}
|
||||
|
||||
public static ScriptScoreFunctionBuilder scriptFunction(Script script, @Nullable String functionName) {
|
||||
return new ScriptScoreFunctionBuilder(script, functionName);
|
||||
}
|
||||
|
||||
public static ScriptScoreFunctionBuilder scriptFunction(String script, @Nullable String functionName) {
|
||||
return new ScriptScoreFunctionBuilder(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, script, emptyMap()), functionName);
|
||||
}
|
||||
|
||||
public static RandomScoreFunctionBuilder randomFunction(@Nullable String functionName) {
|
||||
return new RandomScoreFunctionBuilder(functionName);
|
||||
}
|
||||
|
||||
public static WeightBuilder weightFactorFunction(float weight, @Nullable String functionName) {
|
||||
return (WeightBuilder) (new WeightBuilder(functionName).setWeight(weight));
|
||||
}
|
||||
|
||||
public static FieldValueFactorFunctionBuilder fieldValueFactorFunction(String fieldName, @Nullable String functionName) {
|
||||
return new FieldValueFactorFunctionBuilder(fieldName, functionName);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.ParsingException;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
|
@ -57,10 +58,15 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder<ScriptScore
|
|||
private final Script script;
|
||||
|
||||
public ScriptScoreFunctionBuilder(Script script) {
|
||||
this(script, null);
|
||||
}
|
||||
|
||||
public ScriptScoreFunctionBuilder(Script script, @Nullable String functionName) {
|
||||
if (script == null) {
|
||||
throw new IllegalArgumentException("script must not be null");
|
||||
}
|
||||
this.script = script;
|
||||
setFunctionName(functionName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -112,7 +118,8 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder<ScriptScore
|
|||
searchScript,
|
||||
context.index().getName(),
|
||||
context.getShardId(),
|
||||
context.indexVersionCreated()
|
||||
context.indexVersionCreated(),
|
||||
getFunctionName()
|
||||
);
|
||||
} catch (Exception e) {
|
||||
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
|
||||
|
|
|
@ -195,9 +195,11 @@ public class ScriptScoreQueryBuilder extends AbstractQueryBuilder<ScriptScoreQue
|
|||
}
|
||||
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
|
||||
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup());
|
||||
Query query = this.query.toQuery(context);
|
||||
final QueryBuilder queryBuilder = this.query;
|
||||
Query query = queryBuilder.toQuery(context);
|
||||
return new ScriptScoreQuery(
|
||||
query,
|
||||
queryBuilder.queryName(),
|
||||
script,
|
||||
scoreScriptFactory,
|
||||
minScore,
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
package org.opensearch.index.query.functionscore;
|
||||
|
||||
import org.opensearch.common.Nullable;
|
||||
import org.opensearch.common.io.stream.StreamInput;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
import org.opensearch.common.lucene.search.function.ScoreFunction;
|
||||
|
@ -51,6 +52,13 @@ public class WeightBuilder extends ScoreFunctionBuilder<WeightBuilder> {
|
|||
*/
|
||||
public WeightBuilder() {}
|
||||
|
||||
/**
|
||||
* Standard constructor.
|
||||
*/
|
||||
public WeightBuilder(@Nullable String functionName) {
|
||||
setFunctionName(functionName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a stream.
|
||||
*/
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
package org.opensearch.script;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.opensearch.common.Nullable;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
@ -49,7 +50,21 @@ public interface ExplainableScoreScript {
|
|||
* want to explain how that was computed.
|
||||
*
|
||||
* @param subQueryScore the Explanation for _score
|
||||
* @deprecated please use {@code explain(Explanation subQueryScore, @Nullable String scriptName)}
|
||||
*/
|
||||
@Deprecated
|
||||
Explanation explain(Explanation subQueryScore) throws IOException;
|
||||
|
||||
/**
|
||||
* Build the explanation of the current document being scored
|
||||
* The script score needs the Explanation of the sub query score because it might use _score and
|
||||
* want to explain how that was computed.
|
||||
*
|
||||
* @param subQueryScore the Explanation for _score
|
||||
* @param scriptName the script name
|
||||
*/
|
||||
default Explanation explain(Explanation subQueryScore, @Nullable String scriptName) throws IOException {
|
||||
return explain(subQueryScore);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -88,6 +88,7 @@ import java.util.Collection;
|
|||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
import static org.hamcrest.CoreMatchers.endsWith;
|
||||
import static org.hamcrest.core.Is.is;
|
||||
import static org.hamcrest.core.IsEqual.equalTo;
|
||||
import static org.hamcrest.core.IsNot.not;
|
||||
|
@ -283,7 +284,8 @@ public class FunctionScoreTests extends OpenSearchTestCase {
|
|||
0,
|
||||
GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION,
|
||||
new IndexNumericFieldDataStub(),
|
||||
MultiValueMode.MAX
|
||||
MultiValueMode.MAX,
|
||||
null
|
||||
);
|
||||
private static final ScoreFunction EXP_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
|
||||
0,
|
||||
|
@ -292,7 +294,8 @@ public class FunctionScoreTests extends OpenSearchTestCase {
|
|||
0,
|
||||
ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION,
|
||||
new IndexNumericFieldDataStub(),
|
||||
MultiValueMode.MAX
|
||||
MultiValueMode.MAX,
|
||||
null
|
||||
);
|
||||
private static final ScoreFunction LIN_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
|
||||
0,
|
||||
|
@ -301,7 +304,48 @@ public class FunctionScoreTests extends OpenSearchTestCase {
|
|||
0,
|
||||
LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION,
|
||||
new IndexNumericFieldDataStub(),
|
||||
MultiValueMode.MAX
|
||||
MultiValueMode.MAX,
|
||||
null
|
||||
);
|
||||
|
||||
private static final ScoreFunction RANDOM_SCORE_FUNCTION_NAMED = new RandomScoreFunction(0, 0, new IndexFieldDataStub(), "func1");
|
||||
private static final ScoreFunction FIELD_VALUE_FACTOR_FUNCTION_NAMED = new FieldValueFactorFunction(
|
||||
"test",
|
||||
1,
|
||||
FieldValueFactorFunction.Modifier.LN,
|
||||
1.0,
|
||||
null,
|
||||
"func1"
|
||||
);
|
||||
private static final ScoreFunction GAUSS_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
|
||||
0,
|
||||
1,
|
||||
0.1,
|
||||
0,
|
||||
GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION,
|
||||
new IndexNumericFieldDataStub(),
|
||||
MultiValueMode.MAX,
|
||||
"func1"
|
||||
);
|
||||
private static final ScoreFunction EXP_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
|
||||
0,
|
||||
1,
|
||||
0.1,
|
||||
0,
|
||||
ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION,
|
||||
new IndexNumericFieldDataStub(),
|
||||
MultiValueMode.MAX,
|
||||
"func1"
|
||||
);
|
||||
private static final ScoreFunction LIN_DECAY_FUNCTION_NAMED = new DecayFunctionBuilder.NumericFieldDataScoreFunction(
|
||||
0,
|
||||
1,
|
||||
0.1,
|
||||
0,
|
||||
LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION,
|
||||
new IndexNumericFieldDataStub(),
|
||||
MultiValueMode.MAX,
|
||||
"func1"
|
||||
);
|
||||
private static final ScoreFunction WEIGHT_FACTOR_FUNCTION = new WeightFactorFunction(4);
|
||||
private static final String TEXT = "The way out is through.";
|
||||
|
@ -383,6 +427,58 @@ public class FunctionScoreTests extends OpenSearchTestCase {
|
|||
assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));
|
||||
}
|
||||
|
||||
public void testExplainFunctionScoreQueryWithName() throws IOException {
|
||||
Explanation functionExplanation = getFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION_NAMED);
|
||||
checkFunctionScoreExplanation(functionExplanation, "random score function (seed: 0, field: test, _name: func1)");
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0));
|
||||
|
||||
functionExplanation = getFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION_NAMED);
|
||||
checkFunctionScoreExplanation(functionExplanation, "field value function(_name: func1): ln(doc['test'].value?:1.0 * factor=1.0)");
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0));
|
||||
|
||||
functionExplanation = getFunctionScoreExplanation(searcher, GAUSS_DECAY_FUNCTION_NAMED);
|
||||
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
|
||||
assertThat(
|
||||
functionExplanation.getDetails()[0].getDetails()[0].toString(),
|
||||
equalTo(
|
||||
"0.1 = exp(-0.5*pow(MAX[Math.max(Math.abs"
|
||||
+ "(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.21714724095162594, _name: func1)\n"
|
||||
)
|
||||
);
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
|
||||
|
||||
functionExplanation = getFunctionScoreExplanation(searcher, EXP_DECAY_FUNCTION_NAMED);
|
||||
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
|
||||
assertThat(
|
||||
functionExplanation.getDetails()[0].getDetails()[0].toString(),
|
||||
equalTo(
|
||||
"0.1 = exp(- MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)] * 2.3025850929940455, _name: func1)\n"
|
||||
)
|
||||
);
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
|
||||
|
||||
functionExplanation = getFunctionScoreExplanation(searcher, LIN_DECAY_FUNCTION_NAMED);
|
||||
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
|
||||
assertThat(
|
||||
functionExplanation.getDetails()[0].getDetails()[0].toString(),
|
||||
equalTo(
|
||||
"0.1 = max(0.0, ((1.1111111111111112"
|
||||
+ " - MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)])/1.1111111111111112, _name: func1)\n"
|
||||
)
|
||||
);
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
|
||||
|
||||
functionExplanation = getFunctionScoreExplanation(searcher, new WeightFactorFunction(4, RANDOM_SCORE_FUNCTION_NAMED));
|
||||
checkFunctionScoreExplanation(functionExplanation, "product of:");
|
||||
assertThat(
|
||||
functionExplanation.getDetails()[0].getDetails()[0].toString(),
|
||||
endsWith("random score function (seed: 0, field: test, _name: func1)\n")
|
||||
);
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails()[1].toString(), equalTo("4.0 = weight\n"));
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
|
||||
assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));
|
||||
}
|
||||
|
||||
public Explanation getFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction scoreFunction) throws IOException {
|
||||
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new TermQuery(TERM), scoreFunction, CombineFunction.AVG, 0.0f, 100);
|
||||
Weight weight = searcher.createWeight(searcher.rewrite(functionScoreQuery), org.apache.lucene.search.ScoreMode.COMPLETE, 1f);
|
||||
|
|
|
@ -110,6 +110,34 @@ public class ScriptScoreQueryTests extends OpenSearchTestCase {
|
|||
assertThat(explanation.getValue(), equalTo(1.0));
|
||||
}
|
||||
|
||||
public void testExplainWithName() throws IOException {
|
||||
Script script = new Script("script using explain");
|
||||
ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> {
|
||||
assertNotNull(explanation);
|
||||
explanation.set("this explains the score");
|
||||
return 1.0;
|
||||
});
|
||||
|
||||
ScriptScoreQuery query = new ScriptScoreQuery(
|
||||
Queries.newMatchAllQuery(),
|
||||
"query1",
|
||||
script,
|
||||
factory,
|
||||
null,
|
||||
"index",
|
||||
0,
|
||||
Version.CURRENT
|
||||
);
|
||||
Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
|
||||
Explanation explanation = weight.explain(leafReaderContext, 0);
|
||||
assertNotNull(explanation);
|
||||
assertThat(explanation.getDescription(), equalTo("this explains the score"));
|
||||
assertThat(explanation.getValue(), equalTo(1.0));
|
||||
|
||||
assertThat(explanation.getDetails(), arrayWithSize(1));
|
||||
assertThat(explanation.getDetails()[0].getDescription(), equalTo("*:* (_name: query1)"));
|
||||
}
|
||||
|
||||
public void testExplainDefault() throws IOException {
|
||||
Script script = new Script("script without setting explanation");
|
||||
ScoreScript.LeafFactory factory = newFactory(script, true, explanation -> 1.5);
|
||||
|
|
Loading…
Reference in New Issue