add more combine functions and rename PLAIN to REPLACE
This commit is contained in:
parent
db100aa2de
commit
6035134047
|
@ -47,7 +47,7 @@ public enum CombineFunction {
|
|||
return res;
|
||||
}
|
||||
},
|
||||
PLAIN {
|
||||
REPLACE {
|
||||
@Override
|
||||
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
|
||||
return toFloat(queryBoost * Math.min(funcScore, maxBoost));
|
||||
|
@ -55,7 +55,7 @@ public enum CombineFunction {
|
|||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "plain";
|
||||
return "replace";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -70,20 +70,132 @@ public enum CombineFunction {
|
|||
return res;
|
||||
}
|
||||
|
||||
},
|
||||
SUM {
|
||||
@Override
|
||||
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
|
||||
return toFloat(queryBoost * (queryScore + Math.min(funcScore, maxBoost)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "sum";
|
||||
}
|
||||
|
||||
@Override
|
||||
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
|
||||
float score = queryBoost * (Math.min(funcExpl.getValue(), maxBoost) + queryExpl.getValue());
|
||||
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
|
||||
ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
|
||||
minExpl.addDetail(funcExpl);
|
||||
minExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
|
||||
ComplexExplanation sumExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost) + queryExpl.getValue(),
|
||||
"sum of");
|
||||
sumExpl.addDetail(queryExpl);
|
||||
sumExpl.addDetail(minExpl);
|
||||
res.addDetail(sumExpl);
|
||||
res.addDetail(new Explanation(queryBoost, "queryBoost"));
|
||||
return res;
|
||||
}
|
||||
|
||||
},
|
||||
AVG {
|
||||
@Override
|
||||
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
|
||||
return toFloat((queryBoost * (Math.min(funcScore, maxBoost) + queryScore) / 2.0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "avg";
|
||||
}
|
||||
|
||||
@Override
|
||||
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
|
||||
float score = toFloat(queryBoost * (queryExpl.getValue() + Math.min(funcExpl.getValue(), maxBoost)) / 2.0);
|
||||
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
|
||||
ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
|
||||
minExpl.addDetail(funcExpl);
|
||||
minExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
|
||||
ComplexExplanation avgExpl = new ComplexExplanation(true,
|
||||
toFloat((Math.min(funcExpl.getValue(), maxBoost) + queryExpl.getValue()) / 2.0), "avg of");
|
||||
avgExpl.addDetail(queryExpl);
|
||||
avgExpl.addDetail(minExpl);
|
||||
res.addDetail(avgExpl);
|
||||
res.addDetail(new Explanation(queryBoost, "queryBoost"));
|
||||
return res;
|
||||
}
|
||||
|
||||
},
|
||||
MIN {
|
||||
@Override
|
||||
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
|
||||
return toFloat(queryBoost * Math.min(queryScore, Math.min(funcScore, maxBoost)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "min";
|
||||
}
|
||||
|
||||
@Override
|
||||
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
|
||||
float score = toFloat(queryBoost * Math.min(queryExpl.getValue(), Math.min(funcExpl.getValue(), maxBoost)));
|
||||
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
|
||||
ComplexExplanation innerMinExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
|
||||
innerMinExpl.addDetail(funcExpl);
|
||||
innerMinExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
|
||||
ComplexExplanation outerMinExpl = new ComplexExplanation(true, Math.min(Math.min(funcExpl.getValue(), maxBoost),
|
||||
queryExpl.getValue()), "min of");
|
||||
outerMinExpl.addDetail(queryExpl);
|
||||
outerMinExpl.addDetail(innerMinExpl);
|
||||
res.addDetail(outerMinExpl);
|
||||
res.addDetail(new Explanation(queryBoost, "queryBoost"));
|
||||
return res;
|
||||
}
|
||||
|
||||
},
|
||||
MAX {
|
||||
@Override
|
||||
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
|
||||
return toFloat(queryBoost * (Math.max(queryScore, Math.min(funcScore, maxBoost))));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "max";
|
||||
}
|
||||
|
||||
@Override
|
||||
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
|
||||
float score = toFloat(queryBoost * Math.max(queryExpl.getValue(), Math.min(funcExpl.getValue(), maxBoost)));
|
||||
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
|
||||
ComplexExplanation innerMinExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
|
||||
innerMinExpl.addDetail(funcExpl);
|
||||
innerMinExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
|
||||
ComplexExplanation outerMaxExpl = new ComplexExplanation(true, Math.max(Math.min(funcExpl.getValue(), maxBoost),
|
||||
queryExpl.getValue()), "max of");
|
||||
outerMaxExpl.addDetail(queryExpl);
|
||||
outerMaxExpl.addDetail(innerMinExpl);
|
||||
res.addDetail(outerMaxExpl);
|
||||
res.addDetail(new Explanation(queryBoost, "queryBoost"));
|
||||
return res;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
public abstract float combine(double queryBoost, double queryScore, double funcScore, double maxBoost);
|
||||
|
||||
|
||||
public abstract String getName();
|
||||
|
||||
public static float toFloat(double input) {
|
||||
assert deviation(input) <= 0.001 : "input " + input + " out of float scope for function score deviation: " + deviation(input);
|
||||
return (float) input;
|
||||
}
|
||||
|
||||
|
||||
private static double deviation(double input) { // only with assert!
|
||||
float floatVersion = (float)input;
|
||||
return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d-(floatVersion) / input;
|
||||
float floatVersion = (float) input;
|
||||
return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d - (floatVersion) / input;
|
||||
}
|
||||
|
||||
public abstract ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost);
|
||||
|
|
|
@ -36,7 +36,7 @@ public class ScriptScoreFunction extends ScoreFunction {
|
|||
|
||||
|
||||
public ScriptScoreFunction(String sScript, Map<String, Object> params, SearchScript script) {
|
||||
super(CombineFunction.PLAIN);
|
||||
super(CombineFunction.REPLACE);
|
||||
this.sScript = sScript;
|
||||
this.params = params;
|
||||
this.script = script;
|
||||
|
|
|
@ -1187,7 +1187,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
|
|||
"child",
|
||||
QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0))
|
||||
.add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value"))
|
||||
.boostMode(CombineFunction.PLAIN.getName())).scoreType("sum")).execute().actionGet();
|
||||
.boostMode(CombineFunction.REPLACE.getName())).scoreType("sum")).execute().actionGet();
|
||||
|
||||
assertThat(response.getHits().totalHits(), equalTo(3l));
|
||||
assertThat(response.getHits().hits()[0].id(), equalTo("1"));
|
||||
|
@ -1204,7 +1204,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
|
|||
"child",
|
||||
QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0))
|
||||
.add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value"))
|
||||
.boostMode(CombineFunction.PLAIN.getName())).scoreType("max")).execute().actionGet();
|
||||
.boostMode(CombineFunction.REPLACE.getName())).scoreType("max")).execute().actionGet();
|
||||
|
||||
assertThat(response.getHits().totalHits(), equalTo(3l));
|
||||
assertThat(response.getHits().hits()[0].id(), equalTo("3"));
|
||||
|
@ -1221,7 +1221,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
|
|||
"child",
|
||||
QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0))
|
||||
.add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value"))
|
||||
.boostMode(CombineFunction.PLAIN.getName())).scoreType("avg")).execute().actionGet();
|
||||
.boostMode(CombineFunction.REPLACE.getName())).scoreType("avg")).execute().actionGet();
|
||||
|
||||
assertThat(response.getHits().totalHits(), equalTo(3l));
|
||||
assertThat(response.getHits().hits()[0].id(), equalTo("3"));
|
||||
|
@ -1238,7 +1238,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
|
|||
"parent",
|
||||
QueryBuilders.functionScoreQuery(matchQuery("p_field1", "p_value3"))
|
||||
.add(new ScriptScoreFunctionBuilder().script("doc['p_field2'].value"))
|
||||
.boostMode(CombineFunction.PLAIN.getName())).scoreType("score"))
|
||||
.boostMode(CombineFunction.REPLACE.getName())).scoreType("score"))
|
||||
.addSort(SortBuilders.fieldSort("c_field3")).addSort(SortBuilders.scoreSort()).execute().actionGet();
|
||||
|
||||
assertThat(response.getHits().totalHits(), equalTo(7l));
|
||||
|
|
|
@ -195,7 +195,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
|
|||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true).query(
|
||||
functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.PLAIN.getName()))));
|
||||
functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (2)));
|
||||
|
@ -234,7 +234,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
|
|||
assertThat(sh.getTotalHits(), equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5));
|
||||
float[] coords = {11,20};
|
||||
float[] coords = { 11, 20 };
|
||||
fb = new GaussDecayFunctionBuilder("loc", coords, "1000km");
|
||||
|
||||
response = client().search(
|
||||
|
@ -248,6 +248,97 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
|
|||
assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCombineModes() throws Exception {
|
||||
|
||||
createIndexMapped("test", "type1", "test", "string", "num", "double");
|
||||
ensureYellow();
|
||||
|
||||
List<IndexRequestBuilder> indexBuilders = new ArrayList<IndexRequestBuilder>();
|
||||
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("1").setIndex("test")
|
||||
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 1.0).endObject()));
|
||||
IndexRequestBuilder[] builders = indexBuilders.toArray(new IndexRequestBuilder[indexBuilders.size()]);
|
||||
|
||||
indexRandom("test", false, builders);
|
||||
refresh();
|
||||
|
||||
DecayFunctionBuilder fb = new GaussDecayFunctionBuilder("num", 0.0, 1.0).setScaleWeight(0.5);
|
||||
// function score should return 0.5 for this function
|
||||
|
||||
ActionFuture<SearchResponse> response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true).query(
|
||||
functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
|
||||
.boostMode(CombineFunction.MULT.getName()))));
|
||||
SearchResponse sr = response.actionGet();
|
||||
SearchHits sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5));
|
||||
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
|
||||
|
||||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true).query(
|
||||
functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
|
||||
.boostMode(CombineFunction.REPLACE.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat((double) sh.getAt(0).score(), closeTo(1.0, 1.e-5));
|
||||
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
|
||||
|
||||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
|
||||
.boostMode(CombineFunction.SUM.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat((double) sh.getAt(0).score(), closeTo(2.0 * (0.30685282 + 0.5), 1.e-5));
|
||||
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
|
||||
|
||||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
|
||||
.boostMode(CombineFunction.AVG.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat((double) sh.getAt(0).score(), closeTo((0.30685282 + 0.5), 1.e-5));
|
||||
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
|
||||
|
||||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
|
||||
.boostMode(CombineFunction.MIN.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat((double) sh.getAt(0).score(), closeTo(2.0 * (0.30685282), 1.e-5));
|
||||
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
|
||||
|
||||
response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
|
||||
.boostMode(CombineFunction.MAX.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (1)));
|
||||
assertThat(sh.getAt(0).getId(), equalTo("1"));
|
||||
assertThat((double) sh.getAt(0).score(), closeTo(1.0, 1.e-5));
|
||||
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = SearchPhaseExecutionException.class)
|
||||
public void testExceptionThrownIfScaleLE0() throws Exception {
|
||||
|
||||
|
|
|
@ -482,9 +482,9 @@ public class QueryRescorerTests extends AbstractSharedClusterTest {
|
|||
.queryRescorer(
|
||||
QueryBuilders.boolQuery()
|
||||
.disableCoord(true)
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.PLAIN.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("7.0f")).boostMode(CombineFunction.PLAIN.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.0f")).boostMode(CombineFunction.PLAIN.getName())))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.REPLACE.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("7.0f")).boostMode(CombineFunction.REPLACE.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.0f")).boostMode(CombineFunction.REPLACE.getName())))
|
||||
.setQueryWeight(primaryWeight)
|
||||
.setRescoreQueryWeight(secondaryWeight);
|
||||
|
||||
|
@ -497,10 +497,10 @@ public class QueryRescorerTests extends AbstractSharedClusterTest {
|
|||
.setPreference("test") // ensure we hit the same shards for tie-breaking
|
||||
.setQuery(QueryBuilders.boolQuery()
|
||||
.disableCoord(true)
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("2.0f")).boostMode(CombineFunction.PLAIN.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("3.0f")).boostMode(CombineFunction.PLAIN.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[2])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.PLAIN.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.2f")).boostMode(CombineFunction.PLAIN.getName())))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("2.0f")).boostMode(CombineFunction.REPLACE.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("3.0f")).boostMode(CombineFunction.REPLACE.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[2])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.REPLACE.getName()))
|
||||
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.2f")).boostMode(CombineFunction.REPLACE.getName())))
|
||||
.setFrom(0)
|
||||
.setSize(10)
|
||||
.setRescorer(rescoreQuery)
|
||||
|
|
Loading…
Reference in New Issue