LUCENE-8315: Make FeatureField easier to use.

This commit is contained in:
Adrien Grand 2018-05-16 17:16:16 +02:00
parent e7cf4929f8
commit bd20cb3c87
3 changed files with 118 additions and 41 deletions

View File

@ -17,18 +17,19 @@
package org.apache.lucene.document; package org.apache.lucene.document;
import java.io.IOException; import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates; import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer; import org.apache.lucene.search.similarities.Similarity.SimScorer;
@ -82,7 +83,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
* <p> * <p>
* The constants in the above formulas typically need training in order to * The constants in the above formulas typically need training in order to
* compute optimal values. If you don't know where to start, the * compute optimal values. If you don't know where to start, the
* {@link #newSaturationQuery(IndexSearcher, String, String)} method uses * {@link #newSaturationQuery(String, String)} method uses
* {@code 1f} as a weight and tries to guess a sensible value for the * {@code 1f} as a weight and tries to guess a sensible value for the
* {@code pivot} parameter of the saturation function based on index * {@code pivot} parameter of the saturation function based on index
* statistics, which shouldn't perform too bad. Here is an example, assuming * statistics, which shouldn't perform too bad. Here is an example, assuming
@ -93,7 +94,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
* .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD) * .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
* .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD) * .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
* .build(); * .build();
* Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank"); * Query boost = FeatureField.newSaturationQuery("features", "pagerank");
* Query boostedQuery = new BooleanQuery.Builder() * Query boostedQuery = new BooleanQuery.Builder()
* .add(query, Occur.MUST) * .add(query, Occur.MUST)
* .add(boost, Occur.SHOULD) * .add(boost, Occur.SHOULD)
@ -210,6 +211,7 @@ public final class FeatureField extends Field {
static abstract class FeatureFunction { static abstract class FeatureFunction {
abstract SimScorer scorer(String field, float w); abstract SimScorer scorer(String field, float w);
abstract Explanation explain(String field, String feature, float w, int freq); abstract Explanation explain(String field, String feature, float w, int freq);
FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
} }
static final class LogFunction extends FeatureFunction { static final class LogFunction extends FeatureFunction {
@ -263,24 +265,38 @@ public final class FeatureField extends Field {
static final class SaturationFunction extends FeatureFunction { static final class SaturationFunction extends FeatureFunction {
private final float pivot; private final String field, feature;
private final Float pivot;
SaturationFunction(float pivot) { SaturationFunction(String field, String feature, Float pivot) {
this.field = field;
this.feature = feature;
this.pivot = pivot; this.pivot = pivot;
} }
@Override
public FeatureFunction rewrite(IndexReader reader) throws IOException {
if (pivot != null) {
return super.rewrite(reader);
}
float newPivot = computePivotFeatureValue(reader, field, feature);
return new SaturationFunction(field, feature, newPivot);
}
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) { if (obj == null || getClass() != obj.getClass()) {
return false; return false;
} }
SaturationFunction that = (SaturationFunction) obj; SaturationFunction that = (SaturationFunction) obj;
return pivot == that.pivot; return Objects.equals(field, that.field) &&
Objects.equals(feature, that.feature) &&
Objects.equals(pivot, that.pivot);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Float.hashCode(pivot); return Objects.hash(field, feature, pivot);
} }
@Override @Override
@ -290,6 +306,10 @@ public final class FeatureField extends Field {
@Override @Override
SimScorer scorer(String field, float weight) { SimScorer scorer(String field, float weight) {
if (pivot == null) {
throw new IllegalStateException("Rewrite first");
}
final float pivot = this.pivot; // unbox
return new SimScorer(field) { return new SimScorer(field) {
@Override @Override
public float score(float freq, long norm) { public float score(float freq, long norm) {
@ -416,36 +436,34 @@ public final class FeatureField extends Field {
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity) * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
*/ */
public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) { public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
if (weight <= 0 || weight > MAX_WEIGHT) { return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot));
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
}
if (pivot <= 0 || Float.isFinite(pivot) == false) {
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
}
Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot));
if (weight != 1f) {
q = new BoostQuery(q, weight);
}
return q;
} }
/** /**
* Same as {@link #newSaturationQuery(String, String, float, float)} but * Same as {@link #newSaturationQuery(String, String, float, float)} but
* uses {@code 1f} as a weight and tries to compute a sensible default value * {@code 1f} is used as a weight and a reasonably good default pivot value
* for {@code pivot} using * is computed based on index statistics and is approximately equal to the
* {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This * geometric mean of all values that exist in the index.
* isn't expected to give an optimal configuration of these parameters but * @param fieldName field that stores features
* should be a good start if you have no idea what the values of these * @param featureName name of the feature
* parameters should be. * @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
* @param searcher the {@link IndexSearcher} that you will search against
* @param featureFieldName the field that stores features
* @param featureName the name of the feature
*/ */
public static Query newSaturationQuery(IndexSearcher searcher, public static Query newSaturationQuery(String fieldName, String featureName) {
String featureFieldName, String featureName) throws IOException { return newSaturationQuery(fieldName, featureName, 1f, null);
float weight = 1f; }
float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName);
return newSaturationQuery(featureFieldName, featureName, weight, pivot); private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) {
if (weight <= 0 || weight > MAX_WEIGHT) {
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
}
if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) {
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
}
Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot));
if (weight != 1f) {
q = new BoostQuery(q, weight);
}
return q;
} }
/** /**
@ -483,13 +501,20 @@ public final class FeatureField extends Field {
* representation in practice before converting it back to a float. Given that * representation in practice before converting it back to a float. Given that
* floats store the exponent in the higher bits, it means that the result will * floats store the exponent in the higher bits, it means that the result will
* be an approximation of the geometric mean of all feature values. * be an approximation of the geometric mean of all feature values.
* @param searcher the {@link IndexSearcher} to search against * @param reader the {@link IndexReader} to search against
* @param featureField the field that stores features * @param featureField the field that stores features
* @param featureName the name of the feature * @param featureName the name of the feature
*/ */
public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException { static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException {
Term term = new Term(featureField, featureName); Term term = new Term(featureField, featureName);
TermStates states = TermStates.build(searcher.getIndexReader().getContext(), term, true); TermStates states = TermStates.build(reader.getContext(), term, true);
if (states.docFreq() == 0) {
// avoid division by 0
// The return value doesn't matter much here, the term doesn't exist,
// it will never be used for scoring. Just Make sure to return a legal
// value.
return 1;
}
float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq()); float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq());
return decodeFeatureValue(avgFreq); return decodeFeatureValue(avgFreq);
} }

View File

@ -22,6 +22,7 @@ import java.util.Set;
import org.apache.lucene.document.FeatureField.FeatureFunction; import org.apache.lucene.document.FeatureField.FeatureFunction;
import org.apache.lucene.index.ImpactsEnum; import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
@ -50,6 +51,15 @@ final class FeatureQuery extends Query {
this.function = Objects.requireNonNull(function); this.function = Objects.requireNonNull(function);
} }
@Override
public Query rewrite(IndexReader reader) throws IOException {
FeatureFunction rewritten = function.rewrite(reader);
if (function != rewritten) {
return new FeatureQuery(fieldName, featureName, rewritten);
}
return super.rewrite(reader);
}
@Override @Override
public boolean equals(Object obj) { public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) { if (obj == null || getClass() != obj.getClass()) {
@ -80,7 +90,16 @@ final class FeatureQuery extends Query {
} }
@Override @Override
public void extractTerms(Set<Term> terms) {} public void extractTerms(Set<Term> terms) {
if (scoreMode.needsScores() == false) {
// features are irrelevant to highlighting, skip
} else {
// extracting the term here will help get better scoring with
// distributed term statistics if the saturation function is used
// and the pivot value is computed automatically
terms.add(new Term(fieldName, featureName));
}
}
@Override @Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException { public Explanation explain(LeafReaderContext context, int doc) throws IOException {

View File

@ -17,10 +17,15 @@
package org.apache.lucene.document; package org.apache.lucene.document;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.Field.Store;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiReader;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanClause.Occur;
@ -210,7 +215,7 @@ public class TestFeatureField extends LuceneTestCase {
} }
public void testSatuSimScorer() { public void testSatuSimScorer() {
doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f)); doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f));
} }
public void testSigmSimScorer() { public void testSigmSimScorer() {
@ -230,6 +235,14 @@ public class TestFeatureField extends LuceneTestCase {
public void testComputePivotFeatureValue() throws IOException { public void testComputePivotFeatureValue() throws IOException {
Directory dir = newDirectory(); Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
// Make sure that we create a legal pivot on missing features
DirectoryReader reader = writer.getReader();
float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
assertTrue(Float.isFinite(pivot));
assertTrue(pivot > 0);
reader.close();
Document doc = new Document(); Document doc = new Document();
FeatureField pagerank = new FeatureField("features", "pagerank", 1); FeatureField pagerank = new FeatureField("features", "pagerank", 1);
doc.add(pagerank); doc.add(pagerank);
@ -248,11 +261,10 @@ public class TestFeatureField extends LuceneTestCase {
pagerank.setFeatureValue(42); pagerank.setFeatureValue(42);
writer.addDocument(doc); writer.addDocument(doc);
DirectoryReader reader = writer.getReader(); reader = writer.getReader();
writer.close(); writer.close();
IndexSearcher searcher = new IndexSearcher(reader); pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean
assertEquals(expected, pivot, 0.1); assertEquals(expected, pivot, 0.1);
@ -260,6 +272,27 @@ public class TestFeatureField extends LuceneTestCase {
dir.close(); dir.close();
} }
public void testExtractTerms() throws IOException {
IndexReader reader = new MultiReader();
IndexSearcher searcher = newSearcher(reader);
Query query = FeatureField.newLogQuery("field", "term", 2f, 42);
Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1f);
Set<Term> terms = new HashSet<>();
weight.extractTerms(terms);
assertEquals(Collections.emptySet(), terms);
terms = new HashSet<>();
weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1f);
weight.extractTerms(terms);
assertEquals(Collections.singleton(new Term("field", "term")), terms);
terms = new HashSet<>();
weight = searcher.createWeight(query, ScoreMode.TOP_SCORES, 1f);
weight.extractTerms(terms);
assertEquals(Collections.singleton(new Term("field", "term")), terms);
}
public void testDemo() throws IOException { public void testDemo() throws IOException {
Directory dir = newDirectory(); Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig() RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()
@ -298,7 +331,7 @@ public class TestFeatureField extends LuceneTestCase {
.add(new TermQuery(new Term("body", "apache")), Occur.SHOULD) .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
.add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD) .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
.build(); .build();
Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank"); Query boost = FeatureField.newSaturationQuery("features", "pagerank");
Query boostedQuery = new BooleanQuery.Builder() Query boostedQuery = new BooleanQuery.Builder()
.add(query, Occur.MUST) .add(query, Occur.MUST)
.add(boost, Occur.SHOULD) .add(boost, Occur.SHOULD)