mirror of https://github.com/apache/lucene.git
LUCENE-8315: Make FeatureField easier to use.
This commit is contained in:
parent
e7cf4929f8
commit
bd20cb3c87
|
@ -17,18 +17,19 @@
|
|||
package org.apache.lucene.document;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.analysis.TokenStream;
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.TermStates;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
import org.apache.lucene.search.BoostQuery;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||
|
@ -82,7 +83,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
|||
* <p>
|
||||
* The constants in the above formulas typically need training in order to
|
||||
* compute optimal values. If you don't know where to start, the
|
||||
* {@link #newSaturationQuery(IndexSearcher, String, String)} method uses
|
||||
* {@link #newSaturationQuery(String, String)} method uses
|
||||
* {@code 1f} as a weight and tries to guess a sensible value for the
|
||||
* {@code pivot} parameter of the saturation function based on index
|
||||
* statistics, which shouldn't perform too bad. Here is an example, assuming
|
||||
|
@ -93,7 +94,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
|||
* .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
|
||||
* .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
|
||||
* .build();
|
||||
* Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
|
||||
* Query boost = FeatureField.newSaturationQuery("features", "pagerank");
|
||||
* Query boostedQuery = new BooleanQuery.Builder()
|
||||
* .add(query, Occur.MUST)
|
||||
* .add(boost, Occur.SHOULD)
|
||||
|
@ -210,6 +211,7 @@ public final class FeatureField extends Field {
|
|||
static abstract class FeatureFunction {
|
||||
abstract SimScorer scorer(String field, float w);
|
||||
abstract Explanation explain(String field, String feature, float w, int freq);
|
||||
FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
|
||||
}
|
||||
|
||||
static final class LogFunction extends FeatureFunction {
|
||||
|
@ -263,24 +265,38 @@ public final class FeatureField extends Field {
|
|||
|
||||
static final class SaturationFunction extends FeatureFunction {
|
||||
|
||||
private final float pivot;
|
||||
private final String field, feature;
|
||||
private final Float pivot;
|
||||
|
||||
SaturationFunction(float pivot) {
|
||||
SaturationFunction(String field, String feature, Float pivot) {
|
||||
this.field = field;
|
||||
this.feature = feature;
|
||||
this.pivot = pivot;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FeatureFunction rewrite(IndexReader reader) throws IOException {
|
||||
if (pivot != null) {
|
||||
return super.rewrite(reader);
|
||||
}
|
||||
float newPivot = computePivotFeatureValue(reader, field, feature);
|
||||
return new SaturationFunction(field, feature, newPivot);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
SaturationFunction that = (SaturationFunction) obj;
|
||||
return pivot == that.pivot;
|
||||
return Objects.equals(field, that.field) &&
|
||||
Objects.equals(feature, that.feature) &&
|
||||
Objects.equals(pivot, that.pivot);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Float.hashCode(pivot);
|
||||
return Objects.hash(field, feature, pivot);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -290,6 +306,10 @@ public final class FeatureField extends Field {
|
|||
|
||||
@Override
|
||||
SimScorer scorer(String field, float weight) {
|
||||
if (pivot == null) {
|
||||
throw new IllegalStateException("Rewrite first");
|
||||
}
|
||||
final float pivot = this.pivot; // unbox
|
||||
return new SimScorer(field) {
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
|
@ -416,36 +436,34 @@ public final class FeatureField extends Field {
|
|||
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
|
||||
*/
|
||||
public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
|
||||
if (weight <= 0 || weight > MAX_WEIGHT) {
|
||||
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
|
||||
}
|
||||
if (pivot <= 0 || Float.isFinite(pivot) == false) {
|
||||
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
|
||||
}
|
||||
Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot));
|
||||
if (weight != 1f) {
|
||||
q = new BoostQuery(q, weight);
|
||||
}
|
||||
return q;
|
||||
return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot));
|
||||
}
|
||||
|
||||
/**
|
||||
* Same as {@link #newSaturationQuery(String, String, float, float)} but
|
||||
* uses {@code 1f} as a weight and tries to compute a sensible default value
|
||||
* for {@code pivot} using
|
||||
* {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This
|
||||
* isn't expected to give an optimal configuration of these parameters but
|
||||
* should be a good start if you have no idea what the values of these
|
||||
* parameters should be.
|
||||
* @param searcher the {@link IndexSearcher} that you will search against
|
||||
* @param featureFieldName the field that stores features
|
||||
* @param featureName the name of the feature
|
||||
* {@code 1f} is used as a weight and a reasonably good default pivot value
|
||||
* is computed based on index statistics and is approximately equal to the
|
||||
* geometric mean of all values that exist in the index.
|
||||
* @param fieldName field that stores features
|
||||
* @param featureName name of the feature
|
||||
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
|
||||
*/
|
||||
public static Query newSaturationQuery(IndexSearcher searcher,
|
||||
String featureFieldName, String featureName) throws IOException {
|
||||
float weight = 1f;
|
||||
float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName);
|
||||
return newSaturationQuery(featureFieldName, featureName, weight, pivot);
|
||||
public static Query newSaturationQuery(String fieldName, String featureName) {
|
||||
return newSaturationQuery(fieldName, featureName, 1f, null);
|
||||
}
|
||||
|
||||
private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) {
|
||||
if (weight <= 0 || weight > MAX_WEIGHT) {
|
||||
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
|
||||
}
|
||||
if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) {
|
||||
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
|
||||
}
|
||||
Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot));
|
||||
if (weight != 1f) {
|
||||
q = new BoostQuery(q, weight);
|
||||
}
|
||||
return q;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -483,13 +501,20 @@ public final class FeatureField extends Field {
|
|||
* representation in practice before converting it back to a float. Given that
|
||||
* floats store the exponent in the higher bits, it means that the result will
|
||||
* be an approximation of the geometric mean of all feature values.
|
||||
* @param searcher the {@link IndexSearcher} to search against
|
||||
* @param reader the {@link IndexReader} to search against
|
||||
* @param featureField the field that stores features
|
||||
* @param featureName the name of the feature
|
||||
*/
|
||||
public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException {
|
||||
static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException {
|
||||
Term term = new Term(featureField, featureName);
|
||||
TermStates states = TermStates.build(searcher.getIndexReader().getContext(), term, true);
|
||||
TermStates states = TermStates.build(reader.getContext(), term, true);
|
||||
if (states.docFreq() == 0) {
|
||||
// avoid division by 0
|
||||
// The return value doesn't matter much here, the term doesn't exist,
|
||||
// it will never be used for scoring. Just Make sure to return a legal
|
||||
// value.
|
||||
return 1;
|
||||
}
|
||||
float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq());
|
||||
return decodeFeatureValue(avgFreq);
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.Set;
|
|||
|
||||
import org.apache.lucene.document.FeatureField.FeatureFunction;
|
||||
import org.apache.lucene.index.ImpactsEnum;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.Term;
|
||||
|
@ -50,6 +51,15 @@ final class FeatureQuery extends Query {
|
|||
this.function = Objects.requireNonNull(function);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query rewrite(IndexReader reader) throws IOException {
|
||||
FeatureFunction rewritten = function.rewrite(reader);
|
||||
if (function != rewritten) {
|
||||
return new FeatureQuery(fieldName, featureName, rewritten);
|
||||
}
|
||||
return super.rewrite(reader);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == null || getClass() != obj.getClass()) {
|
||||
|
@ -80,7 +90,16 @@ final class FeatureQuery extends Query {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void extractTerms(Set<Term> terms) {}
|
||||
public void extractTerms(Set<Term> terms) {
|
||||
if (scoreMode.needsScores() == false) {
|
||||
// features are irrelevant to highlighting, skip
|
||||
} else {
|
||||
// extracting the term here will help get better scoring with
|
||||
// distributed term statistics if the saturation function is used
|
||||
// and the pivot value is computed automatically
|
||||
terms.add(new Term(fieldName, featureName));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
|
|
|
@ -17,10 +17,15 @@
|
|||
package org.apache.lucene.document;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.document.Field.Store;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.MultiReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.BooleanClause.Occur;
|
||||
|
@ -210,7 +215,7 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testSatuSimScorer() {
|
||||
doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f));
|
||||
doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f));
|
||||
}
|
||||
|
||||
public void testSigmSimScorer() {
|
||||
|
@ -230,6 +235,14 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
public void testComputePivotFeatureValue() throws IOException {
|
||||
Directory dir = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
|
||||
|
||||
// Make sure that we create a legal pivot on missing features
|
||||
DirectoryReader reader = writer.getReader();
|
||||
float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
|
||||
assertTrue(Float.isFinite(pivot));
|
||||
assertTrue(pivot > 0);
|
||||
reader.close();
|
||||
|
||||
Document doc = new Document();
|
||||
FeatureField pagerank = new FeatureField("features", "pagerank", 1);
|
||||
doc.add(pagerank);
|
||||
|
@ -248,11 +261,10 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
pagerank.setFeatureValue(42);
|
||||
writer.addDocument(doc);
|
||||
|
||||
DirectoryReader reader = writer.getReader();
|
||||
reader = writer.getReader();
|
||||
writer.close();
|
||||
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
|
||||
pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
|
||||
double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean
|
||||
assertEquals(expected, pivot, 0.1);
|
||||
|
||||
|
@ -260,6 +272,27 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
dir.close();
|
||||
}
|
||||
|
||||
public void testExtractTerms() throws IOException {
|
||||
IndexReader reader = new MultiReader();
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
Query query = FeatureField.newLogQuery("field", "term", 2f, 42);
|
||||
|
||||
Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||
Set<Term> terms = new HashSet<>();
|
||||
weight.extractTerms(terms);
|
||||
assertEquals(Collections.emptySet(), terms);
|
||||
|
||||
terms = new HashSet<>();
|
||||
weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1f);
|
||||
weight.extractTerms(terms);
|
||||
assertEquals(Collections.singleton(new Term("field", "term")), terms);
|
||||
|
||||
terms = new HashSet<>();
|
||||
weight = searcher.createWeight(query, ScoreMode.TOP_SCORES, 1f);
|
||||
weight.extractTerms(terms);
|
||||
assertEquals(Collections.singleton(new Term("field", "term")), terms);
|
||||
}
|
||||
|
||||
public void testDemo() throws IOException {
|
||||
Directory dir = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()
|
||||
|
@ -298,7 +331,7 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
.add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
|
||||
.add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
|
||||
.build();
|
||||
Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
|
||||
Query boost = FeatureField.newSaturationQuery("features", "pagerank");
|
||||
Query boostedQuery = new BooleanQuery.Builder()
|
||||
.add(query, Occur.MUST)
|
||||
.add(boost, Occur.SHOULD)
|
||||
|
|
Loading…
Reference in New Issue