mirror of https://github.com/apache/lucene.git
LUCENE-9594 Add linear function for FeatureField
This adds a linear function and newLinearQuery for FeatureField
This commit is contained in:
parent
d65041359e
commit
5897d14fe4
|
@ -66,7 +66,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
|||
* 2<sup>-8</sup> = 0.00390625.
|
||||
* <p>
|
||||
* Given a scoring factor {@code S > 0} and its weight {@code w > 0}, there
|
||||
* are three ways that S can be turned into a score:
|
||||
* are four ways that S can be turned into a score:
|
||||
* <ul>
|
||||
* <li>{@link #newLogQuery w * log(a + S)}, with a ≥ 1. This function
|
||||
* usually makes sense because the distribution of scoring factors
|
||||
|
@ -82,6 +82,12 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
|||
* than the two above but is also harder to tune due to the fact it has
|
||||
* 2 parameters. Like with {@code satu}, values are in the 0..1 range and
|
||||
* 0.5 is obtained when S and k are equal.
|
||||
* <li>{@link #newLinearQuery w * S}. Expert: This function doesn't apply
|
||||
* any transformation to an indexed feature value, and the indexed value itself,
|
||||
* multiplied by weight, determines the score. Thus, there is an expectation
|
||||
* that a feature value is encoded in the index in a way that makes
|
||||
* sense for scoring.
|
||||
*
|
||||
* </ul>
|
||||
* <p>
|
||||
* The constants in the above formulas typically need training in order to
|
||||
|
@ -217,6 +223,46 @@ public final class FeatureField extends Field {
|
|||
FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
|
||||
}
|
||||
|
||||
static final class LinearFunction extends FeatureFunction {
|
||||
@Override
|
||||
SimScorer scorer(float w) {
|
||||
return new SimScorer() {
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
return (w * decodeFeatureValue(freq));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
Explanation explain(String field, String feature, float w, int freq) {
|
||||
float featureValue = decodeFeatureValue(freq);
|
||||
float score = scorer(w).score(freq, 1L);
|
||||
return Explanation.match(score,
|
||||
"Linear function on the " + field + " field for the " + feature + " feature, computed as w * S from:",
|
||||
Explanation.match(w, "w, weight of this function"),
|
||||
Explanation.match(featureValue, "S, feature value"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "LinearFunction";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return getClass().hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
static final class LogFunction extends FeatureFunction {
|
||||
|
||||
private final float scalingFactor;
|
||||
|
@ -406,6 +452,26 @@ public final class FeatureField extends Field {
|
|||
*/
|
||||
private static final float MAX_WEIGHT = Long.SIZE;
|
||||
|
||||
|
||||
/**
|
||||
* Return a new {@link Query} that will score documents as
|
||||
* {@code weight * S} where S is the value of the static feature.
|
||||
* @param fieldName field that stores features
|
||||
* @param featureName name of the feature
|
||||
* @param weight weight to give to this feature, must be in (0,64]
|
||||
* @throws IllegalArgumentException if weight is not in (0,64]
|
||||
*/
|
||||
public static Query newLinearQuery(String fieldName, String featureName, float weight) {
|
||||
if (weight <= 0 || weight > MAX_WEIGHT) {
|
||||
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
|
||||
}
|
||||
Query q = new FeatureQuery(fieldName, featureName, new LinearFunction());
|
||||
if (weight != 1f) {
|
||||
q = new BoostQuery(q, weight);
|
||||
}
|
||||
return q;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new {@link Query} that will score documents as
|
||||
* {@code weight * Math.log(scalingFactor + S)} where S is the value of the static feature.
|
||||
|
|
|
@ -101,6 +101,24 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc());
|
||||
|
||||
q = FeatureField.newLinearQuery("features", "pagerank", 3f);
|
||||
w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2);
|
||||
s = w.scorer(context);
|
||||
|
||||
assertEquals(0, s.iterator().nextDoc());
|
||||
assertEquals((float) (6.0 * 10), s.score(), 0f);
|
||||
|
||||
assertEquals(1, s.iterator().nextDoc());
|
||||
assertEquals((float) (6.0 * 100), s.score(), 0f);
|
||||
|
||||
assertEquals(3, s.iterator().nextDoc());
|
||||
assertEquals((float) (6.0 * 1), s.score(), 0f);
|
||||
|
||||
assertEquals(4, s.iterator().nextDoc());
|
||||
assertEquals((float) (6.0 * 42), s.score(), 0f);
|
||||
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc());
|
||||
|
||||
q = FeatureField.newSaturationQuery("features", "pagerank", 3f, 4.5f);
|
||||
w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2);
|
||||
s = w.scorer(context);
|
||||
|
@ -188,16 +206,19 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
|
||||
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 1f, 4.5f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 1f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 1f, 12f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 1f, 12f, 0.6f), searcher);
|
||||
|
||||
// Test boosts that are > 1
|
||||
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 3f, 4.5f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 3f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 3f, 12f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 3f, 12f, 0.6f), searcher);
|
||||
|
||||
// Test boosts that are < 1
|
||||
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", .2f, 4.5f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", .2f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", .2f, 12f), searcher);
|
||||
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", .2f, 12f, 0.6f), searcher);
|
||||
|
||||
|
@ -209,6 +230,10 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
doTestSimScorer(new FeatureField.LogFunction(4.5f).scorer(3f));
|
||||
}
|
||||
|
||||
public void testLinearSimScorer() {
|
||||
doTestSimScorer(new FeatureField.LinearFunction().scorer(1f));
|
||||
}
|
||||
|
||||
public void testSatuSimScorer() {
|
||||
doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer(3f));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue