diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java index 1b8a4e5dfce..378c2afc904 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java @@ -17,18 +17,19 @@ package org.apache.lucene.document; import java.io.IOException; +import java.util.Objects; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.Similarity.SimScorer; @@ -82,7 +83,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; *
* The constants in the above formulas typically need training in order to
* compute optimal values. If you don't know where to start, the
- * {@link #newSaturationQuery(IndexSearcher, String, String)} method uses
+ * {@link #newSaturationQuery(String, String)} method uses
* {@code 1f} as a weight and tries to guess a sensible value for the
* {@code pivot} parameter of the saturation function based on index
* statistics, which shouldn't perform too bad. Here is an example, assuming
@@ -93,7 +94,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
* .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
* .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
* .build();
- * Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
+ * Query boost = FeatureField.newSaturationQuery("features", "pagerank");
* Query boostedQuery = new BooleanQuery.Builder()
* .add(query, Occur.MUST)
* .add(boost, Occur.SHOULD)
@@ -210,6 +211,7 @@ public final class FeatureField extends Field {
static abstract class FeatureFunction {
abstract SimScorer scorer(String field, float w);
abstract Explanation explain(String field, String feature, float w, int freq);
+ FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
}
static final class LogFunction extends FeatureFunction {
@@ -263,24 +265,38 @@ public final class FeatureField extends Field {
static final class SaturationFunction extends FeatureFunction {
- private final float pivot;
+ private final String field, feature;
+ private final Float pivot;
- SaturationFunction(float pivot) {
+ SaturationFunction(String field, String feature, Float pivot) {
+ this.field = field;
+ this.feature = feature;
this.pivot = pivot;
}
+ @Override
+ public FeatureFunction rewrite(IndexReader reader) throws IOException {
+ if (pivot != null) {
+ return super.rewrite(reader);
+ }
+ float newPivot = computePivotFeatureValue(reader, field, feature);
+ return new SaturationFunction(field, feature, newPivot);
+ }
+
@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
SaturationFunction that = (SaturationFunction) obj;
- return pivot == that.pivot;
+ return Objects.equals(field, that.field) &&
+ Objects.equals(feature, that.feature) &&
+ Objects.equals(pivot, that.pivot);
}
@Override
public int hashCode() {
- return Float.hashCode(pivot);
+ return Objects.hash(field, feature, pivot);
}
@Override
@@ -290,6 +306,10 @@ public final class FeatureField extends Field {
@Override
SimScorer scorer(String field, float weight) {
+ if (pivot == null) {
+ throw new IllegalStateException("Rewrite first");
+ }
+ final float pivot = this.pivot; // unbox
return new SimScorer(field) {
@Override
public float score(float freq, long norm) {
@@ -416,36 +436,34 @@ public final class FeatureField extends Field {
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
*/
public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
- if (weight <= 0 || weight > MAX_WEIGHT) {
- throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
- }
- if (pivot <= 0 || Float.isFinite(pivot) == false) {
- throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
- }
- Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot));
- if (weight != 1f) {
- q = new BoostQuery(q, weight);
- }
- return q;
+ return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot));
}
/**
* Same as {@link #newSaturationQuery(String, String, float, float)} but
- * uses {@code 1f} as a weight and tries to compute a sensible default value
- * for {@code pivot} using
- * {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This
- * isn't expected to give an optimal configuration of these parameters but
- * should be a good start if you have no idea what the values of these
- * parameters should be.
- * @param searcher the {@link IndexSearcher} that you will search against
- * @param featureFieldName the field that stores features
- * @param featureName the name of the feature
+ * {@code 1f} is used as a weight and a reasonably good default pivot value
+ * is computed based on index statistics and is approximately equal to the
+ * geometric mean of all values that exist in the index.
+ * @param fieldName field that stores features
+ * @param featureName name of the feature
+ * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
*/
- public static Query newSaturationQuery(IndexSearcher searcher,
- String featureFieldName, String featureName) throws IOException {
- float weight = 1f;
- float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName);
- return newSaturationQuery(featureFieldName, featureName, weight, pivot);
+ public static Query newSaturationQuery(String fieldName, String featureName) {
+ return newSaturationQuery(fieldName, featureName, 1f, null);
+ }
+
+ private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) {
+ if (weight <= 0 || weight > MAX_WEIGHT) {
+ throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
+ }
+ if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) {
+ throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
+ }
+ Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot));
+ if (weight != 1f) {
+ q = new BoostQuery(q, weight);
+ }
+ return q;
}
/**
@@ -483,13 +501,20 @@ public final class FeatureField extends Field {
* representation in practice before converting it back to a float. Given that
* floats store the exponent in the higher bits, it means that the result will
* be an approximation of the geometric mean of all feature values.
- * @param searcher the {@link IndexSearcher} to search against
+ * @param reader the {@link IndexReader} to search against
* @param featureField the field that stores features
* @param featureName the name of the feature
*/
- public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException {
+ static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException {
Term term = new Term(featureField, featureName);
- TermStates states = TermStates.build(searcher.getIndexReader().getContext(), term, true);
+ TermStates states = TermStates.build(reader.getContext(), term, true);
+ if (states.docFreq() == 0) {
+ // avoid division by 0
+ // The return value doesn't matter much here, the term doesn't exist,
+ // it will never be used for scoring. Just Make sure to return a legal
+ // value.
+ return 1;
+ }
float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq());
return decodeFeatureValue(avgFreq);
}
diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
index 2b387120d4d..add1b4af581 100644
--- a/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/FeatureQuery.java
@@ -22,6 +22,7 @@ import java.util.Set;
import org.apache.lucene.document.FeatureField.FeatureFunction;
import org.apache.lucene.index.ImpactsEnum;
+import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
@@ -50,6 +51,15 @@ final class FeatureQuery extends Query {
this.function = Objects.requireNonNull(function);
}
+ @Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ FeatureFunction rewritten = function.rewrite(reader);
+ if (function != rewritten) {
+ return new FeatureQuery(fieldName, featureName, rewritten);
+ }
+ return super.rewrite(reader);
+ }
+
@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
@@ -80,7 +90,16 @@ final class FeatureQuery extends Query {
}
@Override
- public void extractTerms(Set