diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionBuilder.java b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionBuilder.java index 98e67238513..37bb71da87c 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionBuilder.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionBuilder.java @@ -29,25 +29,33 @@ public abstract class DecayFunctionBuilder implements ScoreFunctionBuilder { protected static final String REFERNECE = "reference"; protected static final String SCALE = "scale"; protected static final String SCALE_WEIGHT = "scale_weight"; + protected static final String OFFSET = "offset"; private String fieldName; private Object reference; private Object scale; private double scaleWeight = -1; + private Object offset; public DecayFunctionBuilder(String fieldName, Object reference, Object scale) { this.fieldName = fieldName; this.reference = reference; this.scale = scale; } + public DecayFunctionBuilder setScaleWeight(double scaleWeight) { - if(scaleWeight <=0 || scaleWeight >= 1.0) { + if (scaleWeight <= 0 || scaleWeight >= 1.0) { throw new ElasticSearchIllegalStateException("scale weight parameter must be in range 0..1!"); } this.scaleWeight = scaleWeight; return this; } - + + public DecayFunctionBuilder setOffset(Object offset) { + this.offset = offset; + return this; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(getName()); @@ -57,6 +65,9 @@ public abstract class DecayFunctionBuilder implements ScoreFunctionBuilder { if (scaleWeight > 0) { builder.field(SCALE_WEIGHT, scaleWeight); } + if (offset != null) { + builder.field(OFFSET, offset); + } builder.endObject(); builder.endObject(); return builder; diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java index 207e89dc2ff..a583e15a74f 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/DecayFunctionParser.java @@ -164,6 +164,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { double scale = 0; double reference = 0; double scaleWeight = 0.5; + double offset = 0.0d; boolean scaleFound = false; boolean refFound = false; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { @@ -177,6 +178,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { } else if (parameterName.equals(DecayFunctionBuilder.REFERNECE)) { reference = parser.doubleValue(); refFound = true; + } else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) { + offset = parser.doubleValue(); } else { throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!"); } @@ -186,7 +189,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { + " must be set for numeric fields."); } IndexNumericFieldData numericFieldData = parseContext.fieldData().getForField(mapper); - return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), numericFieldData); + return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), numericFieldData); } private ScoreFunction parseGeoVariable(String fieldName, XContentParser parser, QueryParseContext parseContext, @@ -195,6 +198,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { String parameterName = null; GeoPoint reference = new GeoPoint(); String scaleString = "1km"; + String offsetString = "0km"; double scaleWeight = 0.5; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -205,6 +209,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { reference = GeoPoint.parse(parser); } else if (parameterName.equals(DecayFunctionBuilder.SCALE_WEIGHT)) { scaleWeight = parser.doubleValue(); + } else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) { + offsetString = parser.text(); } else { throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!"); } @@ -213,9 +219,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { throw new ElasticSearchParseException(DecayFunctionBuilder.REFERNECE + "must be set for geo fields."); } double scale = DistanceUnit.parse(scaleString, DistanceUnit.METERS, DistanceUnit.METERS); - + double offset = DistanceUnit.parse(offsetString, DistanceUnit.METERS, DistanceUnit.METERS); IndexGeoPointFieldData indexFieldData = parseContext.fieldData().getForField(mapper); - return new GeoFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), indexFieldData); + return new GeoFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), indexFieldData); } @@ -225,6 +231,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { String parameterName = null; String scaleString = null; String referenceString = null; + String offsetString = "0d"; double scaleWeight = 0.5; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -235,6 +242,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { referenceString = parser.text(); } else if (parameterName.equals(DecayFunctionBuilder.SCALE_WEIGHT)) { scaleWeight = parser.doubleValue(); + } else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) { + offsetString = parser.text(); } else { throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!"); } @@ -249,8 +258,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { } TimeValue val = TimeValue.parseTimeValue(scaleString, TimeValue.timeValueHours(24)); double scale = val.getMillis(); + val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24)); + double offset = val.getMillis(); IndexNumericFieldData numericFieldData = parseContext.fieldData().getForField(dateFieldMapper); - return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), numericFieldData); + return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), numericFieldData); } static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction { @@ -261,9 +272,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { private static final GeoDistance distFunction = GeoDistance.fromString("arc"); - public GeoFieldDataScoreFunction(GeoPoint reference, double scale, double scaleWeight, DecayFunction func, + public GeoFieldDataScoreFunction(GeoPoint reference, double scale, double scaleWeight, double offset, DecayFunction func, IndexGeoPointFieldData fieldData) { - super(scale, scaleWeight, func); + super(scale, scaleWeight, offset, func); this.reference = reference; this.fieldData = fieldData; } @@ -276,21 +287,26 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { @Override protected double distance(int docId) { GeoPoint other = geoPointValues.getValueMissing(docId, reference); - return distFunction.calculate(reference.lat(), reference.lon(), other.lat(), other.lon(), DistanceUnit.METERS); + double distance = Math.abs(distFunction.calculate(reference.lat(), reference.lon(), other.lat(), other.lon(), + DistanceUnit.METERS)) - offset; + if (distance < 0.0d) { + distance = 0.0d; + } + return distance; } @Override protected String getDistanceString(int docId) { final GeoPoint other = geoPointValues.getValueMissing(docId, reference); - return "arcDistance(" + other + "(=doc value), " + reference + ") = " + distance(docId); - + return "arcDistance(" + other + "(=doc value), " + reference + "(=reference)) - " + offset + + "(=offset) < 0.0 ? 0.0: arcDistance(" + other + "(=doc value), " + reference + "(=reference)) - " + offset + + "(=offset)"; } @Override protected String getFieldName() { return fieldData.getFieldNames().fullName(); } - } static class NumericFieldDataScoreFunction extends AbstractDistanceScoreFunction { @@ -299,9 +315,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { private final double reference; private DoubleValues doubleValues; - public NumericFieldDataScoreFunction(double reference, double scale, double scaleWeight, DecayFunction func, + public NumericFieldDataScoreFunction(double reference, double scale, double scaleWeight, double offset, DecayFunction func, IndexNumericFieldData fieldData) { - super(scale, scaleWeight, func); + super(scale, scaleWeight, offset, func); this.fieldData = fieldData; this.reference = reference; } @@ -312,12 +328,18 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { @Override protected double distance(int docId) { - return doubleValues.getValueMissing(docId, reference) - reference; + double distance = Math.abs(doubleValues.getValueMissing(docId, reference) - reference) - offset; + if (distance < 0.0) { + distance = 0.0; + } + return distance; } @Override protected String getDistanceString(int docId) { - return "(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - " + reference + ")"; + return "Math.abs(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - " + reference + "(=reference)) - " + + offset + "(=offset) < 0.0 ? 0.0: Math.abs(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - " + + reference + ") - " + offset + "(=offset)"; } @Override @@ -333,9 +355,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { public static abstract class AbstractDistanceScoreFunction extends ScoreFunction { private final double scale; + protected final double offset; private final DecayFunction func; - public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, DecayFunction func) { + public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, double offset, DecayFunction func) { super(CombineFunction.MULT); if (userSuppiedScale <= 0.0) { throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : scale must be > 0.0."); @@ -346,6 +369,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser { } this.scale = func.processScale(userSuppiedScale, userSuppliedScaleWeight); this.func = func; + if (offset < 0.0d) { + throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : offset must be > 0.0"); + } + this.offset = offset; } @Override diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/exp/ExponentialDecayFunctionParser.java b/src/main/java/org/elasticsearch/index/query/functionscore/exp/ExponentialDecayFunctionParser.java index 04f31c5aff0..69048869e40 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/exp/ExponentialDecayFunctionParser.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/exp/ExponentialDecayFunctionParser.java @@ -44,14 +44,14 @@ public class ExponentialDecayFunctionParser extends DecayFunctionParser { @Override public double evaluate(double value, double scale) { - return Math.exp(scale * Math.abs(value)); + return Math.exp(scale * value); } @Override public Explanation explainFunction(String valueExpl, double value, double scale) { ComplexExplanation ce = new ComplexExplanation(); ce.setValue((float) evaluate(value, scale)); - ce.setDescription("exp(- abs(" + valueExpl + ") * " + -1*scale + ")"); + ce.setDescription("exp(- abs(" + valueExpl + ") * " + -1 * scale + ")"); return ce; } diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/gauss/GaussDecayFunctionParser.java b/src/main/java/org/elasticsearch/index/query/functionscore/gauss/GaussDecayFunctionParser.java index 9e9ca335a96..31a111722a4 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/gauss/GaussDecayFunctionParser.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/gauss/GaussDecayFunctionParser.java @@ -47,7 +47,7 @@ public class GaussDecayFunctionParser extends DecayFunctionParser { public Explanation explainFunction(String valueExpl, double value, double scale) { ComplexExplanation ce = new ComplexExplanation(); ce.setValue((float) evaluate(value, scale)); - ce.setDescription("-exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1*scale + ")"); + ce.setDescription("-exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + ")"); return ce; } diff --git a/src/main/java/org/elasticsearch/index/query/functionscore/lin/LinearDecayFunctionParser.java b/src/main/java/org/elasticsearch/index/query/functionscore/lin/LinearDecayFunctionParser.java index 497b2bcfbc8..a62fc77bd84 100644 --- a/src/main/java/org/elasticsearch/index/query/functionscore/lin/LinearDecayFunctionParser.java +++ b/src/main/java/org/elasticsearch/index/query/functionscore/lin/LinearDecayFunctionParser.java @@ -43,8 +43,8 @@ public class LinearDecayFunctionParser extends DecayFunctionParser { final static class LinearDecayScoreFunction implements DecayFunction { @Override - public double evaluate(double value, double scale) { - return Math.max(0.0, (scale - Math.abs(value)) / scale); + public double evaluate(double value, double scale) { + return Math.max(0.0, (scale - value) / scale); } @Override diff --git a/src/test/java/org/elasticsearch/test/integration/search/functionscore/DecayFunctionScoreTests.java b/src/test/java/org/elasticsearch/test/integration/search/functionscore/DecayFunctionScoreTests.java index 46bada63d3c..6eced4e9adb 100644 --- a/src/test/java/org/elasticsearch/test/integration/search/functionscore/DecayFunctionScoreTests.java +++ b/src/test/java/org/elasticsearch/test/integration/search/functionscore/DecayFunctionScoreTests.java @@ -149,6 +149,84 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest { assertThat(sh.getAt(1).getId(), equalTo("2")); } + @Test + public void testDistanceScoreGeoLinGaussExpWithOffset() throws Exception { + + createIndexMapped("test", "type1", "test", "string", "num", "double"); + ensureYellow(); + + // add tw docs within offset + List indexBuilders = new ArrayList(); + indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("1").setIndex("test") + .setSource(jsonBuilder().startObject().field("test", "value").field("num", 0.5).endObject())); + indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("2").setIndex("test") + .setSource(jsonBuilder().startObject().field("test", "value").field("num", 1.7).endObject())); + + // add docs outside offset + int numDummyDocs = 20; + for (int i = 0; i < numDummyDocs; i++) { + indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId(Integer.toString(i + 3)).setIndex("test") + .setSource(jsonBuilder().startObject().field("test", "value").field("num", 3.0 + i).endObject())); + } + IndexRequestBuilder[] builders = indexBuilders.toArray(new IndexRequestBuilder[indexBuilders.size()]); + + indexRandom("test", false, builders); + refresh(); + + // Test Gauss + DecayFunctionBuilder fb = new GaussDecayFunctionBuilder("num", 1.0, 5.0); + fb.setOffset(1.0); + + ActionFuture response = client() + .search(searchRequest() + .searchType(SearchType.QUERY_THEN_FETCH) + .source(searchSource().explain(true).size(numDummyDocs + 2) + .query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName())))); + SearchResponse sr = response.actionGet(); + SearchHits sh = sr.getHits(); + assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2))); + assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2"))); + assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2"))); + assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score())); + for (int i = 0; i < numDummyDocs; i++) { + assertThat(sh.getAt(i + 2).getId(), equalTo(Integer.toString(i + 3))); + } + + // Test Exp + fb = new ExponentialDecayFunctionBuilder("num", 1.0, 5.0); + fb.setOffset(1.0); + + response = client() + .search(searchRequest() + .searchType(SearchType.QUERY_THEN_FETCH) + .source(searchSource().explain(true).size(numDummyDocs + 2) + .query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName())))); + sr = response.actionGet(); + sh = sr.getHits(); + assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2))); + assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2"))); + assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2"))); + assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score())); + for (int i = 0; i < numDummyDocs; i++) { + assertThat(sh.getAt(i + 2).getId(), equalTo(Integer.toString(i + 3))); + } + // Test Lin + fb = new LinearDecayFunctionBuilder("num", 1.0, 20.0); + fb.setOffset(1.0); + + response = client() + .search(searchRequest() + .searchType(SearchType.QUERY_THEN_FETCH) + .source(searchSource().explain(true).size(numDummyDocs + 2) + .query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName())))); + sr = response.actionGet(); + sh = sr.getHits(); + assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2))); + assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2"))); + assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2"))); + assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score())); + } + @Test public void testBoostModeSettingWorks() throws Exception { @@ -405,7 +483,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest { ActionFuture response = client().search( searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source( - searchSource().explain(true).query( + searchSource().explain(false).query( functionScoreQuery(termQuery("test", "value")).add(new MatchAllFilterBuilder(), gfb1) .add(new MatchAllFilterBuilder(), gfb2).scoreMode("multiply")))); @@ -464,7 +542,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest { .size(numDocs) .query(functionScoreQuery(termQuery("test", "value")).add(new MatchAllFilterBuilder(), gfb1) .add(new MatchAllFilterBuilder(), gfb2).add(new MatchAllFilterBuilder(), gfb3) - .scoreMode("multiply")))); + .scoreMode("multiply").boostMode(CombineFunction.REPLACE.getName())))); SearchResponse sr = response.actionGet(); ElasticsearchAssertions.assertNoFailures(sr); @@ -476,6 +554,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest { } for (int i = 0; i < numDocs - 1; i++) { assertThat(scores[i], lessThan(scores[i + 1])); + } } diff --git a/src/test/java/org/elasticsearch/test/integration/search/functionscore/FunctionScorePluginTests.java b/src/test/java/org/elasticsearch/test/integration/search/functionscore/FunctionScorePluginTests.java index ab1011c4da8..8b20189e382 100644 --- a/src/test/java/org/elasticsearch/test/integration/search/functionscore/FunctionScorePluginTests.java +++ b/src/test/java/org/elasticsearch/test/integration/search/functionscore/FunctionScorePluginTests.java @@ -118,7 +118,7 @@ public class FunctionScorePluginTests extends AbstractNodesTests { public static class CustomDistanceScoreParser extends DecayFunctionParser { - public static final String[] NAMES = {"linear_mult", "linearMult"}; + public static final String[] NAMES = { "linear_mult", "linearMult" }; @Override public String[] getNames() { @@ -138,7 +138,8 @@ public class FunctionScorePluginTests extends AbstractNodesTests { @Override public double evaluate(double value, double scale) { - return Math.abs(value); + + return value; } @Override @@ -155,10 +156,8 @@ public class FunctionScorePluginTests extends AbstractNodesTests { } } - public class CustomDistanceScoreBuilder extends DecayFunctionBuilder { - public CustomDistanceScoreBuilder(String fieldName, Object reference, Object scale) { super(fieldName, reference, scale); }