diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java index 2ca048c5270..dafcbf4f046 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java @@ -66,7 +66,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; * 2-8 = 0.00390625. *
* Given a scoring factor {@code S > 0} and its weight {@code w > 0}, there - * are three ways that S can be turned into a score: + * are four ways that S can be turned into a score: *
* The constants in the above formulas typically need training in order to @@ -217,6 +223,46 @@ public final class FeatureField extends Field { FeatureFunction rewrite(IndexReader reader) throws IOException { return this; } } + static final class LinearFunction extends FeatureFunction { + @Override + SimScorer scorer(float w) { + return new SimScorer() { + @Override + public float score(float freq, long norm) { + return (w * decodeFeatureValue(freq)); + } + }; + } + + @Override + Explanation explain(String field, String feature, float w, int freq) { + float featureValue = decodeFeatureValue(freq); + float score = scorer(w).score(freq, 1L); + return Explanation.match(score, + "Linear function on the " + field + " field for the " + feature + " feature, computed as w * S from:", + Explanation.match(w, "w, weight of this function"), + Explanation.match(featureValue, "S, feature value")); + } + + @Override + public String toString() { + return "LinearFunction"; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) { + return false; + } + return true; + } + }; + static final class LogFunction extends FeatureFunction { private final float scalingFactor; @@ -406,6 +452,26 @@ public final class FeatureField extends Field { */ private static final float MAX_WEIGHT = Long.SIZE; + + /** + * Return a new {@link Query} that will score documents as + * {@code weight * S} where S is the value of the static feature. + * @param fieldName field that stores features + * @param featureName name of the feature + * @param weight weight to give to this feature, must be in (0,64] + * @throws IllegalArgumentException if weight is not in (0,64] + */ + public static Query newLinearQuery(String fieldName, String featureName, float weight) { + if (weight <= 0 || weight > MAX_WEIGHT) { + throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight); + } + Query q = new FeatureQuery(fieldName, featureName, new LinearFunction()); + if (weight != 1f) { + q = new BoostQuery(q, weight); + } + return q; + } + /** * Return a new {@link Query} that will score documents as * {@code weight * Math.log(scalingFactor + S)} where S is the value of the static feature. diff --git a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java index 79534c819cf..bc63d2ec1cf 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java @@ -101,6 +101,24 @@ public class TestFeatureField extends LuceneTestCase { assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc()); + q = FeatureField.newLinearQuery("features", "pagerank", 3f); + w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2); + s = w.scorer(context); + + assertEquals(0, s.iterator().nextDoc()); + assertEquals((float) (6.0 * 10), s.score(), 0f); + + assertEquals(1, s.iterator().nextDoc()); + assertEquals((float) (6.0 * 100), s.score(), 0f); + + assertEquals(3, s.iterator().nextDoc()); + assertEquals((float) (6.0 * 1), s.score(), 0f); + + assertEquals(4, s.iterator().nextDoc()); + assertEquals((float) (6.0 * 42), s.score(), 0f); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, s.iterator().nextDoc()); + q = FeatureField.newSaturationQuery("features", "pagerank", 3f, 4.5f); w = q.createWeight(searcher, ScoreMode.TOP_SCORES, 2); s = w.scorer(context); @@ -188,16 +206,19 @@ public class TestFeatureField extends LuceneTestCase { IndexSearcher searcher = new IndexSearcher(reader); QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 1f, 4.5f), searcher); + QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 1f), searcher); QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 1f, 12f), searcher); QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 1f, 12f, 0.6f), searcher); // Test boosts that are > 1 QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", 3f, 4.5f), searcher); + QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", 3f), searcher); QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", 3f, 12f), searcher); QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", 3f, 12f, 0.6f), searcher); // Test boosts that are < 1 QueryUtils.check(random(), FeatureField.newLogQuery("features", "pagerank", .2f, 4.5f), searcher); + QueryUtils.check(random(), FeatureField.newLinearQuery("features", "pagerank", .2f), searcher); QueryUtils.check(random(), FeatureField.newSaturationQuery("features", "pagerank", .2f, 12f), searcher); QueryUtils.check(random(), FeatureField.newSigmoidQuery("features", "pagerank", .2f, 12f, 0.6f), searcher); @@ -209,6 +230,10 @@ public class TestFeatureField extends LuceneTestCase { doTestSimScorer(new FeatureField.LogFunction(4.5f).scorer(3f)); } + public void testLinearSimScorer() { + doTestSimScorer(new FeatureField.LinearFunction().scorer(1f)); + } + public void testSatuSimScorer() { doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer(3f)); }