Add offset to decay function score
Docs within the offset will be scored with 1.0, decay only starts after offset is reached.
This commit is contained in:
parent
c0288a62e6
commit
41b4a14933
|
@ -29,25 +29,33 @@ public abstract class DecayFunctionBuilder implements ScoreFunctionBuilder {
|
|||
protected static final String REFERNECE = "reference";
|
||||
protected static final String SCALE = "scale";
|
||||
protected static final String SCALE_WEIGHT = "scale_weight";
|
||||
protected static final String OFFSET = "offset";
|
||||
|
||||
private String fieldName;
|
||||
private Object reference;
|
||||
private Object scale;
|
||||
private double scaleWeight = -1;
|
||||
private Object offset;
|
||||
|
||||
public DecayFunctionBuilder(String fieldName, Object reference, Object scale) {
|
||||
this.fieldName = fieldName;
|
||||
this.reference = reference;
|
||||
this.scale = scale;
|
||||
}
|
||||
|
||||
public DecayFunctionBuilder setScaleWeight(double scaleWeight) {
|
||||
if(scaleWeight <=0 || scaleWeight >= 1.0) {
|
||||
if (scaleWeight <= 0 || scaleWeight >= 1.0) {
|
||||
throw new ElasticSearchIllegalStateException("scale weight parameter must be in range 0..1!");
|
||||
}
|
||||
this.scaleWeight = scaleWeight;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
public DecayFunctionBuilder setOffset(Object offset) {
|
||||
this.offset = offset;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject(getName());
|
||||
|
@ -57,6 +65,9 @@ public abstract class DecayFunctionBuilder implements ScoreFunctionBuilder {
|
|||
if (scaleWeight > 0) {
|
||||
builder.field(SCALE_WEIGHT, scaleWeight);
|
||||
}
|
||||
if (offset != null) {
|
||||
builder.field(OFFSET, offset);
|
||||
}
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -164,6 +164,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
double scale = 0;
|
||||
double reference = 0;
|
||||
double scaleWeight = 0.5;
|
||||
double offset = 0.0d;
|
||||
boolean scaleFound = false;
|
||||
boolean refFound = false;
|
||||
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
|
||||
|
@ -177,6 +178,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
} else if (parameterName.equals(DecayFunctionBuilder.REFERNECE)) {
|
||||
reference = parser.doubleValue();
|
||||
refFound = true;
|
||||
} else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) {
|
||||
offset = parser.doubleValue();
|
||||
} else {
|
||||
throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!");
|
||||
}
|
||||
|
@ -186,7 +189,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
+ " must be set for numeric fields.");
|
||||
}
|
||||
IndexNumericFieldData<?> numericFieldData = parseContext.fieldData().getForField(mapper);
|
||||
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), numericFieldData);
|
||||
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), numericFieldData);
|
||||
}
|
||||
|
||||
private ScoreFunction parseGeoVariable(String fieldName, XContentParser parser, QueryParseContext parseContext,
|
||||
|
@ -195,6 +198,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
String parameterName = null;
|
||||
GeoPoint reference = new GeoPoint();
|
||||
String scaleString = "1km";
|
||||
String offsetString = "0km";
|
||||
double scaleWeight = 0.5;
|
||||
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
|
||||
if (token == XContentParser.Token.FIELD_NAME) {
|
||||
|
@ -205,6 +209,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
reference = GeoPoint.parse(parser);
|
||||
} else if (parameterName.equals(DecayFunctionBuilder.SCALE_WEIGHT)) {
|
||||
scaleWeight = parser.doubleValue();
|
||||
} else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) {
|
||||
offsetString = parser.text();
|
||||
} else {
|
||||
throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!");
|
||||
}
|
||||
|
@ -213,9 +219,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
throw new ElasticSearchParseException(DecayFunctionBuilder.REFERNECE + "must be set for geo fields.");
|
||||
}
|
||||
double scale = DistanceUnit.parse(scaleString, DistanceUnit.METERS, DistanceUnit.METERS);
|
||||
|
||||
double offset = DistanceUnit.parse(offsetString, DistanceUnit.METERS, DistanceUnit.METERS);
|
||||
IndexGeoPointFieldData<?> indexFieldData = parseContext.fieldData().getForField(mapper);
|
||||
return new GeoFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), indexFieldData);
|
||||
return new GeoFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), indexFieldData);
|
||||
|
||||
}
|
||||
|
||||
|
@ -225,6 +231,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
String parameterName = null;
|
||||
String scaleString = null;
|
||||
String referenceString = null;
|
||||
String offsetString = "0d";
|
||||
double scaleWeight = 0.5;
|
||||
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
|
||||
if (token == XContentParser.Token.FIELD_NAME) {
|
||||
|
@ -235,6 +242,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
referenceString = parser.text();
|
||||
} else if (parameterName.equals(DecayFunctionBuilder.SCALE_WEIGHT)) {
|
||||
scaleWeight = parser.doubleValue();
|
||||
} else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) {
|
||||
offsetString = parser.text();
|
||||
} else {
|
||||
throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!");
|
||||
}
|
||||
|
@ -249,8 +258,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
}
|
||||
TimeValue val = TimeValue.parseTimeValue(scaleString, TimeValue.timeValueHours(24));
|
||||
double scale = val.getMillis();
|
||||
val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24));
|
||||
double offset = val.getMillis();
|
||||
IndexNumericFieldData<?> numericFieldData = parseContext.fieldData().getForField(dateFieldMapper);
|
||||
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), numericFieldData);
|
||||
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), numericFieldData);
|
||||
}
|
||||
|
||||
static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction {
|
||||
|
@ -261,9 +272,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
|
||||
private static final GeoDistance distFunction = GeoDistance.fromString("arc");
|
||||
|
||||
public GeoFieldDataScoreFunction(GeoPoint reference, double scale, double scaleWeight, DecayFunction func,
|
||||
public GeoFieldDataScoreFunction(GeoPoint reference, double scale, double scaleWeight, double offset, DecayFunction func,
|
||||
IndexGeoPointFieldData<?> fieldData) {
|
||||
super(scale, scaleWeight, func);
|
||||
super(scale, scaleWeight, offset, func);
|
||||
this.reference = reference;
|
||||
this.fieldData = fieldData;
|
||||
}
|
||||
|
@ -276,21 +287,26 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
@Override
|
||||
protected double distance(int docId) {
|
||||
GeoPoint other = geoPointValues.getValueMissing(docId, reference);
|
||||
return distFunction.calculate(reference.lat(), reference.lon(), other.lat(), other.lon(), DistanceUnit.METERS);
|
||||
double distance = Math.abs(distFunction.calculate(reference.lat(), reference.lon(), other.lat(), other.lon(),
|
||||
DistanceUnit.METERS)) - offset;
|
||||
if (distance < 0.0d) {
|
||||
distance = 0.0d;
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getDistanceString(int docId) {
|
||||
final GeoPoint other = geoPointValues.getValueMissing(docId, reference);
|
||||
return "arcDistance(" + other + "(=doc value), " + reference + ") = " + distance(docId);
|
||||
|
||||
return "arcDistance(" + other + "(=doc value), " + reference + "(=reference)) - " + offset
|
||||
+ "(=offset) < 0.0 ? 0.0: arcDistance(" + other + "(=doc value), " + reference + "(=reference)) - " + offset
|
||||
+ "(=offset)";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getFieldName() {
|
||||
return fieldData.getFieldNames().fullName();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static class NumericFieldDataScoreFunction extends AbstractDistanceScoreFunction {
|
||||
|
@ -299,9 +315,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
private final double reference;
|
||||
private DoubleValues doubleValues;
|
||||
|
||||
public NumericFieldDataScoreFunction(double reference, double scale, double scaleWeight, DecayFunction func,
|
||||
public NumericFieldDataScoreFunction(double reference, double scale, double scaleWeight, double offset, DecayFunction func,
|
||||
IndexNumericFieldData<?> fieldData) {
|
||||
super(scale, scaleWeight, func);
|
||||
super(scale, scaleWeight, offset, func);
|
||||
this.fieldData = fieldData;
|
||||
this.reference = reference;
|
||||
}
|
||||
|
@ -312,12 +328,18 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
|
||||
@Override
|
||||
protected double distance(int docId) {
|
||||
return doubleValues.getValueMissing(docId, reference) - reference;
|
||||
double distance = Math.abs(doubleValues.getValueMissing(docId, reference) - reference) - offset;
|
||||
if (distance < 0.0) {
|
||||
distance = 0.0;
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getDistanceString(int docId) {
|
||||
return "(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - " + reference + ")";
|
||||
return "Math.abs(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - " + reference + "(=reference)) - "
|
||||
+ offset + "(=offset) < 0.0 ? 0.0: Math.abs(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - "
|
||||
+ reference + ") - " + offset + "(=offset)";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -333,9 +355,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
public static abstract class AbstractDistanceScoreFunction extends ScoreFunction {
|
||||
|
||||
private final double scale;
|
||||
protected final double offset;
|
||||
private final DecayFunction func;
|
||||
|
||||
public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, DecayFunction func) {
|
||||
public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, double offset, DecayFunction func) {
|
||||
super(CombineFunction.MULT);
|
||||
if (userSuppiedScale <= 0.0) {
|
||||
throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : scale must be > 0.0.");
|
||||
|
@ -346,6 +369,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
|
|||
}
|
||||
this.scale = func.processScale(userSuppiedScale, userSuppliedScaleWeight);
|
||||
this.func = func;
|
||||
if (offset < 0.0d) {
|
||||
throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : offset must be > 0.0");
|
||||
}
|
||||
this.offset = offset;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -44,14 +44,14 @@ public class ExponentialDecayFunctionParser extends DecayFunctionParser {
|
|||
|
||||
@Override
|
||||
public double evaluate(double value, double scale) {
|
||||
return Math.exp(scale * Math.abs(value));
|
||||
return Math.exp(scale * value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explainFunction(String valueExpl, double value, double scale) {
|
||||
ComplexExplanation ce = new ComplexExplanation();
|
||||
ce.setValue((float) evaluate(value, scale));
|
||||
ce.setDescription("exp(- abs(" + valueExpl + ") * " + -1*scale + ")");
|
||||
ce.setDescription("exp(- abs(" + valueExpl + ") * " + -1 * scale + ")");
|
||||
return ce;
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ public class GaussDecayFunctionParser extends DecayFunctionParser {
|
|||
public Explanation explainFunction(String valueExpl, double value, double scale) {
|
||||
ComplexExplanation ce = new ComplexExplanation();
|
||||
ce.setValue((float) evaluate(value, scale));
|
||||
ce.setDescription("-exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1*scale + ")");
|
||||
ce.setDescription("-exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + ")");
|
||||
return ce;
|
||||
}
|
||||
|
||||
|
|
|
@ -43,8 +43,8 @@ public class LinearDecayFunctionParser extends DecayFunctionParser {
|
|||
final static class LinearDecayScoreFunction implements DecayFunction {
|
||||
|
||||
@Override
|
||||
public double evaluate(double value, double scale) {
|
||||
return Math.max(0.0, (scale - Math.abs(value)) / scale);
|
||||
public double evaluate(double value, double scale) {
|
||||
return Math.max(0.0, (scale - value) / scale);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -149,6 +149,84 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
|
|||
assertThat(sh.getAt(1).getId(), equalTo("2"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistanceScoreGeoLinGaussExpWithOffset() throws Exception {
|
||||
|
||||
createIndexMapped("test", "type1", "test", "string", "num", "double");
|
||||
ensureYellow();
|
||||
|
||||
// add tw docs within offset
|
||||
List<IndexRequestBuilder> indexBuilders = new ArrayList<IndexRequestBuilder>();
|
||||
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("1").setIndex("test")
|
||||
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 0.5).endObject()));
|
||||
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("2").setIndex("test")
|
||||
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 1.7).endObject()));
|
||||
|
||||
// add docs outside offset
|
||||
int numDummyDocs = 20;
|
||||
for (int i = 0; i < numDummyDocs; i++) {
|
||||
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId(Integer.toString(i + 3)).setIndex("test")
|
||||
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 3.0 + i).endObject()));
|
||||
}
|
||||
IndexRequestBuilder[] builders = indexBuilders.toArray(new IndexRequestBuilder[indexBuilders.size()]);
|
||||
|
||||
indexRandom("test", false, builders);
|
||||
refresh();
|
||||
|
||||
// Test Gauss
|
||||
DecayFunctionBuilder fb = new GaussDecayFunctionBuilder("num", 1.0, 5.0);
|
||||
fb.setOffset(1.0);
|
||||
|
||||
ActionFuture<SearchResponse> response = client()
|
||||
.search(searchRequest()
|
||||
.searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(searchSource().explain(true).size(numDummyDocs + 2)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
|
||||
SearchResponse sr = response.actionGet();
|
||||
SearchHits sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2)));
|
||||
assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2")));
|
||||
assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2")));
|
||||
assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score()));
|
||||
for (int i = 0; i < numDummyDocs; i++) {
|
||||
assertThat(sh.getAt(i + 2).getId(), equalTo(Integer.toString(i + 3)));
|
||||
}
|
||||
|
||||
// Test Exp
|
||||
fb = new ExponentialDecayFunctionBuilder("num", 1.0, 5.0);
|
||||
fb.setOffset(1.0);
|
||||
|
||||
response = client()
|
||||
.search(searchRequest()
|
||||
.searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(searchSource().explain(true).size(numDummyDocs + 2)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2)));
|
||||
assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2")));
|
||||
assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2")));
|
||||
assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score()));
|
||||
for (int i = 0; i < numDummyDocs; i++) {
|
||||
assertThat(sh.getAt(i + 2).getId(), equalTo(Integer.toString(i + 3)));
|
||||
}
|
||||
// Test Lin
|
||||
fb = new LinearDecayFunctionBuilder("num", 1.0, 20.0);
|
||||
fb.setOffset(1.0);
|
||||
|
||||
response = client()
|
||||
.search(searchRequest()
|
||||
.searchType(SearchType.QUERY_THEN_FETCH)
|
||||
.source(searchSource().explain(true).size(numDummyDocs + 2)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
|
||||
sr = response.actionGet();
|
||||
sh = sr.getHits();
|
||||
assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2)));
|
||||
assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2")));
|
||||
assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2")));
|
||||
assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoostModeSettingWorks() throws Exception {
|
||||
|
||||
|
@ -405,7 +483,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
|
|||
|
||||
ActionFuture<SearchResponse> response = client().search(
|
||||
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
|
||||
searchSource().explain(true).query(
|
||||
searchSource().explain(false).query(
|
||||
functionScoreQuery(termQuery("test", "value")).add(new MatchAllFilterBuilder(), gfb1)
|
||||
.add(new MatchAllFilterBuilder(), gfb2).scoreMode("multiply"))));
|
||||
|
||||
|
@ -464,7 +542,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
|
|||
.size(numDocs)
|
||||
.query(functionScoreQuery(termQuery("test", "value")).add(new MatchAllFilterBuilder(), gfb1)
|
||||
.add(new MatchAllFilterBuilder(), gfb2).add(new MatchAllFilterBuilder(), gfb3)
|
||||
.scoreMode("multiply"))));
|
||||
.scoreMode("multiply").boostMode(CombineFunction.REPLACE.getName()))));
|
||||
|
||||
SearchResponse sr = response.actionGet();
|
||||
ElasticsearchAssertions.assertNoFailures(sr);
|
||||
|
@ -476,6 +554,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
|
|||
}
|
||||
for (int i = 0; i < numDocs - 1; i++) {
|
||||
assertThat(scores[i], lessThan(scores[i + 1]));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -118,7 +118,7 @@ public class FunctionScorePluginTests extends AbstractNodesTests {
|
|||
|
||||
public static class CustomDistanceScoreParser extends DecayFunctionParser {
|
||||
|
||||
public static final String[] NAMES = {"linear_mult", "linearMult"};
|
||||
public static final String[] NAMES = { "linear_mult", "linearMult" };
|
||||
|
||||
@Override
|
||||
public String[] getNames() {
|
||||
|
@ -138,7 +138,8 @@ public class FunctionScorePluginTests extends AbstractNodesTests {
|
|||
|
||||
@Override
|
||||
public double evaluate(double value, double scale) {
|
||||
return Math.abs(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -155,10 +156,8 @@ public class FunctionScorePluginTests extends AbstractNodesTests {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public class CustomDistanceScoreBuilder extends DecayFunctionBuilder {
|
||||
|
||||
|
||||
public CustomDistanceScoreBuilder(String fieldName, Object reference, Object scale) {
|
||||
super(fieldName, reference, scale);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue