mirror of https://github.com/apache/lucene.git
LUCENE-9594 Add linear function for FeatureField
This adds a linear function and newLinearQuery for FeatureField
This commit is contained in:
parent
d65041359e
commit
5897d14fe4
|
@ -66,7 +66,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||||
* 2<sup>-8</sup> = 0.00390625.
|
* 2<sup>-8</sup> = 0.00390625.
|
||||||
* <p>
|
* <p>
|
||||||
* Given a scoring factor {@code S > 0} and its weight {@code w > 0}, there
|
* Given a scoring factor {@code S > 0} and its weight {@code w > 0}, there
|
||||||
* are three ways that S can be turned into a score:
|
* are four ways that S can be turned into a score:
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li>{@link #newLogQuery w * log(a + S)}, with a ≥ 1. This function
|
* <li>{@link #newLogQuery w * log(a + S)}, with a ≥ 1. This function
|
||||||
* usually makes sense because the distribution of scoring factors
|
* usually makes sense because the distribution of scoring factors
|
||||||
|
@ -82,6 +82,12 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||||
* than the two above but is also harder to tune due to the fact it has
|
* than the two above but is also harder to tune due to the fact it has
|
||||||
* 2 parameters. Like with {@code satu}, values are in the 0..1 range and
|
* 2 parameters. Like with {@code satu}, values are in the 0..1 range and
|
||||||
* 0.5 is obtained when S and k are equal.
|
* 0.5 is obtained when S and k are equal.
|
||||||
|
* <li>{@link #newLinearQuery w * S}. Expert: This function doesn't apply
|
||||||
|
* any transformation to an indexed feature value, and the indexed value itself,
|
||||||
|
* multiplied by weight, determines the score. Thus, there is an expectation
|
||||||
|
* that a feature value is encoded in the index in a way that makes
|
||||||
|
* sense for scoring.
|
||||||
|
*
|
||||||
* </ul>
|
* </ul>
|
||||||
* <p>
|
* <p>
|
||||||
* The constants in the above formulas typically need training in order to
|
* The constants in the above formulas typically need training in order to
|
||||||
|
@ -217,6 +223,46 @@ public final class FeatureField extends Field {
|
||||||
FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
|
FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static final class LinearFunction extends FeatureFunction {
|
||||||
|
@Override
|
||||||
|
SimScorer scorer(float w) {
|
||||||
|
return new SimScorer() {
|
||||||
|
@Override
|
||||||
|
public float score(float freq, long norm) {
|
||||||
|
return (w * decodeFeatureValue(freq));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Explanation explain(String field, String feature, float w, int freq) {
|
||||||
|
float featureValue = decodeFeatureValue(freq);
|
||||||
|
float score = scorer(w).score(freq, 1L);
|
||||||
|
return Explanation.match(score,
|
||||||
|
"Linear function on the " + field + " field for the " + feature + " feature, computed as w * S from:",
|
||||||
|
Explanation.match(w, "w, weight of this function"),
|
||||||
|
Explanation.match(featureValue, "S, feature value"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "LinearFunction";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return getClass().hashCode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (obj == null || getClass() != obj.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static final class LogFunction extends FeatureFunction {
|
static final class LogFunction extends FeatureFunction {
|
||||||
|
|
||||||
private final float scalingFactor;
|
private final float scalingFactor;
|
||||||
|
@ -406,6 +452,26 @@ public final class FeatureField extends Field {
|
||||||
*/
|
*/
|
||||||
private static final float MAX_WEIGHT = Long.SIZE;
|
private static final float MAX_WEIGHT = Long.SIZE;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a new {@link Query} that will score documents as
|
||||||
|
* {@code weight * S} where S is the value of the static feature.
|
||||||
|
* @param fieldName field that stores features
|
||||||
|
* @param featureName name of the feature
|
||||||
|
* @param weight weight to give to this feature, must be in (0,64]
|
||||||
|
* @throws IllegalArgumentException if weight is not in (0,64]
|
||||||
|
*/
|
||||||
|
public static Query newLinearQuery(String fieldName, String featureName, float weight) {
|
||||||
|
if (weight <= 0 || weight > MAX_WEIGHT) {
|
||||||
|
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
|
||||||
|
}
|
||||||
|
Query q = new FeatureQuery(fieldName, featureName, new LinearFunction());
|
||||||
|
if (weight != 1f) {
|
||||||
|
q = new BoostQuery(q, weight);
|
||||||
|
}
|
||||||
|
return q;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a new {@link Query} that will score documents as
|
* Return a new {@link Query} that will score documents as
|
||||||
* {@code weight * Math.log(scalingFactor + S)} where S is the value of the static feature.
|
* {@code weight * Math.log(scalingFactor + S)} where S is the value of the static feature.
|
||||||
|
|
|
@ -101,6 +101,24 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
|
|
||||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc());
|
assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc());
|
||||||
|
|
||||||
|
q = FeatureField.newLinearQuery("features", "pagerank", 3f);
|
||||||
|
w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2);
|
||||||
|
s = w.scorer(context);
|
||||||
|
|
||||||
|
assertEquals(0, s.iterator().nextDoc());
|
||||||
|
assertEquals((float) (6.0 * 10), s.score(), 0f);
|
||||||
|
|
||||||
|
assertEquals(1, s.iterator().nextDoc());
|
||||||
|
assertEquals((float) (6.0 * 100), s.score(), 0f);
|
||||||
|
|
||||||
|
assertEquals(3, s.iterator().nextDoc());
|
||||||
|
assertEquals((float) (6.0 * 1), s.score(), 0f);
|
||||||
|
|
||||||
|
assertEquals(4, s.iterator().nextDoc());
|
||||||
|
assertEquals((float) (6.0 * 42), s.score(), 0f);
|
||||||
|
|
||||||
|
assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc());
|
||||||
|
|
||||||
q = FeatureField.newSaturationQuery("features", "pagerank", 3f, 4.5f);
|
q = FeatureField.newSaturationQuery("features", "pagerank", 3f, 4.5f);
|
||||||
w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2);
|
w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2);
|
||||||
s = w.scorer(context);
|
s = w.scorer(context);
|
||||||
|
@ -188,16 +206,19 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
IndexSearcher searcher = new IndexSearcher(reader);
|
IndexSearcher searcher = new IndexSearcher(reader);
|
||||||
|
|
||||||
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 1f, 4.5f), searcher);
|
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 1f, 4.5f), searcher);
|
||||||
|
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 1f), searcher);
|
||||||
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 1f, 12f), searcher);
|
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 1f, 12f), searcher);
|
||||||
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 1f, 12f, 0.6f), searcher);
|
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 1f, 12f, 0.6f), searcher);
|
||||||
|
|
||||||
// Test boosts that are > 1
|
// Test boosts that are > 1
|
||||||
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 3f, 4.5f), searcher);
|
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 3f, 4.5f), searcher);
|
||||||
|
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 3f), searcher);
|
||||||
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 3f, 12f), searcher);
|
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 3f, 12f), searcher);
|
||||||
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 3f, 12f, 0.6f), searcher);
|
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 3f, 12f, 0.6f), searcher);
|
||||||
|
|
||||||
// Test boosts that are < 1
|
// Test boosts that are < 1
|
||||||
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", .2f, 4.5f), searcher);
|
QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", .2f, 4.5f), searcher);
|
||||||
|
QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", .2f), searcher);
|
||||||
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", .2f, 12f), searcher);
|
QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", .2f, 12f), searcher);
|
||||||
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", .2f, 12f, 0.6f), searcher);
|
QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", .2f, 12f, 0.6f), searcher);
|
||||||
|
|
||||||
|
@ -209,6 +230,10 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
doTestSimScorer(new FeatureField.LogFunction(4.5f).scorer(3f));
|
doTestSimScorer(new FeatureField.LogFunction(4.5f).scorer(3f));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testLinearSimScorer() {
|
||||||
|
doTestSimScorer(new FeatureField.LinearFunction().scorer(1f));
|
||||||
|
}
|
||||||
|
|
||||||
public void testSatuSimScorer() {
|
public void testSatuSimScorer() {
|
||||||
doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer(3f));
|
doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer(3f));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue