LUCENE-9594 Add linear function for FeatureField

This adds a linear function and newLinearQuery for FeatureField
This commit is contained in:
Mayya Sharipova 2020-11-10 17:08:08 -05:00 committed by GitHub
parent d65041359e
commit 5897d14fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 1 deletions

View File

@ -66,7 +66,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
* 2<sup>-8</sup> = 0.00390625.
* <p>
* Given a scoring factor {@code S > 0} and its weight {@code w > 0}, there
* are three ways that S can be turned into a score:
* are four ways that S can be turned into a score:
* <ul>
* <li>{@link #newLogQuery w * log(a + S)}, with a &ge; 1. This function
* usually makes sense because the distribution of scoring factors
@ -82,6 +82,12 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
* than the two above but is also harder to tune due to the fact it has
* 2 parameters. Like with {@code satu}, values are in the 0..1 range and
* 0.5 is obtained when S and k are equal.
* <li>{@link #newLinearQuery w * S}. Expert: This function doesn't apply
* any transformation to an indexed feature value, and the indexed value itself,
* multiplied by weight, determines the score. Thus, there is an expectation
* that a feature value is encoded in the index in a way that makes
* sense for scoring.
*
* </ul>
* <p>
* The constants in the above formulas typically need training in order to
@ -217,6 +223,46 @@ public final class FeatureField extends Field {
FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
}
static final class LinearFunction extends FeatureFunction {
@Override
SimScorer scorer(float w) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return (w * decodeFeatureValue(freq));
}
};
}
@Override
Explanation explain(String field, String feature, float w, int freq) {
float featureValue = decodeFeatureValue(freq);
float score = scorer(w).score(freq, 1L);
return Explanation.match(score,
"Linear function on the " + field + " field for the " + feature + " feature, computed as w * S from:",
Explanation.match(w, "w, weight of this function"),
Explanation.match(featureValue, "S, feature value"));
}
@Override
public String toString() {
return "LinearFunction";
}
@Override
public int hashCode() {
return getClass().hashCode();
}
@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
return true;
}
};
static final class LogFunction extends FeatureFunction {
private final float scalingFactor;
@ -406,6 +452,26 @@ public final class FeatureField extends Field {
*/
private static final float MAX_WEIGHT = Long.SIZE;
/**
* Return a new {@link Query} that will score documents as
* {@code weight * S} where S is the value of the static feature.
* @param fieldName field that stores features
* @param featureName name of the feature
* @param weight weight to give to this feature, must be in (0,64]
* @throws IllegalArgumentException if weight is not in (0,64]
*/
public static Query newLinearQuery(String fieldName, String featureName, float weight) {
if (weight <= 0 || weight > MAX_WEIGHT) {
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
}
Query q = new FeatureQuery(fieldName, featureName, new LinearFunction());
if (weight != 1f) {
q = new BoostQuery(q, weight);
}
return q;
}
/**
* Return a new {@link Query} that will score documents as
* {@code weight * Math.log(scalingFactor + S)} where S is the value of the static feature.

View File

@ -101,6 +101,24 @@ public class TestFeatureField extends LuceneTestCase {
assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc());
q = FeatureField.newLinearQuery("features", "pagerank", 3f);
w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2);
s = w.scorer(context);
assertEquals(0, s.iterator().nextDoc());
assertEquals((float) (6.0 * 10), s.score(), 0f);
assertEquals(1, s.iterator().nextDoc());
assertEquals((float) (6.0 * 100), s.score(), 0f);
assertEquals(3, s.iterator().nextDoc());
assertEquals((float) (6.0 * 1), s.score(), 0f);
assertEquals(4, s.iterator().nextDoc());
assertEquals((float) (6.0 * 42), s.score(), 0f);
assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc());
q = FeatureField.newSaturationQuery("features", "pagerank", 3f, 4.5f);
w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2);
s = w.scorer(context);
@ -188,16 +206,19 @@ public class TestFeatureField extends LuceneTestCase {
IndexSearcher searcher = new IndexSearcher(reader);
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 1f, 4.5f), searcher);
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 1f), searcher);
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 1f, 12f), searcher);
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 1f, 12f, 0.6f), searcher);
// Test boosts that are > 1
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 3f, 4.5f), searcher);
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 3f), searcher);
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 3f, 12f), searcher);
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 3f, 12f, 0.6f), searcher);
// Test boosts that are < 1
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", .2f, 4.5f), searcher);
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", .2f), searcher);
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", .2f, 12f), searcher);
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", .2f, 12f, 0.6f), searcher);
@ -209,6 +230,10 @@ public class TestFeatureField extends LuceneTestCase {
doTestSimScorer(new FeatureField.LogFunction(4.5f).scorer(3f));
}
public void testLinearSimScorer() {
doTestSimScorer(new FeatureField.LinearFunction().scorer(1f));
}
public void testSatuSimScorer() {
doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer(3f));
}