diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java index 1b8a4e5dfce..378c2afc904 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java @@ -17,18 +17,19 @@ package org.apache.lucene.document; import java.io.IOException; +import java.util.Objects; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.Similarity.SimScorer; @@ -82,7 +83,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; *

* The constants in the above formulas typically need training in order to * compute optimal values. If you don't know where to start, the - * {@link #newSaturationQuery(IndexSearcher, String, String)} method uses + * {@link #newSaturationQuery(String, String)} method uses * {@code 1f} as a weight and tries to guess a sensible value for the * {@code pivot} parameter of the saturation function based on index * statistics, which shouldn't perform too bad. Here is an example, assuming @@ -93,7 +94,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; * .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD) * .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD) * .build(); - * Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank"); + * Query boost = FeatureField.newSaturationQuery("features", "pagerank"); * Query boostedQuery = new BooleanQuery.Builder() * .add(query, Occur.MUST) * .add(boost, Occur.SHOULD) @@ -210,6 +211,7 @@ public final class FeatureField extends Field { static abstract class FeatureFunction { abstract SimScorer scorer(String field, float w); abstract Explanation explain(String field, String feature, float w, int freq); + FeatureFunction rewrite(IndexReader reader) throws IOException { return this; } } static final class LogFunction extends FeatureFunction { @@ -263,24 +265,38 @@ public final class FeatureField extends Field { static final class SaturationFunction extends FeatureFunction { - private final float pivot; + private final String field, feature; + private final Float pivot; - SaturationFunction(float pivot) { + SaturationFunction(String field, String feature, Float pivot) { + this.field = field; + this.feature = feature; this.pivot = pivot; } + @Override + public FeatureFunction rewrite(IndexReader reader) throws IOException { + if (pivot != null) { + return super.rewrite(reader); + } + float newPivot = computePivotFeatureValue(reader, field, feature); + return new SaturationFunction(field, feature, newPivot); + } + @Override public boolean equals(Object obj) { if (obj == null || getClass() != obj.getClass()) { return false; } SaturationFunction that = (SaturationFunction) obj; - return pivot == that.pivot; + return Objects.equals(field, that.field) && + Objects.equals(feature, that.feature) && + Objects.equals(pivot, that.pivot); } @Override public int hashCode() { - return Float.hashCode(pivot); + return Objects.hash(field, feature, pivot); } @Override @@ -290,6 +306,10 @@ public final class FeatureField extends Field { @Override SimScorer scorer(String field, float weight) { + if (pivot == null) { + throw new IllegalStateException("Rewrite first"); + } + final float pivot = this.pivot; // unbox return new SimScorer(field) { @Override public float score(float freq, long norm) { @@ -416,36 +436,34 @@ public final class FeatureField extends Field { * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity) */ public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) { - if (weight <= 0 || weight > MAX_WEIGHT) { - throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight); - } - if (pivot <= 0 || Float.isFinite(pivot) == false) { - throw new IllegalArgumentException("pivot must be > 0, got: " + pivot); - } - Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot)); - if (weight != 1f) { - q = new BoostQuery(q, weight); - } - return q; + return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot)); } /** * Same as {@link #newSaturationQuery(String, String, float, float)} but - * uses {@code 1f} as a weight and tries to compute a sensible default value - * for {@code pivot} using - * {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This - * isn't expected to give an optimal configuration of these parameters but - * should be a good start if you have no idea what the values of these - * parameters should be. - * @param searcher the {@link IndexSearcher} that you will search against - * @param featureFieldName the field that stores features - * @param featureName the name of the feature + * {@code 1f} is used as a weight and a reasonably good default pivot value + * is computed based on index statistics and is approximately equal to the + * geometric mean of all values that exist in the index. + * @param fieldName field that stores features + * @param featureName name of the feature + * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity) */ - public static Query newSaturationQuery(IndexSearcher searcher, - String featureFieldName, String featureName) throws IOException { - float weight = 1f; - float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName); - return newSaturationQuery(featureFieldName, featureName, weight, pivot); + public static Query newSaturationQuery(String fieldName, String featureName) { + return newSaturationQuery(fieldName, featureName, 1f, null); + } + + private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) { + if (weight <= 0 || weight > MAX_WEIGHT) { + throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight); + } + if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) { + throw new IllegalArgumentException("pivot must be > 0, got: " + pivot); + } + Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot)); + if (weight != 1f) { + q = new BoostQuery(q, weight); + } + return q; } /** @@ -483,13 +501,20 @@ public final class FeatureField extends Field { * representation in practice before converting it back to a float. Given that * floats store the exponent in the higher bits, it means that the result will * be an approximation of the geometric mean of all feature values. - * @param searcher the {@link IndexSearcher} to search against + * @param reader the {@link IndexReader} to search against * @param featureField the field that stores features * @param featureName the name of the feature */ - public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException { + static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException { Term term = new Term(featureField, featureName); - TermStates states = TermStates.build(searcher.getIndexReader().getContext(), term, true); + TermStates states = TermStates.build(reader.getContext(), term, true); + if (states.docFreq() == 0) { + // avoid division by 0 + // The return value doesn't matter much here, the term doesn't exist, + // it will never be used for scoring. Just Make sure to return a legal + // value. + return 1; + } float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq()); return decodeFeatureValue(avgFreq); } diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java index 2b387120d4d..add1b4af581 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java @@ -22,6 +22,7 @@ import java.util.Set; import org.apache.lucene.document.FeatureField.FeatureFunction; import org.apache.lucene.index.ImpactsEnum; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; @@ -50,6 +51,15 @@ final class FeatureQuery extends Query { this.function = Objects.requireNonNull(function); } + @Override + public Query rewrite(IndexReader reader) throws IOException { + FeatureFunction rewritten = function.rewrite(reader); + if (function != rewritten) { + return new FeatureQuery(fieldName, featureName, rewritten); + } + return super.rewrite(reader); + } + @Override public boolean equals(Object obj) { if (obj == null || getClass() != obj.getClass()) { @@ -80,7 +90,16 @@ final class FeatureQuery extends Query { } @Override - public void extractTerms(Set terms) {} + public void extractTerms(Set terms) { + if (scoreMode.needsScores() == false) { + // features are irrelevant to highlighting, skip + } else { + // extracting the term here will help get better scoring with + // distributed term statistics if the saturation function is used + // and the pivot value is computed automatically + terms.add(new Term(fieldName, featureName)); + } + } @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java index 2afc25032c4..312abdc680b 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java @@ -17,10 +17,15 @@ package org.apache.lucene.document; import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; import org.apache.lucene.document.Field.Store; import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiReader; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause.Occur; @@ -210,7 +215,7 @@ public class TestFeatureField extends LuceneTestCase { } public void testSatuSimScorer() { - doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f)); + doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f)); } public void testSigmSimScorer() { @@ -230,6 +235,14 @@ public class TestFeatureField extends LuceneTestCase { public void testComputePivotFeatureValue() throws IOException { Directory dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()); + + // Make sure that we create a legal pivot on missing features + DirectoryReader reader = writer.getReader(); + float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank"); + assertTrue(Float.isFinite(pivot)); + assertTrue(pivot > 0); + reader.close(); + Document doc = new Document(); FeatureField pagerank = new FeatureField("features", "pagerank", 1); doc.add(pagerank); @@ -248,11 +261,10 @@ public class TestFeatureField extends LuceneTestCase { pagerank.setFeatureValue(42); writer.addDocument(doc); - DirectoryReader reader = writer.getReader(); + reader = writer.getReader(); writer.close(); - IndexSearcher searcher = new IndexSearcher(reader); - float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank"); + pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank"); double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean assertEquals(expected, pivot, 0.1); @@ -260,6 +272,27 @@ public class TestFeatureField extends LuceneTestCase { dir.close(); } + public void testExtractTerms() throws IOException { + IndexReader reader = new MultiReader(); + IndexSearcher searcher = newSearcher(reader); + Query query = FeatureField.newLogQuery("field", "term", 2f, 42); + + Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1f); + Set terms = new HashSet<>(); + weight.extractTerms(terms); + assertEquals(Collections.emptySet(), terms); + + terms = new HashSet<>(); + weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1f); + weight.extractTerms(terms); + assertEquals(Collections.singleton(new Term("field", "term")), terms); + + terms = new HashSet<>(); + weight = searcher.createWeight(query, ScoreMode.TOP_SCORES, 1f); + weight.extractTerms(terms); + assertEquals(Collections.singleton(new Term("field", "term")), terms); + } + public void testDemo() throws IOException { Directory dir = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig() @@ -298,7 +331,7 @@ public class TestFeatureField extends LuceneTestCase { .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD) .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD) .build(); - Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank"); + Query boost = FeatureField.newSaturationQuery("features", "pagerank"); Query boostedQuery = new BooleanQuery.Builder() .add(query, Occur.MUST) .add(boost, Occur.SHOULD)