add more combine functions and rename PLAIN to REPLACE

This commit is contained in:
Britta Weber 2013-08-16 15:11:41 +02:00
parent db100aa2de
commit 6035134047
5 changed files with 223 additions and 20 deletions

View File

@ -47,7 +47,7 @@ public enum CombineFunction {
return res; return res;
} }
}, },
PLAIN { REPLACE {
@Override @Override
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) { public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
return toFloat(queryBoost * Math.min(funcScore, maxBoost)); return toFloat(queryBoost * Math.min(funcScore, maxBoost));
@ -55,7 +55,7 @@ public enum CombineFunction {
@Override @Override
public String getName() { public String getName() {
return "plain"; return "replace";
} }
@Override @Override
@ -70,20 +70,132 @@ public enum CombineFunction {
return res; return res;
} }
},
SUM {
@Override
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
return toFloat(queryBoost * (queryScore + Math.min(funcScore, maxBoost)));
}
@Override
public String getName() {
return "sum";
}
@Override
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
float score = queryBoost * (Math.min(funcExpl.getValue(), maxBoost) + queryExpl.getValue());
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
minExpl.addDetail(funcExpl);
minExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
ComplexExplanation sumExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost) + queryExpl.getValue(),
"sum of");
sumExpl.addDetail(queryExpl);
sumExpl.addDetail(minExpl);
res.addDetail(sumExpl);
res.addDetail(new Explanation(queryBoost, "queryBoost"));
return res;
}
},
AVG {
@Override
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
return toFloat((queryBoost * (Math.min(funcScore, maxBoost) + queryScore) / 2.0));
}
@Override
public String getName() {
return "avg";
}
@Override
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
float score = toFloat(queryBoost * (queryExpl.getValue() + Math.min(funcExpl.getValue(), maxBoost)) / 2.0);
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
ComplexExplanation minExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
minExpl.addDetail(funcExpl);
minExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
ComplexExplanation avgExpl = new ComplexExplanation(true,
toFloat((Math.min(funcExpl.getValue(), maxBoost) + queryExpl.getValue()) / 2.0), "avg of");
avgExpl.addDetail(queryExpl);
avgExpl.addDetail(minExpl);
res.addDetail(avgExpl);
res.addDetail(new Explanation(queryBoost, "queryBoost"));
return res;
}
},
MIN {
@Override
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
return toFloat(queryBoost * Math.min(queryScore, Math.min(funcScore, maxBoost)));
}
@Override
public String getName() {
return "min";
}
@Override
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
float score = toFloat(queryBoost * Math.min(queryExpl.getValue(), Math.min(funcExpl.getValue(), maxBoost)));
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
ComplexExplanation innerMinExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
innerMinExpl.addDetail(funcExpl);
innerMinExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
ComplexExplanation outerMinExpl = new ComplexExplanation(true, Math.min(Math.min(funcExpl.getValue(), maxBoost),
queryExpl.getValue()), "min of");
outerMinExpl.addDetail(queryExpl);
outerMinExpl.addDetail(innerMinExpl);
res.addDetail(outerMinExpl);
res.addDetail(new Explanation(queryBoost, "queryBoost"));
return res;
}
},
MAX {
@Override
public float combine(double queryBoost, double queryScore, double funcScore, double maxBoost) {
return toFloat(queryBoost * (Math.max(queryScore, Math.min(funcScore, maxBoost))));
}
@Override
public String getName() {
return "max";
}
@Override
public ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost) {
float score = toFloat(queryBoost * Math.max(queryExpl.getValue(), Math.min(funcExpl.getValue(), maxBoost)));
ComplexExplanation res = new ComplexExplanation(true, score, "function score, product of:");
ComplexExplanation innerMinExpl = new ComplexExplanation(true, Math.min(funcExpl.getValue(), maxBoost), "Math.min of");
innerMinExpl.addDetail(funcExpl);
innerMinExpl.addDetail(new Explanation(maxBoost, "maxBoost"));
ComplexExplanation outerMaxExpl = new ComplexExplanation(true, Math.max(Math.min(funcExpl.getValue(), maxBoost),
queryExpl.getValue()), "max of");
outerMaxExpl.addDetail(queryExpl);
outerMaxExpl.addDetail(innerMinExpl);
res.addDetail(outerMaxExpl);
res.addDetail(new Explanation(queryBoost, "queryBoost"));
return res;
}
}; };
public abstract float combine(double queryBoost, double queryScore, double funcScore, double maxBoost); public abstract float combine(double queryBoost, double queryScore, double funcScore, double maxBoost);
public abstract String getName(); public abstract String getName();
public static float toFloat(double input) { public static float toFloat(double input) {
assert deviation(input) <= 0.001 : "input " + input + " out of float scope for function score deviation: " + deviation(input); assert deviation(input) <= 0.001 : "input " + input + " out of float scope for function score deviation: " + deviation(input);
return (float) input; return (float) input;
} }
private static double deviation(double input) { // only with assert! private static double deviation(double input) { // only with assert!
float floatVersion = (float)input; float floatVersion = (float) input;
return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d-(floatVersion) / input; return Double.compare(floatVersion, input) == 0 || input == 0.0d ? 0 : 1.d - (floatVersion) / input;
} }
public abstract ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost); public abstract ComplexExplanation explain(float queryBoost, Explanation queryExpl, Explanation funcExpl, float maxBoost);

View File

@ -36,7 +36,7 @@ public class ScriptScoreFunction extends ScoreFunction {
public ScriptScoreFunction(String sScript, Map<String, Object> params, SearchScript script) { public ScriptScoreFunction(String sScript, Map<String, Object> params, SearchScript script) {
super(CombineFunction.PLAIN); super(CombineFunction.REPLACE);
this.sScript = sScript; this.sScript = sScript;
this.params = params; this.params = params;
this.script = script; this.script = script;

View File

@ -1187,7 +1187,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
"child", "child",
QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0)) QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0))
.add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value")) .add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value"))
.boostMode(CombineFunction.PLAIN.getName())).scoreType("sum")).execute().actionGet(); .boostMode(CombineFunction.REPLACE.getName())).scoreType("sum")).execute().actionGet();
assertThat(response.getHits().totalHits(), equalTo(3l)); assertThat(response.getHits().totalHits(), equalTo(3l));
assertThat(response.getHits().hits()[0].id(), equalTo("1")); assertThat(response.getHits().hits()[0].id(), equalTo("1"));
@ -1204,7 +1204,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
"child", "child",
QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0)) QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0))
.add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value")) .add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value"))
.boostMode(CombineFunction.PLAIN.getName())).scoreType("max")).execute().actionGet(); .boostMode(CombineFunction.REPLACE.getName())).scoreType("max")).execute().actionGet();
assertThat(response.getHits().totalHits(), equalTo(3l)); assertThat(response.getHits().totalHits(), equalTo(3l));
assertThat(response.getHits().hits()[0].id(), equalTo("3")); assertThat(response.getHits().hits()[0].id(), equalTo("3"));
@ -1221,7 +1221,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
"child", "child",
QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0)) QueryBuilders.functionScoreQuery(matchQuery("c_field2", 0))
.add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value")) .add(new ScriptScoreFunctionBuilder().script("doc['c_field1'].value"))
.boostMode(CombineFunction.PLAIN.getName())).scoreType("avg")).execute().actionGet(); .boostMode(CombineFunction.REPLACE.getName())).scoreType("avg")).execute().actionGet();
assertThat(response.getHits().totalHits(), equalTo(3l)); assertThat(response.getHits().totalHits(), equalTo(3l));
assertThat(response.getHits().hits()[0].id(), equalTo("3")); assertThat(response.getHits().hits()[0].id(), equalTo("3"));
@ -1238,7 +1238,7 @@ public class SimpleChildQuerySearchTests extends AbstractSharedClusterTest {
"parent", "parent",
QueryBuilders.functionScoreQuery(matchQuery("p_field1", "p_value3")) QueryBuilders.functionScoreQuery(matchQuery("p_field1", "p_value3"))
.add(new ScriptScoreFunctionBuilder().script("doc['p_field2'].value")) .add(new ScriptScoreFunctionBuilder().script("doc['p_field2'].value"))
.boostMode(CombineFunction.PLAIN.getName())).scoreType("score")) .boostMode(CombineFunction.REPLACE.getName())).scoreType("score"))
.addSort(SortBuilders.fieldSort("c_field3")).addSort(SortBuilders.scoreSort()).execute().actionGet(); .addSort(SortBuilders.fieldSort("c_field3")).addSort(SortBuilders.scoreSort()).execute().actionGet();
assertThat(response.getHits().totalHits(), equalTo(7l)); assertThat(response.getHits().totalHits(), equalTo(7l));

View File

@ -195,7 +195,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
response = client().search( response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source( searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true).query( searchSource().explain(true).query(
functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.PLAIN.getName())))); functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
sr = response.actionGet(); sr = response.actionGet();
sh = sr.getHits(); sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (2))); assertThat(sh.getTotalHits(), equalTo((long) (2)));
@ -234,7 +234,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
assertThat(sh.getTotalHits(), equalTo((long) (1))); assertThat(sh.getTotalHits(), equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1")); assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5)); assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5));
float[] coords = {11,20}; float[] coords = { 11, 20 };
fb = new GaussDecayFunctionBuilder("loc", coords, "1000km"); fb = new GaussDecayFunctionBuilder("loc", coords, "1000km");
response = client().search( response = client().search(
@ -248,6 +248,97 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5)); assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5));
} }
@Test
public void testCombineModes() throws Exception {
createIndexMapped("test", "type1", "test", "string", "num", "double");
ensureYellow();
List<IndexRequestBuilder> indexBuilders = new ArrayList<IndexRequestBuilder>();
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("1").setIndex("test")
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 1.0).endObject()));
IndexRequestBuilder[] builders = indexBuilders.toArray(new IndexRequestBuilder[indexBuilders.size()]);
indexRandom("test", false, builders);
refresh();
DecayFunctionBuilder fb = new GaussDecayFunctionBuilder("num", 0.0, 1.0).setScaleWeight(0.5);
// function score should return 0.5 for this function
ActionFuture<SearchResponse> response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true).query(
functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
.boostMode(CombineFunction.MULT.getName()))));
SearchResponse sr = response.actionGet();
SearchHits sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat((double) sh.getAt(0).score(), closeTo(0.30685282, 1.e-5));
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true).query(
functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
.boostMode(CombineFunction.REPLACE.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat((double) sh.getAt(0).score(), closeTo(1.0, 1.e-5));
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true)
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
.boostMode(CombineFunction.SUM.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat((double) sh.getAt(0).score(), closeTo(2.0 * (0.30685282 + 0.5), 1.e-5));
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true)
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
.boostMode(CombineFunction.AVG.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat((double) sh.getAt(0).score(), closeTo((0.30685282 + 0.5), 1.e-5));
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true)
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
.boostMode(CombineFunction.MIN.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat((double) sh.getAt(0).score(), closeTo(2.0 * (0.30685282), 1.e-5));
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true)
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boost(2.0f)
.boostMode(CombineFunction.MAX.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (1)));
assertThat(sh.getAt(0).getId(), equalTo("1"));
assertThat((double) sh.getAt(0).score(), closeTo(1.0, 1.e-5));
logger.info("--> Hit[0] {} Explanation:\n {}", sr.getHits().getAt(0).id(), sr.getHits().getAt(0).explanation());
}
@Test(expected = SearchPhaseExecutionException.class) @Test(expected = SearchPhaseExecutionException.class)
public void testExceptionThrownIfScaleLE0() throws Exception { public void testExceptionThrownIfScaleLE0() throws Exception {

View File

@ -482,9 +482,9 @@ public class QueryRescorerTests extends AbstractSharedClusterTest {
.queryRescorer( .queryRescorer(
QueryBuilders.boolQuery() QueryBuilders.boolQuery()
.disableCoord(true) .disableCoord(true)
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.PLAIN.getName())) .should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.REPLACE.getName()))
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("7.0f")).boostMode(CombineFunction.PLAIN.getName())) .should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("7.0f")).boostMode(CombineFunction.REPLACE.getName()))
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.0f")).boostMode(CombineFunction.PLAIN.getName()))) .should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.0f")).boostMode(CombineFunction.REPLACE.getName())))
.setQueryWeight(primaryWeight) .setQueryWeight(primaryWeight)
.setRescoreQueryWeight(secondaryWeight); .setRescoreQueryWeight(secondaryWeight);
@ -497,10 +497,10 @@ public class QueryRescorerTests extends AbstractSharedClusterTest {
.setPreference("test") // ensure we hit the same shards for tie-breaking .setPreference("test") // ensure we hit the same shards for tie-breaking
.setQuery(QueryBuilders.boolQuery() .setQuery(QueryBuilders.boolQuery()
.disableCoord(true) .disableCoord(true)
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("2.0f")).boostMode(CombineFunction.PLAIN.getName())) .should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[0])).add(new ScriptScoreFunctionBuilder().script("2.0f")).boostMode(CombineFunction.REPLACE.getName()))
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("3.0f")).boostMode(CombineFunction.PLAIN.getName())) .should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[1])).add(new ScriptScoreFunctionBuilder().script("3.0f")).boostMode(CombineFunction.REPLACE.getName()))
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[2])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.PLAIN.getName())) .should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[2])).add(new ScriptScoreFunctionBuilder().script("5.0f")).boostMode(CombineFunction.REPLACE.getName()))
.should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.2f")).boostMode(CombineFunction.PLAIN.getName()))) .should(QueryBuilders.functionScoreQuery(QueryBuilders.termQuery("field1", intToEnglish[3])).add(new ScriptScoreFunctionBuilder().script("0.2f")).boostMode(CombineFunction.REPLACE.getName())))
.setFrom(0) .setFrom(0)
.setSize(10) .setSize(10)
.setRescorer(rescoreQuery) .setRescorer(rescoreQuery)