mirror of https://github.com/apache/lucene.git
LUCENE-8315: Make FeatureField easier to use.
This commit is contained in:
parent
e7cf4929f8
commit
bd20cb3c87
|
@ -17,18 +17,19 @@
|
||||||
package org.apache.lucene.document;
|
package org.apache.lucene.document;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.analysis.TokenStream;
|
import org.apache.lucene.analysis.TokenStream;
|
||||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||||
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
||||||
import org.apache.lucene.index.IndexOptions;
|
import org.apache.lucene.index.IndexOptions;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.index.TermStates;
|
import org.apache.lucene.index.TermStates;
|
||||||
import org.apache.lucene.search.BooleanQuery;
|
import org.apache.lucene.search.BooleanQuery;
|
||||||
import org.apache.lucene.search.BoostQuery;
|
import org.apache.lucene.search.BoostQuery;
|
||||||
import org.apache.lucene.search.Explanation;
|
import org.apache.lucene.search.Explanation;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||||
import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||||
|
@ -82,7 +83,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||||
* <p>
|
* <p>
|
||||||
* The constants in the above formulas typically need training in order to
|
* The constants in the above formulas typically need training in order to
|
||||||
* compute optimal values. If you don't know where to start, the
|
* compute optimal values. If you don't know where to start, the
|
||||||
* {@link #newSaturationQuery(IndexSearcher, String, String)} method uses
|
* {@link #newSaturationQuery(String, String)} method uses
|
||||||
* {@code 1f} as a weight and tries to guess a sensible value for the
|
* {@code 1f} as a weight and tries to guess a sensible value for the
|
||||||
* {@code pivot} parameter of the saturation function based on index
|
* {@code pivot} parameter of the saturation function based on index
|
||||||
* statistics, which shouldn't perform too bad. Here is an example, assuming
|
* statistics, which shouldn't perform too bad. Here is an example, assuming
|
||||||
|
@ -93,7 +94,7 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||||
* .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
|
* .add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
|
||||||
* .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
|
* .add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
|
||||||
* .build();
|
* .build();
|
||||||
* Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
|
* Query boost = FeatureField.newSaturationQuery("features", "pagerank");
|
||||||
* Query boostedQuery = new BooleanQuery.Builder()
|
* Query boostedQuery = new BooleanQuery.Builder()
|
||||||
* .add(query, Occur.MUST)
|
* .add(query, Occur.MUST)
|
||||||
* .add(boost, Occur.SHOULD)
|
* .add(boost, Occur.SHOULD)
|
||||||
|
@ -210,6 +211,7 @@ public final class FeatureField extends Field {
|
||||||
static abstract class FeatureFunction {
|
static abstract class FeatureFunction {
|
||||||
abstract SimScorer scorer(String field, float w);
|
abstract SimScorer scorer(String field, float w);
|
||||||
abstract Explanation explain(String field, String feature, float w, int freq);
|
abstract Explanation explain(String field, String feature, float w, int freq);
|
||||||
|
FeatureFunction rewrite(IndexReader reader) throws IOException { return this; }
|
||||||
}
|
}
|
||||||
|
|
||||||
static final class LogFunction extends FeatureFunction {
|
static final class LogFunction extends FeatureFunction {
|
||||||
|
@ -263,24 +265,38 @@ public final class FeatureField extends Field {
|
||||||
|
|
||||||
static final class SaturationFunction extends FeatureFunction {
|
static final class SaturationFunction extends FeatureFunction {
|
||||||
|
|
||||||
private final float pivot;
|
private final String field, feature;
|
||||||
|
private final Float pivot;
|
||||||
|
|
||||||
SaturationFunction(float pivot) {
|
SaturationFunction(String field, String feature, Float pivot) {
|
||||||
|
this.field = field;
|
||||||
|
this.feature = feature;
|
||||||
this.pivot = pivot;
|
this.pivot = pivot;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FeatureFunction rewrite(IndexReader reader) throws IOException {
|
||||||
|
if (pivot != null) {
|
||||||
|
return super.rewrite(reader);
|
||||||
|
}
|
||||||
|
float newPivot = computePivotFeatureValue(reader, field, feature);
|
||||||
|
return new SaturationFunction(field, feature, newPivot);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object obj) {
|
public boolean equals(Object obj) {
|
||||||
if (obj == null || getClass() != obj.getClass()) {
|
if (obj == null || getClass() != obj.getClass()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
SaturationFunction that = (SaturationFunction) obj;
|
SaturationFunction that = (SaturationFunction) obj;
|
||||||
return pivot == that.pivot;
|
return Objects.equals(field, that.field) &&
|
||||||
|
Objects.equals(feature, that.feature) &&
|
||||||
|
Objects.equals(pivot, that.pivot);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Float.hashCode(pivot);
|
return Objects.hash(field, feature, pivot);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -290,6 +306,10 @@ public final class FeatureField extends Field {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
SimScorer scorer(String field, float weight) {
|
SimScorer scorer(String field, float weight) {
|
||||||
|
if (pivot == null) {
|
||||||
|
throw new IllegalStateException("Rewrite first");
|
||||||
|
}
|
||||||
|
final float pivot = this.pivot; // unbox
|
||||||
return new SimScorer(field) {
|
return new SimScorer(field) {
|
||||||
@Override
|
@Override
|
||||||
public float score(float freq, long norm) {
|
public float score(float freq, long norm) {
|
||||||
|
@ -416,36 +436,34 @@ public final class FeatureField extends Field {
|
||||||
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
|
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
|
||||||
*/
|
*/
|
||||||
public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
|
public static Query newSaturationQuery(String fieldName, String featureName, float weight, float pivot) {
|
||||||
if (weight <= 0 || weight > MAX_WEIGHT) {
|
return newSaturationQuery(fieldName, featureName, weight, Float.valueOf(pivot));
|
||||||
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
|
|
||||||
}
|
|
||||||
if (pivot <= 0 || Float.isFinite(pivot) == false) {
|
|
||||||
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
|
|
||||||
}
|
|
||||||
Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(pivot));
|
|
||||||
if (weight != 1f) {
|
|
||||||
q = new BoostQuery(q, weight);
|
|
||||||
}
|
|
||||||
return q;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Same as {@link #newSaturationQuery(String, String, float, float)} but
|
* Same as {@link #newSaturationQuery(String, String, float, float)} but
|
||||||
* uses {@code 1f} as a weight and tries to compute a sensible default value
|
* {@code 1f} is used as a weight and a reasonably good default pivot value
|
||||||
* for {@code pivot} using
|
* is computed based on index statistics and is approximately equal to the
|
||||||
* {@link #computePivotFeatureValue(IndexSearcher, String, String)}. This
|
* geometric mean of all values that exist in the index.
|
||||||
* isn't expected to give an optimal configuration of these parameters but
|
* @param fieldName field that stores features
|
||||||
* should be a good start if you have no idea what the values of these
|
* @param featureName name of the feature
|
||||||
* parameters should be.
|
* @throws IllegalArgumentException if weight is not in (0,64] or pivot is not in (0, +Infinity)
|
||||||
* @param searcher the {@link IndexSearcher} that you will search against
|
|
||||||
* @param featureFieldName the field that stores features
|
|
||||||
* @param featureName the name of the feature
|
|
||||||
*/
|
*/
|
||||||
public static Query newSaturationQuery(IndexSearcher searcher,
|
public static Query newSaturationQuery(String fieldName, String featureName) {
|
||||||
String featureFieldName, String featureName) throws IOException {
|
return newSaturationQuery(fieldName, featureName, 1f, null);
|
||||||
float weight = 1f;
|
}
|
||||||
float pivot = computePivotFeatureValue(searcher, featureFieldName, featureName);
|
|
||||||
return newSaturationQuery(featureFieldName, featureName, weight, pivot);
|
private static Query newSaturationQuery(String fieldName, String featureName, float weight, Float pivot) {
|
||||||
|
if (weight <= 0 || weight > MAX_WEIGHT) {
|
||||||
|
throw new IllegalArgumentException("weight must be in (0, " + MAX_WEIGHT + "], got: " + weight);
|
||||||
|
}
|
||||||
|
if (pivot != null && (pivot <= 0 || Float.isFinite(pivot) == false)) {
|
||||||
|
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
|
||||||
|
}
|
||||||
|
Query q = new FeatureQuery(fieldName, featureName, new SaturationFunction(fieldName, featureName, pivot));
|
||||||
|
if (weight != 1f) {
|
||||||
|
q = new BoostQuery(q, weight);
|
||||||
|
}
|
||||||
|
return q;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -483,13 +501,20 @@ public final class FeatureField extends Field {
|
||||||
* representation in practice before converting it back to a float. Given that
|
* representation in practice before converting it back to a float. Given that
|
||||||
* floats store the exponent in the higher bits, it means that the result will
|
* floats store the exponent in the higher bits, it means that the result will
|
||||||
* be an approximation of the geometric mean of all feature values.
|
* be an approximation of the geometric mean of all feature values.
|
||||||
* @param searcher the {@link IndexSearcher} to search against
|
* @param reader the {@link IndexReader} to search against
|
||||||
* @param featureField the field that stores features
|
* @param featureField the field that stores features
|
||||||
* @param featureName the name of the feature
|
* @param featureName the name of the feature
|
||||||
*/
|
*/
|
||||||
public static float computePivotFeatureValue(IndexSearcher searcher, String featureField, String featureName) throws IOException {
|
static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) throws IOException {
|
||||||
Term term = new Term(featureField, featureName);
|
Term term = new Term(featureField, featureName);
|
||||||
TermStates states = TermStates.build(searcher.getIndexReader().getContext(), term, true);
|
TermStates states = TermStates.build(reader.getContext(), term, true);
|
||||||
|
if (states.docFreq() == 0) {
|
||||||
|
// avoid division by 0
|
||||||
|
// The return value doesn't matter much here, the term doesn't exist,
|
||||||
|
// it will never be used for scoring. Just Make sure to return a legal
|
||||||
|
// value.
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq());
|
float avgFreq = (float) ((double) states.totalTermFreq() / states.docFreq());
|
||||||
return decodeFeatureValue(avgFreq);
|
return decodeFeatureValue(avgFreq);
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import java.util.Set;
|
||||||
|
|
||||||
import org.apache.lucene.document.FeatureField.FeatureFunction;
|
import org.apache.lucene.document.FeatureField.FeatureFunction;
|
||||||
import org.apache.lucene.index.ImpactsEnum;
|
import org.apache.lucene.index.ImpactsEnum;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.PostingsEnum;
|
import org.apache.lucene.index.PostingsEnum;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
|
@ -50,6 +51,15 @@ final class FeatureQuery extends Query {
|
||||||
this.function = Objects.requireNonNull(function);
|
this.function = Objects.requireNonNull(function);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Query rewrite(IndexReader reader) throws IOException {
|
||||||
|
FeatureFunction rewritten = function.rewrite(reader);
|
||||||
|
if (function != rewritten) {
|
||||||
|
return new FeatureQuery(fieldName, featureName, rewritten);
|
||||||
|
}
|
||||||
|
return super.rewrite(reader);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object obj) {
|
public boolean equals(Object obj) {
|
||||||
if (obj == null || getClass() != obj.getClass()) {
|
if (obj == null || getClass() != obj.getClass()) {
|
||||||
|
@ -80,7 +90,16 @@ final class FeatureQuery extends Query {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void extractTerms(Set<Term> terms) {}
|
public void extractTerms(Set<Term> terms) {
|
||||||
|
if (scoreMode.needsScores() == false) {
|
||||||
|
// features are irrelevant to highlighting, skip
|
||||||
|
} else {
|
||||||
|
// extracting the term here will help get better scoring with
|
||||||
|
// distributed term statistics if the saturation function is used
|
||||||
|
// and the pivot value is computed automatically
|
||||||
|
terms.add(new Term(fieldName, featureName));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||||
|
|
|
@ -17,10 +17,15 @@
|
||||||
package org.apache.lucene.document;
|
package org.apache.lucene.document;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import org.apache.lucene.document.Field.Store;
|
import org.apache.lucene.document.Field.Store;
|
||||||
import org.apache.lucene.index.DirectoryReader;
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.MultiReader;
|
||||||
import org.apache.lucene.index.RandomIndexWriter;
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.BooleanClause.Occur;
|
import org.apache.lucene.search.BooleanClause.Occur;
|
||||||
|
@ -210,7 +215,7 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSatuSimScorer() {
|
public void testSatuSimScorer() {
|
||||||
doTestSimScorer(new FeatureField.SaturationFunction(20f).scorer("foo", 3f));
|
doTestSimScorer(new FeatureField.SaturationFunction("foo", "bar", 20f).scorer("foo", 3f));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSigmSimScorer() {
|
public void testSigmSimScorer() {
|
||||||
|
@ -230,6 +235,14 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
public void testComputePivotFeatureValue() throws IOException {
|
public void testComputePivotFeatureValue() throws IOException {
|
||||||
Directory dir = newDirectory();
|
Directory dir = newDirectory();
|
||||||
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
|
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig());
|
||||||
|
|
||||||
|
// Make sure that we create a legal pivot on missing features
|
||||||
|
DirectoryReader reader = writer.getReader();
|
||||||
|
float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
|
||||||
|
assertTrue(Float.isFinite(pivot));
|
||||||
|
assertTrue(pivot > 0);
|
||||||
|
reader.close();
|
||||||
|
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
FeatureField pagerank = new FeatureField("features", "pagerank", 1);
|
FeatureField pagerank = new FeatureField("features", "pagerank", 1);
|
||||||
doc.add(pagerank);
|
doc.add(pagerank);
|
||||||
|
@ -248,11 +261,10 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
pagerank.setFeatureValue(42);
|
pagerank.setFeatureValue(42);
|
||||||
writer.addDocument(doc);
|
writer.addDocument(doc);
|
||||||
|
|
||||||
DirectoryReader reader = writer.getReader();
|
reader = writer.getReader();
|
||||||
writer.close();
|
writer.close();
|
||||||
|
|
||||||
IndexSearcher searcher = new IndexSearcher(reader);
|
pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
|
||||||
float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
|
|
||||||
double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean
|
double expected = Math.pow(10 * 100 * 1 * 42, 1/4.); // geometric mean
|
||||||
assertEquals(expected, pivot, 0.1);
|
assertEquals(expected, pivot, 0.1);
|
||||||
|
|
||||||
|
@ -260,6 +272,27 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
dir.close();
|
dir.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testExtractTerms() throws IOException {
|
||||||
|
IndexReader reader = new MultiReader();
|
||||||
|
IndexSearcher searcher = newSearcher(reader);
|
||||||
|
Query query = FeatureField.newLogQuery("field", "term", 2f, 42);
|
||||||
|
|
||||||
|
Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1f);
|
||||||
|
Set<Term> terms = new HashSet<>();
|
||||||
|
weight.extractTerms(terms);
|
||||||
|
assertEquals(Collections.emptySet(), terms);
|
||||||
|
|
||||||
|
terms = new HashSet<>();
|
||||||
|
weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1f);
|
||||||
|
weight.extractTerms(terms);
|
||||||
|
assertEquals(Collections.singleton(new Term("field", "term")), terms);
|
||||||
|
|
||||||
|
terms = new HashSet<>();
|
||||||
|
weight = searcher.createWeight(query, ScoreMode.TOP_SCORES, 1f);
|
||||||
|
weight.extractTerms(terms);
|
||||||
|
assertEquals(Collections.singleton(new Term("field", "term")), terms);
|
||||||
|
}
|
||||||
|
|
||||||
public void testDemo() throws IOException {
|
public void testDemo() throws IOException {
|
||||||
Directory dir = newDirectory();
|
Directory dir = newDirectory();
|
||||||
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()
|
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, newIndexWriterConfig()
|
||||||
|
@ -298,7 +331,7 @@ public class TestFeatureField extends LuceneTestCase {
|
||||||
.add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
|
.add(new TermQuery(new Term("body", "apache")), Occur.SHOULD)
|
||||||
.add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
|
.add(new TermQuery(new Term("body", "lucene")), Occur.SHOULD)
|
||||||
.build();
|
.build();
|
||||||
Query boost = FeatureField.newSaturationQuery(searcher, "features", "pagerank");
|
Query boost = FeatureField.newSaturationQuery("features", "pagerank");
|
||||||
Query boostedQuery = new BooleanQuery.Builder()
|
Query boostedQuery = new BooleanQuery.Builder()
|
||||||
.add(query, Occur.MUST)
|
.add(query, Occur.MUST)
|
||||||
.add(boost, Occur.SHOULD)
|
.add(boost, Occur.SHOULD)
|
||||||
|
|
Loading…
Reference in New Issue