convert weight functions tests to unit tests

This commit is contained in:
Britta Weber 2015-10-06 01:45:39 +02:00
parent 0915adaa71
commit 473d25beed
2 changed files with 198 additions and 298 deletions

View File

@ -25,10 +25,7 @@ import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.*;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.*;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
@ -48,14 +45,20 @@ import org.junit.Test;
import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.ExecutionException;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.core.IsEqual.equalTo;
public class FunctionScoreTests extends ESTestCase {
private static final String UNSUPPORTED = "Method not implemented. This is just a stub for testing.";
class IndexFieldDataStub implements IndexFieldData<AtomicFieldData> {
/**
* Stub for IndexFieldData. Needed by some score functions. Returns 1 as count always.
*/
private static class IndexFieldDataStub implements IndexFieldData<AtomicFieldData> {
@Override
public MappedFieldType.Names getFieldNames() {
return new MappedFieldType.Names("test");
@ -136,7 +139,10 @@ public class FunctionScoreTests extends ESTestCase {
}
}
class IndexNumericFieldDataStub implements IndexNumericFieldData {
/**
* Stub for IndexNumericFieldData needed by some score functions. Returns 1 as value always.
*/
private static class IndexNumericFieldDataStub implements IndexNumericFieldData {
@Override
public NumericType getNumericType() {
@ -232,6 +238,12 @@ public class FunctionScoreTests extends ESTestCase {
}
}
private static final ScoreFunction RANDOM_SCORE_FUNCTION = new RandomScoreFunction(0, 0, new IndexFieldDataStub());
private static final ScoreFunction FIELD_VALUE_FACTOR_FUNCTION = new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null);
private static final ScoreFunction GAUSS_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX);
private static final ScoreFunction EXP_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX);
private static final ScoreFunction LIN_DECAY_FUNCTION = new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX);
private static final ScoreFunction WEIGHT_FACTOR_FUNCTION = new WeightFactorFunction(4);
private static final String TEXT = "The way out is through.";
private static final String FIELD = "test";
private static final Term TERM = new Term(FIELD, "through");
@ -265,28 +277,35 @@ public class FunctionScoreTests extends ESTestCase {
@Test
public void testExplainFunctionScoreQuery() throws IOException {
Explanation functionExplanation = getFunctionScoreExplanation(searcher, new RandomScoreFunction(0, 0, new IndexFieldDataStub()));
Explanation functionExplanation = getFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION);
checkFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)");
assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null));
functionExplanation = getFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION);
checkFunctionScoreExplanation(functionExplanation, "field value function: ln(doc['test'].value?:1.0 * factor=1.0)");
assertThat(functionExplanation.getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX));
functionExplanation = getFunctionScoreExplanation(searcher, GAUSS_DECAY_FUNCTION);
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("0.1 = exp(-0.5*pow(MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.21714724095162594)\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX));
functionExplanation = getFunctionScoreExplanation(searcher, EXP_DECAY_FUNCTION);
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("0.1 = exp(- MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)] * 2.3025850929940455)\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX));
functionExplanation = getFunctionScoreExplanation(searcher, LIN_DECAY_FUNCTION);
checkFunctionScoreExplanation(functionExplanation, "Function for field test:");
assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("0.1 = max(0.0, ((1.1111111111111112 - MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)])/1.1111111111111112)\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFunctionScoreExplanation(searcher, WEIGHT_FACTOR_FUNCTION);
checkFunctionScoreExplanation(functionExplanation, "product of:");
assertThat(functionExplanation.getDetails()[0].getDetails()[0].toString(), equalTo("1.0 = constant score 1.0 - no function provided\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[1].toString(), equalTo("4.0 = weight\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails().length, equalTo(0));
assertThat(functionExplanation.getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));
}
public Explanation getFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction scoreFunction) throws IOException {
@ -303,26 +322,36 @@ public class FunctionScoreTests extends ESTestCase {
@Test
public void testExplainFiltersFunctionScoreQuery() throws IOException {
Explanation functionExplanation = getFiltersFunctionScoreExplanation(searcher, new RandomScoreFunction(0, 0, new IndexFieldDataStub()));
Explanation functionExplanation = getFiltersFunctionScoreExplanation(searcher, RANDOM_SCORE_FUNCTION);
checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)", 0);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));
functionExplanation = getFiltersFunctionScoreExplanation(searcher, new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null));
functionExplanation = getFiltersFunctionScoreExplanation(searcher, FIELD_VALUE_FACTOR_FUNCTION);
checkFiltersFunctionScoreExplanation(functionExplanation, "field value function: ln(doc['test'].value?:1.0 * factor=1.0)", 0);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails().length, equalTo(0));
functionExplanation = getFiltersFunctionScoreExplanation(searcher, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX));
functionExplanation = getFiltersFunctionScoreExplanation(searcher, GAUSS_DECAY_FUNCTION);
checkFiltersFunctionScoreExplanation(functionExplanation, "Function for field test:", 0);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].toString(), equalTo("0.1 = exp(-0.5*pow(MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)],2.0)/0.21714724095162594)\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFiltersFunctionScoreExplanation(searcher, EXP_DECAY_FUNCTION);
checkFiltersFunctionScoreExplanation(functionExplanation, "Function for field test:", 0);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].toString(), equalTo("0.1 = exp(- MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)] * 2.3025850929940455)\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].getDetails().length, equalTo(0));
functionExplanation = getFiltersFunctionScoreExplanation(searcher, LIN_DECAY_FUNCTION);
checkFiltersFunctionScoreExplanation(functionExplanation, "Function for field test:", 0);
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].toString(), equalTo("0.1 = max(0.0, ((1.1111111111111112 - MAX[Math.max(Math.abs(1.0(=doc value) - 0.0(=origin))) - 0.0(=offset), 0)])/1.1111111111111112)\n"));
assertThat(functionExplanation.getDetails()[0].getDetails()[0].getDetails()[1].getDetails()[0].getDetails().length, equalTo(0));
// now test all together
functionExplanation = getFiltersFunctionScoreExplanation(searcher
, new RandomScoreFunction(0, 0, new IndexFieldDataStub())
, new FieldValueFactorFunction("test", 1, FieldValueFactorFunction.Modifier.LN, new Double(1), null)
, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, GaussDecayFunctionBuilder.GAUSS_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX)
, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, ExponentialDecayFunctionBuilder.EXP_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX)
, new DecayFunctionBuilder.NumericFieldDataScoreFunction(0, 1, 0.1, 0, LinearDecayFunctionBuilder.LINEAR_DECAY_FUNCTION, new IndexNumericFieldDataStub(), MultiValueMode.MAX)
, RANDOM_SCORE_FUNCTION
, FIELD_VALUE_FACTOR_FUNCTION
, GAUSS_DECAY_FUNCTION
, EXP_DECAY_FUNCTION
, LIN_DECAY_FUNCTION
);
checkFiltersFunctionScoreExplanation(functionExplanation, "random score function (seed: 0)", 0);
@ -345,15 +374,19 @@ public class FunctionScoreTests extends ESTestCase {
}
public Explanation getFiltersFunctionScoreExplanation(IndexSearcher searcher, ScoreFunction... scoreFunctions) throws IOException {
FiltersFunctionScoreQuery filtersFunctionScoreQuery = getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode.AVG, CombineFunction.AVG, scoreFunctions);
Weight weight = filtersFunctionScoreQuery.createWeight(searcher, true);
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0);
return explanation.getDetails()[1];
}
public FiltersFunctionScoreQuery getFiltersFunctionScoreQuery(FiltersFunctionScoreQuery.ScoreMode scoreMode, CombineFunction combineFunction, ScoreFunction... scoreFunctions) {
FiltersFunctionScoreQuery.FilterFunction[] filterFunctions = new FiltersFunctionScoreQuery.FilterFunction[scoreFunctions.length];
for (int i = 0; i < scoreFunctions.length; i++) {
filterFunctions[i] = new FiltersFunctionScoreQuery.FilterFunction(
new TermQuery(TERM), scoreFunctions[i]);
}
FiltersFunctionScoreQuery filtersFunctionScoreQuery = new FiltersFunctionScoreQuery(new TermQuery(TERM), FiltersFunctionScoreQuery.ScoreMode.AVG, filterFunctions, 100, new Float(0.0), CombineFunction.AVG);
Weight weight = filtersFunctionScoreQuery.createWeight(searcher, true);
Explanation explanation = weight.explain(searcher.getIndexReader().leaves().get(0), 0);
return explanation.getDetails()[1];
return new FiltersFunctionScoreQuery(new TermQuery(TERM), scoreMode, filterFunctions, Float.MAX_VALUE, Float.MAX_VALUE * -1, combineFunction);
}
public void checkFiltersFunctionScoreExplanation(Explanation randomExplanation, String functionExpl, int whichFunction) {
@ -363,4 +396,141 @@ public class FunctionScoreTests extends ESTestCase {
assertThat(randomExplanation.getDetails()[0].getDetails()[whichFunction].getDetails()[0].getDescription(), equalTo("match filter: " + FIELD + ":" + TERM.text()));
assertThat(randomExplanation.getDetails()[0].getDetails()[whichFunction].getDetails()[1].getDescription(), equalTo(functionExpl));
}
private static float[] randomFloats(int size) {
float[] weights = new float[size];
for (int i = 0; i < weights.length; i++) {
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-5f;
}
return weights;
}
private static class ScoreFunctionStub extends ScoreFunction {
private float score;
ScoreFunctionStub(float score) {
super(CombineFunction.REPLACE);
this.score = score;
}
@Override
public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException {
return new LeafScoreFunction() {
@Override
public double score(int docId, float subQueryScore) {
return score;
}
@Override
public Explanation explainScore(int docId, Explanation subQueryScore) throws IOException {
throw new UnsupportedOperationException(UNSUPPORTED);
}
};
}
@Override
public boolean needsScores() {
return false;
}
@Override
protected boolean doEquals(ScoreFunction other) {
return false;
}
}
@Test
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
int numFunctions = randomIntBetween(1, 3);
float[] weights = randomFloats(numFunctions);
float[] scores = randomFloats(numFunctions);
ScoreFunctionStub[] scoreFunctionStubs = new ScoreFunctionStub[numFunctions];
for (int i = 0; i < numFunctions; i++) {
scoreFunctionStubs[i] = new ScoreFunctionStub(scores[i]);
}
WeightFactorFunction[] weightFunctionStubs = new WeightFactorFunction[numFunctions];
for (int i = 0; i < numFunctions; i++) {
weightFunctionStubs[i] = new WeightFactorFunction(weights[i], scoreFunctionStubs[i]);
}
FiltersFunctionScoreQuery filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MULTIPLY
, CombineFunction.REPLACE
, weightFunctionStubs
);
TopDocs topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1);
float scoreWithWeight = topDocsWithWeights.scoreDocs[0].score;
double score = 1;
for (int i = 0; i < weights.length; i++) {
score *= weights[i] * scores[i];
}
assertThat(scoreWithWeight / score, closeTo(1, 1.e-5d));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.SUM
, CombineFunction.REPLACE
, weightFunctionStubs
);
topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1);
scoreWithWeight = topDocsWithWeights.scoreDocs[0].score;
double sum = 0;
for (int i = 0; i < weights.length; i++) {
sum += weights[i] * scores[i];
}
assertThat(scoreWithWeight / sum, closeTo(1, 1.e-5d));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.AVG
, CombineFunction.REPLACE
, weightFunctionStubs
);
topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1);
scoreWithWeight = topDocsWithWeights.scoreDocs[0].score;
double norm = 0;
sum = 0;
for (int i = 0; i < weights.length; i++) {
norm += weights[i];
sum += weights[i] * scores[i];
}
assertThat(scoreWithWeight * norm / sum, closeTo(1, 1.e-5d));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MIN
, CombineFunction.REPLACE
, weightFunctionStubs
);
topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1);
scoreWithWeight = topDocsWithWeights.scoreDocs[0].score;
double min = Double.POSITIVE_INFINITY;
for (int i = 0; i < weights.length; i++) {
min = Math.min(min, weights[i] * scores[i]);
}
assertThat(scoreWithWeight / min, closeTo(1, 1.e-5d));
filtersFunctionScoreQueryWithWeights = getFiltersFunctionScoreQuery(
FiltersFunctionScoreQuery.ScoreMode.MAX
, CombineFunction.REPLACE
, weightFunctionStubs
);
topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1);
scoreWithWeight = topDocsWithWeights.scoreDocs[0].score;
double max = Double.NEGATIVE_INFINITY;
for (int i = 0; i < weights.length; i++) {
max = Math.max(max, weights[i] * scores[i]);
}
assertThat(scoreWithWeight / max, closeTo(1, 1.e-5d));
}
@Test
public void checkWeightOnlyCreatesBoostFunction() throws IOException {
FunctionScoreQuery filtersFunctionScoreQueryWithWeights = new FunctionScoreQuery(new MatchAllDocsQuery(), new WeightFactorFunction(2), 0.0f, CombineFunction.MULTIPLY, 100);
TopDocs topDocsWithWeights = searcher.search(filtersFunctionScoreQueryWithWeights, 1);
float score = topDocsWithWeights.scoreDocs[0].score;
assertThat(score, equalTo(2.0f));
}
}

View File

@ -19,19 +19,12 @@
package org.elasticsearch.messy.tests;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.geo.GeoPoint;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.common.lucene.search.function.FiltersFunctionScoreQuery;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.index.query.functionscore.weight.WeightBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.groovy.GroovyPlugin;
@ -48,289 +41,26 @@ import java.util.concurrent.ExecutionException;
import static org.elasticsearch.client.Requests.searchRequest;
import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.elasticsearch.index.query.QueryBuilders.*;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.*;
import static org.elasticsearch.index.query.QueryBuilders.functionScoreQuery;
import static org.elasticsearch.index.query.QueryBuilders.termQuery;
import static org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders.scriptFunction;
import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
import static org.elasticsearch.search.builder.SearchSourceBuilder.searchSource;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.*;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class FunctionScoreTests extends ESIntegTestCase {
static final String TYPE = "type";
static final String INDEX = "index";
static final String TEXT_FIELD = "text_field";
static final String DOUBLE_FIELD = "double_field";
static final String GEO_POINT_FIELD = "geo_point_field";
static final XContentBuilder SIMPLE_DOC;
static final XContentBuilder MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD;
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(GroovyPlugin.class);
}
static {
XContentBuilder simpleDoc;
XContentBuilder mappingWithDoubleAndGeoPointAndTestField;
try {
simpleDoc = jsonBuilder().startObject()
.field(TEXT_FIELD, "value")
.startObject(GEO_POINT_FIELD)
.field("lat", 10)
.field("lon", 20)
.endObject()
.field(DOUBLE_FIELD, Math.E)
.endObject();
} catch (IOException e) {
throw new ElasticsearchException("Exception while initializing FunctionScoreIT", e);
}
SIMPLE_DOC = simpleDoc;
try {
mappingWithDoubleAndGeoPointAndTestField = jsonBuilder().startObject()
.startObject(TYPE)
.startObject("properties")
.startObject(TEXT_FIELD)
.field("type", "string")
.endObject()
.startObject(GEO_POINT_FIELD)
.field("type", "geo_point")
.endObject()
.startObject(DOUBLE_FIELD)
.field("type", "double")
.endObject()
.endObject()
.endObject()
.endObject();
} catch (IOException e) {
throw new ElasticsearchException("Exception while initializing FunctionScoreIT", e);
}
MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD = mappingWithDoubleAndGeoPointAndTestField;
}
@Test
public void simpleWeightedFunctionsTest() throws IOException, ExecutionException, InterruptedException {
assertAcked(prepareCreate(INDEX).addMapping(
TYPE, MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD
));
ensureYellow();
index(INDEX, TYPE, "1", SIMPLE_DOC);
refresh();
SearchResponse response = client().search(
searchRequest().source(
searchSource().query(
functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{
new FunctionScoreQueryBuilder.FilterFunctionBuilder(gaussDecayFunction(GEO_POINT_FIELD, new GeoPoint(10, 20), "1000km")),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction(DOUBLE_FIELD).modifier(FieldValueFactorFunction.Modifier.LN)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(new Script("_index['" + TEXT_FIELD + "']['value'].tf()")))
})))).actionGet();
SearchResponse responseWithWeights = client().search(
searchRequest().source(
searchSource().query(
functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{
new FunctionScoreQueryBuilder.FilterFunctionBuilder(gaussDecayFunction(GEO_POINT_FIELD, new GeoPoint(10, 20), "1000km").setWeight(2)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(fieldValueFactorFunction(DOUBLE_FIELD).modifier(FieldValueFactorFunction.Modifier.LN).setWeight(2)),
new FunctionScoreQueryBuilder.FilterFunctionBuilder(scriptFunction(new Script("_index['" + TEXT_FIELD + "']['value'].tf()")).setWeight(2))
})))).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).getScore(), is(1.0f));
assertThat(responseWithWeights.getHits().getAt(0).getScore(), is(8.0f));
}
@Test
public void simpleWeightedFunctionsTestWithRandomWeightsAndRandomCombineMode() throws IOException, ExecutionException, InterruptedException {
assertAcked(prepareCreate(INDEX).addMapping(
TYPE,
MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD));
ensureYellow();
XContentBuilder doc = SIMPLE_DOC;
index(INDEX, TYPE, "1", doc);
refresh();
ScoreFunctionBuilder[] scoreFunctionBuilders = getScoreFunctionBuilders();
float[] weights = createRandomWeights(scoreFunctionBuilders.length);
float[] scores = getScores(scoreFunctionBuilders);
int weightscounter = 0;
FunctionScoreQueryBuilder.FilterFunctionBuilder[] filterFunctionBuilders = new FunctionScoreQueryBuilder.FilterFunctionBuilder[scoreFunctionBuilders.length];
for (ScoreFunctionBuilder builder : scoreFunctionBuilders) {
filterFunctionBuilders[weightscounter] = new FunctionScoreQueryBuilder.FilterFunctionBuilder(builder.setWeight(weights[weightscounter]));
weightscounter++;
}
FiltersFunctionScoreQuery.ScoreMode scoreMode = randomFrom(FiltersFunctionScoreQuery.ScoreMode.AVG, FiltersFunctionScoreQuery.ScoreMode.SUM,
FiltersFunctionScoreQuery.ScoreMode.MIN, FiltersFunctionScoreQuery.ScoreMode.MAX, FiltersFunctionScoreQuery.ScoreMode.MULTIPLY);
FunctionScoreQueryBuilder withWeights = functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), filterFunctionBuilders).scoreMode(scoreMode);
SearchResponse responseWithWeights = client().search(
searchRequest().source(searchSource().query(withWeights))
).actionGet();
double expectedScore = computeExpectedScore(weights, scores, scoreMode);
assertThat((float) expectedScore / responseWithWeights.getHits().getAt(0).getScore(), is(1.0f));
}
protected double computeExpectedScore(float[] weights, float[] scores, FiltersFunctionScoreQuery.ScoreMode scoreMode) {
double expectedScore;
switch(scoreMode) {
case MULTIPLY:
expectedScore = 1.0;
break;
case MAX:
expectedScore = Float.MAX_VALUE * -1.0;
break;
case MIN:
expectedScore = Float.MAX_VALUE;
break;
default:
expectedScore = 0.0;
break;
}
float weightSum = 0;
for (int i = 0; i < weights.length; i++) {
double functionScore = (double) weights[i] * scores[i];
weightSum += weights[i];
switch(scoreMode) {
case AVG:
expectedScore += functionScore;
break;
case MAX:
expectedScore = Math.max(functionScore, expectedScore);
break;
case MIN:
expectedScore = Math.min(functionScore, expectedScore);
break;
case SUM:
expectedScore += functionScore;
break;
case MULTIPLY:
expectedScore *= functionScore;
break;
default:
throw new UnsupportedOperationException();
}
}
if (scoreMode == FiltersFunctionScoreQuery.ScoreMode.AVG) {
expectedScore /= weightSum;
}
return expectedScore;
}
@Test
public void simpleWeightedFunctionsTestSingleFunction() throws IOException, ExecutionException, InterruptedException {
assertAcked(prepareCreate(INDEX).addMapping(
TYPE,
MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD));
ensureYellow();
XContentBuilder doc = jsonBuilder().startObject()
.field(TEXT_FIELD, "value")
.startObject(GEO_POINT_FIELD)
.field("lat", 12)
.field("lon", 21)
.endObject()
.field(DOUBLE_FIELD, 10)
.endObject();
index(INDEX, TYPE, "1", doc);
refresh();
ScoreFunctionBuilder[] scoreFunctionBuilders = getScoreFunctionBuilders();
ScoreFunctionBuilder scoreFunctionBuilder = scoreFunctionBuilders[randomInt(3)];
float[] weights = createRandomWeights(1);
float[] scores = getScores(scoreFunctionBuilder);
FunctionScoreQueryBuilder withWeights = functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), scoreFunctionBuilder.setWeight(weights[0]));
SearchResponse responseWithWeights = client().search(
searchRequest().source(searchSource().query(withWeights))
).actionGet();
assertThat( (double) scores[0] * weights[0]/ responseWithWeights.getHits().getAt(0).getScore(), closeTo(1.0, 1.e-6));
}
private float[] getScores(ScoreFunctionBuilder... scoreFunctionBuilders) {
float[] scores = new float[scoreFunctionBuilders.length];
int scorecounter = 0;
for (ScoreFunctionBuilder builder : scoreFunctionBuilders) {
SearchResponse response = client().search(
searchRequest().source(
searchSource().query(
functionScoreQuery(constantScoreQuery(termQuery(TEXT_FIELD, "value")), builder)
))).actionGet();
scores[scorecounter] = response.getHits().getAt(0).getScore();
scorecounter++;
}
return scores;
}
private float[] createRandomWeights(int size) {
float[] weights = new float[size];
for (int i = 0; i < weights.length; i++) {
weights[i] = randomFloat() * (randomBoolean() ? 1.0f : -1.0f) * randomInt(100) + 1.e-6f;
}
return weights;
}
public ScoreFunctionBuilder[] getScoreFunctionBuilders() {
ScoreFunctionBuilder[] builders = new ScoreFunctionBuilder[4];
builders[0] = gaussDecayFunction(GEO_POINT_FIELD, new GeoPoint(10, 20), "1000km");
builders[1] = randomFunction(10);
builders[2] = fieldValueFactorFunction(DOUBLE_FIELD).modifier(FieldValueFactorFunction.Modifier.LN);
builders[3] = scriptFunction(new Script("_index['" + TEXT_FIELD + "']['value'].tf()"));
return builders;
}
@Test
public void checkWeightOnlyCreatesBoostFunction() throws IOException {
assertAcked(prepareCreate(INDEX).addMapping(
TYPE,
MAPPING_WITH_DOUBLE_AND_GEO_POINT_AND_TEXT_FIELD));
ensureYellow();
index(INDEX, TYPE, "1", SIMPLE_DOC);
refresh();
String query =jsonBuilder().startObject()
.startObject("query")
.startObject("function_score")
.startArray("functions")
.startObject()
.field("weight",2)
.endObject()
.endArray()
.endObject()
.endObject()
.endObject().string();
SearchResponse response = client().search(
searchRequest().source(new BytesArray(query))
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(2.0f));
query =jsonBuilder().startObject()
.startObject("query")
.startObject("function_score")
.field("weight",2)
.endObject()
.endObject()
.endObject().string();
response = client().search(
searchRequest().source(new BytesArray(query))
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(2.0f));
response = client().search(
searchRequest().source(searchSource().query(functionScoreQuery(new WeightBuilder().setWeight(2.0f))))
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(2.0f));
response = client().search(
searchRequest().source(searchSource().query(functionScoreQuery(weightFactorFunction(2.0f))))
).actionGet();
assertSearchResponse(response);
assertThat(response.getHits().getAt(0).score(), equalTo(2.0f));
}
@Test
public void testScriptScoresNested() throws IOException {