diff --git a/core/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreTests.java b/core/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreTests.java index c1e5a5f3bf3..1c11e4a7a14 100644 --- a/core/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreTests.java +++ b/core/src/test/java/org/elasticsearch/index/query/functionscore/FunctionScoreTests.java @@ -25,10 +25,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; import org.apache.lucene.index.*; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.Weight; +import org.apache.lucene.search.*; import org.apache.lucene.store.Directory; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; @@ -48,14 +45,20 @@ import org.junit.Test; import java.io.IOException; import java.util.Collection; +import java.util.concurrent.ExecutionException; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.core.IsEqual.equalTo; public class FunctionScoreTests extends ESTestCase { private static final String UNSUPPORTED = "Method not implemented. This is just a stub for testing."; - class IndexFieldDataStub implements IndexFieldData { + + /** + * Stub for IndexFieldData. Needed by some score functions. Returns 1 as count always. + */ + private static class IndexFieldDataStub implements IndexFieldData { @Override public MappedFieldType.Names getFieldNames() { return new MappedFieldType.Names("test"); @@ -136,7 +139,10 @@ public class FunctionScoreTests extends ESTestCase { } } - class IndexNumericFieldDataStub implements IndexNumericFieldData { + /** + * Stub for IndexNumericFieldData needed by some score functions. Returns 1 as value always. + */ + private static class IndexNumericFieldDataStub implements IndexNumericFieldData { @Override public NumericType getNumericType() { @@ -232,6 +238,12 @@ public class FunctionScoreTests extends ESTestCase { } } + private static final ScoreFunction RANDOM_SCORE_FUNCTION = new RandomScoreFunction(0, 0, new IndexFieldDataStub()); + private static final ScoreFunction FIELD_VALUE_FACTOR_FUNCTION = new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null); + private static final ScoreFunction GAUSS_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX); + private static final ScoreFunction EXP_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX); + private static final ScoreFunction LIN_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX); + private static final ScoreFunction WEIGHT_FACTOR_FUNCTION = new WeightFactorFunction(4); private static final String TEXT = "The way out is through."; private static final String FIELD = "test"; private static final Term TERM = new Term(FIELD, "through"); @@ -265,28 +277,35 @@ public class FunctionScoreTests extends ESTestCase { @Test public void testExplainFunctionScoreQuery() throws IOException { - Explanation functionExplanation = getFunctionScoreExplanation(searcher, new RandomScoreFunction(0, 0, new IndexFieldDataStub())); + Explanation functionExplanation = getFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION); checkFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)"); assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0)); - functionExplanation = getFunctionScoreExplanation(searcher, new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null)); + functionExplanation = getFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION); checkFunctionScoreExplanation(functionExplanation, "field value function: ln(doc['test'].value?:1.0 * factor=1.0)"); assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0)); - functionExplanation = getFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX)); + functionExplanation = getFunctionScoreExplanation(searcher, GAUSS_DECAY_FUNCTION); checkFunctionScoreExplanation(functionExplanation, "Function for field test:"); assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("0.1 = exp(-0.5*pow(MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.21714724095162594)\n")); assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); - functionExplanation = getFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX)); + functionExplanation = getFunctionScoreExplanation(searcher, EXP_DECAY_FUNCTION); checkFunctionScoreExplanation(functionExplanation, "Function for field test:"); assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("0.1 = exp(- MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)] * 2.3025850929940455)\n")); assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); - functionExplanation = getFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX)); + functionExplanation = getFunctionScoreExplanation(searcher, LIN_DECAY_FUNCTION); checkFunctionScoreExplanation(functionExplanation, "Function for field test:"); assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("0.1 = max(0.0, ((1.1111111111111112 - MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)])/1.1111111111111112)\n")); assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); + + functionExplanation = getFunctionScoreExplanation(searcher, WEIGHT_FACTOR_FUNCTION); + checkFunctionScoreExplanation(functionExplanation, "product of:"); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("1.0 = constant score 1.0 - no function provided\n")); + assertThat(functionExplanation.getDetails()[0].getDetails()[1].toString(), equalTo("4.0 = weight\n")); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0)); + assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0)); } public Explanation getFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction scoreFunction) throws IOException { @@ -303,26 +322,36 @@ public class FunctionScoreTests extends ESTestCase { @Test public void testExplainFiltersFunctionScoreQuery() throws IOException { - Explanation functionExplanation = getFiltersFunctionScoreExplanation(searcher, new RandomScoreFunction(0, 0, new IndexFieldDataStub())); + Explanation functionExplanation = getFiltersFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION); checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)", 0); assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails().length, equalTo(0)); - functionExplanation = getFiltersFunctionScoreExplanation(searcher, new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null)); + functionExplanation = getFiltersFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION); checkFiltersFunctionScoreExplanation(functionExplanation, "field value function: ln(doc['test'].value?:1.0 * factor=1.0)", 0); assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails().length, equalTo(0)); - functionExplanation = getFiltersFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX)); + functionExplanation = getFiltersFunctionScoreExplanation(searcher, GAUSS_DECAY_FUNCTION); checkFiltersFunctionScoreExplanation(functionExplanation, "Function for field test:", 0); assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].toString(), equalTo("0.1 = exp(-0.5*pow(MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.21714724095162594)\n")); assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].getDetails().length, equalTo(0)); + functionExplanation = getFiltersFunctionScoreExplanation(searcher, EXP_DECAY_FUNCTION); + checkFiltersFunctionScoreExplanation(functionExplanation, "Function for field test:", 0); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].toString(), equalTo("0.1 = exp(- MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)] * 2.3025850929940455)\n")); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].getDetails().length, equalTo(0)); + + functionExplanation = getFiltersFunctionScoreExplanation(searcher, LIN_DECAY_FUNCTION); + checkFiltersFunctionScoreExplanation(functionExplanation, "Function for field test:", 0); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].toString(), equalTo("0.1 = max(0.0, ((1.1111111111111112 - MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)])/1.1111111111111112)\n")); + assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].getDetails().length, equalTo(0)); + // now test all together functionExplanation = getFiltersFunctionScoreExplanation(searcher - , new RandomScoreFunction(0, 0, new IndexFieldDataStub()) - , new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null) - , new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX) - , new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX) - , new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX) + , RANDOM_SCORE_FUNCTION + , FIELD_VALUE_FACTOR_FUNCTION + , GAUSS_DECAY_FUNCTION + , EXP_DECAY_FUNCTION + , LIN_DECAY_FUNCTION ); checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)", 0); @@ -345,15 +374,19 @@ public class FunctionScoreTests extends ESTestCase { } public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException { + FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions); + Weight weight = filtersFunctionScoreQuery.createWeight(searcher, true); + Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0); + return explanation.getDetails()[1]; + } + + public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) { FiltersFunctionScoreQuery.FilterFunction[] filterFunctions = new FiltersFunctionScoreQuery.FilterFunction[scoreFunctions.length]; for (int i = 0; i < scoreFunctions.length; i++) { filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction( new TermQuery(TERM), scoreFunctions[i]); } - FiltersFunctionScoreQuery filtersFunctionScoreQuery = new FiltersFunctionScoreQuery(new TermQuery(TERM), FiltersFunctionScoreQuery.ScoreMode.AVG, filterFunctions, 100, new Float(0.0), CombineFunction.AVG); - Weight weight = filtersFunctionScoreQuery.createWeight(searcher, true); - Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0); - return explanation.getDetails()[1]; + return new FiltersFunctionScoreQuery(new TermQuery(TERM), scoreMode, filterFunctions, Float.MAX_VALUE, Float.MAX_VALUE * -1, combineFunction); } public void checkFiltersFunctionScoreExplanation(Explanation randomExplanation, String functionExpl, int whichFunction) { @@ -363,4 +396,141 @@ public class FunctionScoreTests extends ESTestCase { assertThat(randomExplanation.getDetails()[0].getDetails()[whichFunction].getDetails()[0].getDescription(), equalTo("match filter: " + FIELD + ":" + TERM.text())); assertThat(randomExplanation.getDetails()[0].getDetails()[whichFunction].getDetails()[1].getDescription(), equalTo(functionExpl)); } + + private static float[] randomFloats(int size) { + float[] weights = new float[size]; + for (int i = 0; i < weights.length; i++) { + weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f; + } + return weights; + } + + private static class ScoreFunctionStub extends ScoreFunction { + private float score; + + ScoreFunctionStub(float score) { + super(CombineFunction.REPLACE); + this.score = score; + } + + @Override + public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException { + return new LeafScoreFunction() { + @Override + public double score(int docId, float subQueryScore) { + return score; + } + + @Override + public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException { + throw new UnsupportedOperationException(UNSUPPORTED); + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + protected boolean doEquals(ScoreFunction other) { + return false; + } + } + + @Test + public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException { + int numFunctions = randomIntBetween(1, 3); + float[] weights = randomFloats(numFunctions); + float[] scores = randomFloats(numFunctions); + ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions]; + for (int i = 0; i < numFunctions; i++) { + scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]); + } + WeightFactorFunction[] weightFunctionStubs = new WeightFactorFunction[numFunctions]; + for (int i = 0; i < numFunctions; i++) { + weightFunctionStubs[i] = new WeightFactorFunction(weights[i], scoreFunctionStubs[i]); + } + + FiltersFunctionScoreQuery filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( + FiltersFunctionScoreQuery.ScoreMode.MULTIPLY + , CombineFunction.REPLACE + , weightFunctionStubs + ); + + TopDocs topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1); + float scoreWithWeight = topDocsWithWeights.scoreDocs[0].score; + double score = 1; + for (int i = 0; i < weights.length; i++) { + score *= weights[i] * scores[i]; + } + assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d)); + + filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( + FiltersFunctionScoreQuery.ScoreMode.SUM + , CombineFunction.REPLACE + , weightFunctionStubs + ); + + topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1); + scoreWithWeight = topDocsWithWeights.scoreDocs[0].score; + double sum = 0; + for (int i = 0; i < weights.length; i++) { + sum += weights[i] * scores[i]; + } + assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d)); + + filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( + FiltersFunctionScoreQuery.ScoreMode.AVG + , CombineFunction.REPLACE + , weightFunctionStubs + ); + + topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1); + scoreWithWeight = topDocsWithWeights.scoreDocs[0].score; + double norm = 0; + sum = 0; + for (int i = 0; i < weights.length; i++) { + norm += weights[i]; + sum += weights[i] * scores[i]; + } + assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d)); + + filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( + FiltersFunctionScoreQuery.ScoreMode.MIN + , CombineFunction.REPLACE + , weightFunctionStubs + ); + + topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1); + scoreWithWeight = topDocsWithWeights.scoreDocs[0].score; + double min = Double.POSITIVE_INFINITY; + for (int i = 0; i < weights.length; i++) { + min = Math.min(min, weights[i] * scores[i]); + } + assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d)); + + filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery( + FiltersFunctionScoreQuery.ScoreMode.MAX + , CombineFunction.REPLACE + , weightFunctionStubs + ); + + topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1); + scoreWithWeight = topDocsWithWeights.scoreDocs[0].score; + double max = Double.NEGATIVE_INFINITY; + for (int i = 0; i < weights.length; i++) { + max = Math.max(max, weights[i] * scores[i]); + } + assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d)); + } + + @Test + public void checkWeightOnlyCreatesBoostFunction() throws IOException { + FunctionScoreQuery filtersFunctionScoreQueryWithWeights = new FunctionScoreQuery(new MatchAllDocsQuery(), new WeightFactorFunction(2), 0.0f, CombineFunction.MULTIPLY, 100); + TopDocs topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1); + float score = topDocsWithWeights.scoreDocs[0].score; + assertThat(score, equalTo(2.0f)); + } } \ No newline at end of file diff --git a/plugins/lang-groovy/src/test/java/org/elasticsearch/messy/tests/FunctionScoreTests.java b/plugins/lang-groovy/src/test/java/org/elasticsearch/messy/tests/FunctionScoreTests.java index 677eb98dfd1..51fc5a4de8b 100644 --- a/plugins/lang-groovy/src/test/java/org/elasticsearch/messy/tests/FunctionScoreTests.java +++ b/plugins/lang-groovy/src/test/java/org/elasticsearch/messy/tests/FunctionScoreTests.java @@ -19,19 +19,12 @@ package org.elasticsearch.messy.tests; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.geo.GeoPoint; import org.elasticsearch.common.lucene.search.function.CombineFunction; -import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction; import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery; -import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; -import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; -import org.elasticsearch.index.query.functionscore.weight.WeightBuilder; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.script.Script; import org.elasticsearch.script.groovy.GroovyPlugin; @@ -48,289 +41,26 @@ import java.util.concurrent.ExecutionException; import static org.elasticsearch.client.Requests.searchRequest; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.elasticsearch.index.query.QueryBuilders.*; -import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.*; +import static org.elasticsearch.index.query.QueryBuilders.functionScoreQuery; +import static org.elasticsearch.index.query.QueryBuilders.termQuery; +import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction; import static org.elasticsearch.search.aggregations.AggregationBuilders.terms; import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; public class FunctionScoreTests extends ESIntegTestCase { static final String TYPE = "type"; static final String INDEX = "index"; - static final String TEXT_FIELD = "text_field"; - static final String DOUBLE_FIELD = "double_field"; - static final String GEO_POINT_FIELD = "geo_point_field"; - static final XContentBuilder SIMPLE_DOC; - static final XContentBuilder MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD; @Override protected Collection> nodePlugins() { return Collections.singleton(GroovyPlugin.class); } - static { - XContentBuilder simpleDoc; - XContentBuilder mappingWithDoubleAndGeoPointAndTestField; - try { - simpleDoc = jsonBuilder().startObject() - .field(TEXT_FIELD, "value") - .startObject(GEO_POINT_FIELD) - .field("lat", 10) - .field("lon", 20) - .endObject() - .field(DOUBLE_FIELD, Math.E) - .endObject(); - } catch (IOException e) { - throw new ElasticsearchException("Exception while initializing FunctionScoreIT", e); - } - SIMPLE_DOC = simpleDoc; - try { - - mappingWithDoubleAndGeoPointAndTestField = jsonBuilder().startObject() - .startObject(TYPE) - .startObject("properties") - .startObject(TEXT_FIELD) - .field("type", "string") - .endObject() - .startObject(GEO_POINT_FIELD) - .field("type", "geo_point") - .endObject() - .startObject(DOUBLE_FIELD) - .field("type", "double") - .endObject() - .endObject() - .endObject() - .endObject(); - } catch (IOException e) { - throw new ElasticsearchException("Exception while initializing FunctionScoreIT", e); - } - MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD = mappingWithDoubleAndGeoPointAndTestField; - } - - @Test - public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException { - assertAcked(prepareCreate(INDEX).addMapping( - TYPE, MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD - )); - ensureYellow(); - - index(INDEX, TYPE, "1", SIMPLE_DOC); - refresh(); - SearchResponse response = client().search( - searchRequest().source( - searchSource().query( - functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ - new FunctionScoreQueryBuilder.FilterFunctionBuilder(gaussDecayFunction(GEO_POINT_FIELD, new GeoPoint(10, 20), "1000km")), - new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction(DOUBLE_FIELD).modifier(FieldValueFactorFunction.Modifier.LN)), - new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(new Script("_index['" + TEXT_FIELD + "']['value'].tf()"))) - })))).actionGet(); - SearchResponse responseWithWeights = client().search( - searchRequest().source( - searchSource().query( - functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ - new FunctionScoreQueryBuilder.FilterFunctionBuilder(gaussDecayFunction(GEO_POINT_FIELD, new GeoPoint(10, 20), "1000km").setWeight(2)), - new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction(DOUBLE_FIELD).modifier(FieldValueFactorFunction.Modifier.LN).setWeight(2)), - new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(new Script("_index['" + TEXT_FIELD + "']['value'].tf()")).setWeight(2)) - })))).actionGet(); - - assertSearchResponse(response); - assertThat(response.getHits().getAt(0).getScore(), is(1.0f)); - assertThat(responseWithWeights.getHits().getAt(0).getScore(), is(8.0f)); - } - - @Test - public void simpleWeightedFunctionsTestWithRandomWeightsAndRandomCombineMode() throws IOException, ExecutionException, InterruptedException { - assertAcked(prepareCreate(INDEX).addMapping( - TYPE, - MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD)); - ensureYellow(); - - XContentBuilder doc = SIMPLE_DOC; - index(INDEX, TYPE, "1", doc); - refresh(); - ScoreFunctionBuilder[] scoreFunctionBuilders = getScoreFunctionBuilders(); - float[] weights = createRandomWeights(scoreFunctionBuilders.length); - float[] scores = getScores(scoreFunctionBuilders); - int weightscounter = 0; - FunctionScoreQueryBuilder.FilterFunctionBuilder[] filterFunctionBuilders = new FunctionScoreQueryBuilder.FilterFunctionBuilder[scoreFunctionBuilders.length]; - for (ScoreFunctionBuilder builder : scoreFunctionBuilders) { - filterFunctionBuilders[weightscounter] = new FunctionScoreQueryBuilder.FilterFunctionBuilder(builder.setWeight(weights[weightscounter])); - weightscounter++; - } - FiltersFunctionScoreQuery.ScoreMode scoreMode = randomFrom(FiltersFunctionScoreQuery.ScoreMode.AVG, FiltersFunctionScoreQuery.ScoreMode.SUM, - FiltersFunctionScoreQuery.ScoreMode.MIN, FiltersFunctionScoreQuery.ScoreMode.MAX, FiltersFunctionScoreQuery.ScoreMode.MULTIPLY); - FunctionScoreQueryBuilder withWeights = functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), filterFunctionBuilders).scoreMode(scoreMode); - - SearchResponse responseWithWeights = client().search( - searchRequest().source(searchSource().query(withWeights)) - ).actionGet(); - - double expectedScore = computeExpectedScore(weights, scores, scoreMode); - assertThat((float) expectedScore / responseWithWeights.getHits().getAt(0).getScore(), is(1.0f)); - } - - protected double computeExpectedScore(float[] weights, float[] scores, FiltersFunctionScoreQuery.ScoreMode scoreMode) { - double expectedScore; - switch(scoreMode) { - case MULTIPLY: - expectedScore = 1.0; - break; - case MAX: - expectedScore = Float.MAX_VALUE * -1.0; - break; - case MIN: - expectedScore = Float.MAX_VALUE; - break; - default: - expectedScore = 0.0; - break; - } - - float weightSum = 0; - for (int i = 0; i < weights.length; i++) { - double functionScore = (double) weights[i] * scores[i]; - weightSum += weights[i]; - switch(scoreMode) { - case AVG: - expectedScore += functionScore; - break; - case MAX: - expectedScore = Math.max(functionScore, expectedScore); - break; - case MIN: - expectedScore = Math.min(functionScore, expectedScore); - break; - case SUM: - expectedScore += functionScore; - break; - case MULTIPLY: - expectedScore *= functionScore; - break; - default: - throw new UnsupportedOperationException(); - } - } - if (scoreMode == FiltersFunctionScoreQuery.ScoreMode.AVG) { - expectedScore /= weightSum; - } - return expectedScore; - } - - @Test - public void simpleWeightedFunctionsTestSingleFunction() throws IOException, ExecutionException, InterruptedException { - assertAcked(prepareCreate(INDEX).addMapping( - TYPE, - MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD)); - ensureYellow(); - - XContentBuilder doc = jsonBuilder().startObject() - .field(TEXT_FIELD, "value") - .startObject(GEO_POINT_FIELD) - .field("lat", 12) - .field("lon", 21) - .endObject() - .field(DOUBLE_FIELD, 10) - .endObject(); - index(INDEX, TYPE, "1", doc); - refresh(); - ScoreFunctionBuilder[] scoreFunctionBuilders = getScoreFunctionBuilders(); - ScoreFunctionBuilder scoreFunctionBuilder = scoreFunctionBuilders[randomInt(3)]; - float[] weights = createRandomWeights(1); - float[] scores = getScores(scoreFunctionBuilder); - FunctionScoreQueryBuilder withWeights = functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), scoreFunctionBuilder.setWeight(weights[0])); - - SearchResponse responseWithWeights = client().search( - searchRequest().source(searchSource().query(withWeights)) - ).actionGet(); - - assertThat( (double) scores[0] * weights[0]/ responseWithWeights.getHits().getAt(0).getScore(), closeTo(1.0, 1.e-6)); - - } - - private float[] getScores(ScoreFunctionBuilder... scoreFunctionBuilders) { - float[] scores = new float[scoreFunctionBuilders.length]; - int scorecounter = 0; - for (ScoreFunctionBuilder builder : scoreFunctionBuilders) { - SearchResponse response = client().search( - searchRequest().source( - searchSource().query( - functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), builder) - ))).actionGet(); - scores[scorecounter] = response.getHits().getAt(0).getScore(); - scorecounter++; - } - return scores; - } - - private float[] createRandomWeights(int size) { - float[] weights = new float[size]; - for (int i = 0; i < weights.length; i++) { - weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-6f; - } - return weights; - } - - public ScoreFunctionBuilder[] getScoreFunctionBuilders() { - ScoreFunctionBuilder[] builders = new ScoreFunctionBuilder[4]; - builders[0] = gaussDecayFunction(GEO_POINT_FIELD, new GeoPoint(10, 20), "1000km"); - builders[1] = randomFunction(10); - builders[2] = fieldValueFactorFunction(DOUBLE_FIELD).modifier(FieldValueFactorFunction.Modifier.LN); - builders[3] = scriptFunction(new Script("_index['" + TEXT_FIELD + "']['value'].tf()")); - return builders; - } - - @Test - public void checkWeightOnlyCreatesBoostFunction() throws IOException { - assertAcked(prepareCreate(INDEX).addMapping( - TYPE, - MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD)); - ensureYellow(); - - index(INDEX, TYPE, "1", SIMPLE_DOC); - refresh(); - String query =jsonBuilder().startObject() - .startObject("query") - .startObject("function_score") - .startArray("functions") - .startObject() - .field("weight",2) - .endObject() - .endArray() - .endObject() - .endObject() - .endObject().string(); - SearchResponse response = client().search( - searchRequest().source(new BytesArray(query)) - ).actionGet(); - assertSearchResponse(response); - assertThat(response.getHits().getAt(0).score(), equalTo(2.0f)); - - query =jsonBuilder().startObject() - .startObject("query") - .startObject("function_score") - .field("weight",2) - .endObject() - .endObject() - .endObject().string(); - response = client().search( - searchRequest().source(new BytesArray(query)) - ).actionGet(); - assertSearchResponse(response); - assertThat(response.getHits().getAt(0).score(), equalTo(2.0f)); - response = client().search( - searchRequest().source(searchSource().query(functionScoreQuery(new WeightBuilder().setWeight(2.0f)))) - ).actionGet(); - assertSearchResponse(response); - assertThat(response.getHits().getAt(0).score(), equalTo(2.0f)); - response = client().search( - searchRequest().source(searchSource().query(functionScoreQuery(weightFactorFunction(2.0f)))) - ).actionGet(); - assertSearchResponse(response); - assertThat(response.getHits().getAt(0).score(), equalTo(2.0f)); - } @Test public void testScriptScoresNested() throws IOException {