Add offset to decay function score

Docs within the offset will be scored with 1.0, decay only starts after
offset is reached.
This commit is contained in:
Britta Weber 2013-08-16 16:14:29 +02:00
parent c0288a62e6
commit 41b4a14933
7 changed files with 144 additions and 28 deletions

View File

@ -29,25 +29,33 @@ public abstract class DecayFunctionBuilder implements ScoreFunctionBuilder {
protected static final String REFERNECE = "reference";
protected static final String SCALE = "scale";
protected static final String SCALE_WEIGHT = "scale_weight";
protected static final String OFFSET = "offset";
private String fieldName;
private Object reference;
private Object scale;
private double scaleWeight = -1;
private Object offset;
public DecayFunctionBuilder(String fieldName, Object reference, Object scale) {
this.fieldName = fieldName;
this.reference = reference;
this.scale = scale;
}
public DecayFunctionBuilder setScaleWeight(double scaleWeight) {
if(scaleWeight <=0 || scaleWeight >= 1.0) {
if (scaleWeight <= 0 || scaleWeight >= 1.0) {
throw new ElasticSearchIllegalStateException("scale weight parameter must be in range 0..1!");
}
this.scaleWeight = scaleWeight;
return this;
}
public DecayFunctionBuilder setOffset(Object offset) {
this.offset = offset;
return this;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(getName());
@ -57,6 +65,9 @@ public abstract class DecayFunctionBuilder implements ScoreFunctionBuilder {
if (scaleWeight > 0) {
builder.field(SCALE_WEIGHT, scaleWeight);
}
if (offset != null) {
builder.field(OFFSET, offset);
}
builder.endObject();
builder.endObject();
return builder;

View File

@ -164,6 +164,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
double scale = 0;
double reference = 0;
double scaleWeight = 0.5;
double offset = 0.0d;
boolean scaleFound = false;
boolean refFound = false;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
@ -177,6 +178,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
} else if (parameterName.equals(DecayFunctionBuilder.REFERNECE)) {
reference = parser.doubleValue();
refFound = true;
} else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) {
offset = parser.doubleValue();
} else {
throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!");
}
@ -186,7 +189,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
+ " must be set for numeric fields.");
}
IndexNumericFieldData<?> numericFieldData = parseContext.fieldData().getForField(mapper);
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), numericFieldData);
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), numericFieldData);
}
private ScoreFunction parseGeoVariable(String fieldName, XContentParser parser, QueryParseContext parseContext,
@ -195,6 +198,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
String parameterName = null;
GeoPoint reference = new GeoPoint();
String scaleString = "1km";
String offsetString = "0km";
double scaleWeight = 0.5;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
@ -205,6 +209,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
reference = GeoPoint.parse(parser);
} else if (parameterName.equals(DecayFunctionBuilder.SCALE_WEIGHT)) {
scaleWeight = parser.doubleValue();
} else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) {
offsetString = parser.text();
} else {
throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!");
}
@ -213,9 +219,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
throw new ElasticSearchParseException(DecayFunctionBuilder.REFERNECE + "must be set for geo fields.");
}
double scale = DistanceUnit.parse(scaleString, DistanceUnit.METERS, DistanceUnit.METERS);
double offset = DistanceUnit.parse(offsetString, DistanceUnit.METERS, DistanceUnit.METERS);
IndexGeoPointFieldData<?> indexFieldData = parseContext.fieldData().getForField(mapper);
return new GeoFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), indexFieldData);
return new GeoFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), indexFieldData);
}
@ -225,6 +231,7 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
String parameterName = null;
String scaleString = null;
String referenceString = null;
String offsetString = "0d";
double scaleWeight = 0.5;
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
@ -235,6 +242,8 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
referenceString = parser.text();
} else if (parameterName.equals(DecayFunctionBuilder.SCALE_WEIGHT)) {
scaleWeight = parser.doubleValue();
} else if (parameterName.equals(DecayFunctionBuilder.OFFSET)) {
offsetString = parser.text();
} else {
throw new ElasticSearchParseException("Parameter " + parameterName + " not supported!");
}
@ -249,8 +258,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
}
TimeValue val = TimeValue.parseTimeValue(scaleString, TimeValue.timeValueHours(24));
double scale = val.getMillis();
val = TimeValue.parseTimeValue(offsetString, TimeValue.timeValueHours(24));
double offset = val.getMillis();
IndexNumericFieldData<?> numericFieldData = parseContext.fieldData().getForField(dateFieldMapper);
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, getDecayFunction(), numericFieldData);
return new NumericFieldDataScoreFunction(reference, scale, scaleWeight, offset, getDecayFunction(), numericFieldData);
}
static class GeoFieldDataScoreFunction extends AbstractDistanceScoreFunction {
@ -261,9 +272,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
private static final GeoDistance distFunction = GeoDistance.fromString("arc");
public GeoFieldDataScoreFunction(GeoPoint reference, double scale, double scaleWeight, DecayFunction func,
public GeoFieldDataScoreFunction(GeoPoint reference, double scale, double scaleWeight, double offset, DecayFunction func,
IndexGeoPointFieldData<?> fieldData) {
super(scale, scaleWeight, func);
super(scale, scaleWeight, offset, func);
this.reference = reference;
this.fieldData = fieldData;
}
@ -276,21 +287,26 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
@Override
protected double distance(int docId) {
GeoPoint other = geoPointValues.getValueMissing(docId, reference);
return distFunction.calculate(reference.lat(), reference.lon(), other.lat(), other.lon(), DistanceUnit.METERS);
double distance = Math.abs(distFunction.calculate(reference.lat(), reference.lon(), other.lat(), other.lon(),
DistanceUnit.METERS)) - offset;
if (distance < 0.0d) {
distance = 0.0d;
}
return distance;
}
@Override
protected String getDistanceString(int docId) {
final GeoPoint other = geoPointValues.getValueMissing(docId, reference);
return "arcDistance(" + other + "(=doc value), " + reference + ") = " + distance(docId);
return "arcDistance(" + other + "(=doc value), " + reference + "(=reference)) - " + offset
+ "(=offset) < 0.0 ? 0.0: arcDistance(" + other + "(=doc value), " + reference + "(=reference)) - " + offset
+ "(=offset)";
}
@Override
protected String getFieldName() {
return fieldData.getFieldNames().fullName();
}
}
static class NumericFieldDataScoreFunction extends AbstractDistanceScoreFunction {
@ -299,9 +315,9 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
private final double reference;
private DoubleValues doubleValues;
public NumericFieldDataScoreFunction(double reference, double scale, double scaleWeight, DecayFunction func,
public NumericFieldDataScoreFunction(double reference, double scale, double scaleWeight, double offset, DecayFunction func,
IndexNumericFieldData<?> fieldData) {
super(scale, scaleWeight, func);
super(scale, scaleWeight, offset, func);
this.fieldData = fieldData;
this.reference = reference;
}
@ -312,12 +328,18 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
@Override
protected double distance(int docId) {
return doubleValues.getValueMissing(docId, reference) - reference;
double distance = Math.abs(doubleValues.getValueMissing(docId, reference) - reference) - offset;
if (distance < 0.0) {
distance = 0.0;
}
return distance;
}
@Override
protected String getDistanceString(int docId) {
return "(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - " + reference + ")";
return "Math.abs(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - " + reference + "(=reference)) - "
+ offset + "(=offset) < 0.0 ? 0.0: Math.abs(" + doubleValues.getValueMissing(docId, reference) + "(=doc value) - "
+ reference + ") - " + offset + "(=offset)";
}
@Override
@ -333,9 +355,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
public static abstract class AbstractDistanceScoreFunction extends ScoreFunction {
private final double scale;
protected final double offset;
private final DecayFunction func;
public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, DecayFunction func) {
public AbstractDistanceScoreFunction(double userSuppiedScale, double userSuppliedScaleWeight, double offset, DecayFunction func) {
super(CombineFunction.MULT);
if (userSuppiedScale <= 0.0) {
throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : scale must be > 0.0.");
@ -346,6 +369,10 @@ public abstract class DecayFunctionParser implements ScoreFunctionParser {
}
this.scale = func.processScale(userSuppiedScale, userSuppliedScaleWeight);
this.func = func;
if (offset < 0.0d) {
throw new ElasticSearchIllegalArgumentException(FunctionScoreQueryParser.NAME + " : offset must be > 0.0");
}
this.offset = offset;
}
@Override

View File

@ -44,14 +44,14 @@ public class ExponentialDecayFunctionParser extends DecayFunctionParser {
@Override
public double evaluate(double value, double scale) {
return Math.exp(scale * Math.abs(value));
return Math.exp(scale * value);
}
@Override
public Explanation explainFunction(String valueExpl, double value, double scale) {
ComplexExplanation ce = new ComplexExplanation();
ce.setValue((float) evaluate(value, scale));
ce.setDescription("exp(- abs(" + valueExpl + ") * " + -1*scale + ")");
ce.setDescription("exp(- abs(" + valueExpl + ") * " + -1 * scale + ")");
return ce;
}

View File

@ -47,7 +47,7 @@ public class GaussDecayFunctionParser extends DecayFunctionParser {
public Explanation explainFunction(String valueExpl, double value, double scale) {
ComplexExplanation ce = new ComplexExplanation();
ce.setValue((float) evaluate(value, scale));
ce.setDescription("-exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1*scale + ")");
ce.setDescription("-exp(-0.5*pow(" + valueExpl + ",2.0)/" + -1 * scale + ")");
return ce;
}

View File

@ -44,7 +44,7 @@ public class LinearDecayFunctionParser extends DecayFunctionParser {
@Override
public double evaluate(double value, double scale) {
return Math.max(0.0, (scale - Math.abs(value)) / scale);
return Math.max(0.0, (scale - value) / scale);
}
@Override

View File

@ -149,6 +149,84 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
assertThat(sh.getAt(1).getId(), equalTo("2"));
}
@Test
public void testDistanceScoreGeoLinGaussExpWithOffset() throws Exception {
createIndexMapped("test", "type1", "test", "string", "num", "double");
ensureYellow();
// add tw docs within offset
List<IndexRequestBuilder> indexBuilders = new ArrayList<IndexRequestBuilder>();
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("1").setIndex("test")
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 0.5).endObject()));
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId("2").setIndex("test")
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 1.7).endObject()));
// add docs outside offset
int numDummyDocs = 20;
for (int i = 0; i < numDummyDocs; i++) {
indexBuilders.add(new IndexRequestBuilder(client()).setType("type1").setId(Integer.toString(i + 3)).setIndex("test")
.setSource(jsonBuilder().startObject().field("test", "value").field("num", 3.0 + i).endObject()));
}
IndexRequestBuilder[] builders = indexBuilders.toArray(new IndexRequestBuilder[indexBuilders.size()]);
indexRandom("test", false, builders);
refresh();
// Test Gauss
DecayFunctionBuilder fb = new GaussDecayFunctionBuilder("num", 1.0, 5.0);
fb.setOffset(1.0);
ActionFuture<SearchResponse> response = client()
.search(searchRequest()
.searchType(SearchType.QUERY_THEN_FETCH)
.source(searchSource().explain(true).size(numDummyDocs + 2)
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
SearchResponse sr = response.actionGet();
SearchHits sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2)));
assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2")));
assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2")));
assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score()));
for (int i = 0; i < numDummyDocs; i++) {
assertThat(sh.getAt(i + 2).getId(), equalTo(Integer.toString(i + 3)));
}
// Test Exp
fb = new ExponentialDecayFunctionBuilder("num", 1.0, 5.0);
fb.setOffset(1.0);
response = client()
.search(searchRequest()
.searchType(SearchType.QUERY_THEN_FETCH)
.source(searchSource().explain(true).size(numDummyDocs + 2)
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2)));
assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2")));
assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2")));
assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score()));
for (int i = 0; i < numDummyDocs; i++) {
assertThat(sh.getAt(i + 2).getId(), equalTo(Integer.toString(i + 3)));
}
// Test Lin
fb = new LinearDecayFunctionBuilder("num", 1.0, 20.0);
fb.setOffset(1.0);
response = client()
.search(searchRequest()
.searchType(SearchType.QUERY_THEN_FETCH)
.source(searchSource().explain(true).size(numDummyDocs + 2)
.query(functionScoreQuery(termQuery("test", "value")).add(fb).boostMode(CombineFunction.REPLACE.getName()))));
sr = response.actionGet();
sh = sr.getHits();
assertThat(sh.getTotalHits(), equalTo((long) (numDummyDocs + 2)));
assertThat(sh.getAt(0).getId(), anyOf(equalTo("1"), equalTo("2")));
assertThat(sh.getAt(1).getId(), anyOf(equalTo("1"), equalTo("2")));
assertThat(sh.getAt(1).score(), equalTo(sh.getAt(0).score()));
}
@Test
public void testBoostModeSettingWorks() throws Exception {
@ -405,7 +483,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
ActionFuture<SearchResponse> response = client().search(
searchRequest().searchType(SearchType.QUERY_THEN_FETCH).source(
searchSource().explain(true).query(
searchSource().explain(false).query(
functionScoreQuery(termQuery("test", "value")).add(new MatchAllFilterBuilder(), gfb1)
.add(new MatchAllFilterBuilder(), gfb2).scoreMode("multiply"))));
@ -464,7 +542,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
.size(numDocs)
.query(functionScoreQuery(termQuery("test", "value")).add(new MatchAllFilterBuilder(), gfb1)
.add(new MatchAllFilterBuilder(), gfb2).add(new MatchAllFilterBuilder(), gfb3)
.scoreMode("multiply"))));
.scoreMode("multiply").boostMode(CombineFunction.REPLACE.getName()))));
SearchResponse sr = response.actionGet();
ElasticsearchAssertions.assertNoFailures(sr);
@ -476,6 +554,7 @@ public class DecayFunctionScoreTests extends AbstractSharedClusterTest {
}
for (int i = 0; i < numDocs - 1; i++) {
assertThat(scores[i], lessThan(scores[i + 1]));
}
}

View File

@ -118,7 +118,7 @@ public class FunctionScorePluginTests extends AbstractNodesTests {
public static class CustomDistanceScoreParser extends DecayFunctionParser {
public static final String[] NAMES = {"linear_mult", "linearMult"};
public static final String[] NAMES = { "linear_mult", "linearMult" };
@Override
public String[] getNames() {
@ -138,7 +138,8 @@ public class FunctionScorePluginTests extends AbstractNodesTests {
@Override
public double evaluate(double value, double scale) {
return Math.abs(value);
return value;
}
@Override
@ -155,10 +156,8 @@ public class FunctionScorePluginTests extends AbstractNodesTests {
}
}
public class CustomDistanceScoreBuilder extends DecayFunctionBuilder {
public CustomDistanceScoreBuilder(String fieldName, Object reference, Object scale) {
super(fieldName, reference, scale);
}