mirror of https://github.com/apache/lucene.git
LUCENE-8216: Added a new BM25FQuery in sandbox to blend statistics across several fields using the BM25F formula
This commit is contained in:
parent
492c3440de
commit
fd96bc5ca6
|
@ -148,6 +148,9 @@ New Features
|
||||||
based on the haversine distance of a LatLonPoint field to a provided point. This is
|
based on the haversine distance of a LatLonPoint field to a provided point. This is
|
||||||
typically useful to boost by distance. (Ignacio Vera)
|
typically useful to boost by distance. (Ignacio Vera)
|
||||||
|
|
||||||
|
* LUCENE-8216: Added a new BM25FQuery in sandbox to blend statistics across several fields
|
||||||
|
using the BM25F formula. (Adrien Grand, Jim Ferenczi)
|
||||||
|
|
||||||
Improvements
|
Improvements
|
||||||
|
|
||||||
* LUCENE-7997: Add BaseSimilarityTestCase to sanity check similarities.
|
* LUCENE-7997: Add BaseSimilarityTestCase to sanity check similarities.
|
||||||
|
|
|
@ -0,0 +1,430 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.TreeMap;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.PostingsEnum;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.index.TermState;
|
||||||
|
import org.apache.lucene.index.TermStates;
|
||||||
|
import org.apache.lucene.index.TermsEnum;
|
||||||
|
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||||
|
import org.apache.lucene.search.similarities.Similarity;
|
||||||
|
import org.apache.lucene.search.similarities.SimilarityBase;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A {@link Query} that treats multiple fields as a single stream and scores
|
||||||
|
* terms as if you had indexed them as a single term in a single field.
|
||||||
|
*
|
||||||
|
* For scoring purposes this query implements the BM25F's simple formula
|
||||||
|
* described in:
|
||||||
|
* http://www.staff.city.ac.uk/~sb317/papers/foundations_bm25_review.pdf
|
||||||
|
*
|
||||||
|
* The per-field similarity is ignored but to be compatible each field must use
|
||||||
|
* a {@link Similarity} at index time that encodes norms the same way as
|
||||||
|
* {@link SimilarityBase#computeNorm}.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public final class BM25FQuery extends Query {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A builder for {@link BM25FQuery}.
|
||||||
|
*/
|
||||||
|
public static class Builder {
|
||||||
|
private final BM25Similarity similarity;
|
||||||
|
private final Map<String, FieldAndWeight> fieldAndWeights = new HashMap<>();
|
||||||
|
private final Set<BytesRef> termsSet = new HashSet<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default builder.
|
||||||
|
*/
|
||||||
|
public Builder() {
|
||||||
|
this.similarity = new BM25Similarity();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builder with the supplied parameter values.
|
||||||
|
* @param k1 Controls non-linear term frequency normalization (saturation).
|
||||||
|
* @param b Controls to what degree document length normalizes tf values.
|
||||||
|
*/
|
||||||
|
public Builder(float k1, float b) {
|
||||||
|
this.similarity = new BM25Similarity(k1, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a field to this builder.
|
||||||
|
* @param field The field name.
|
||||||
|
*/
|
||||||
|
public Builder addField(String field) {
|
||||||
|
return addField(field, 1f);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a field to this builder.
|
||||||
|
* @param field The field name.
|
||||||
|
* @param weight The weight associated to this field.
|
||||||
|
*/
|
||||||
|
public Builder addField(String field, float weight) {
|
||||||
|
if (weight < 1) {
|
||||||
|
throw new IllegalArgumentException("weight must be greater than 1");
|
||||||
|
}
|
||||||
|
fieldAndWeights.put(field, new FieldAndWeight(field, weight));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a term to this builder.
|
||||||
|
*/
|
||||||
|
public Builder addTerm(BytesRef term) {
|
||||||
|
if (termsSet.size() > BooleanQuery.getMaxClauseCount()) {
|
||||||
|
throw new BooleanQuery.TooManyClauses();
|
||||||
|
}
|
||||||
|
termsSet.add(term);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the {@link BM25FQuery}.
|
||||||
|
*/
|
||||||
|
public BM25FQuery build() {
|
||||||
|
int size = fieldAndWeights.size() * termsSet.size();
|
||||||
|
if (size > BooleanQuery.getMaxClauseCount()) {
|
||||||
|
throw new BooleanQuery.TooManyClauses();
|
||||||
|
}
|
||||||
|
BytesRef[] terms = termsSet.toArray(new BytesRef[0]);
|
||||||
|
return new BM25FQuery(similarity, new TreeMap<>(fieldAndWeights), terms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class FieldAndWeight {
|
||||||
|
final String field;
|
||||||
|
final float weight;
|
||||||
|
|
||||||
|
FieldAndWeight(String field, float weight) {
|
||||||
|
this.field = field;
|
||||||
|
this.weight = weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the similarity to use for scoring.
|
||||||
|
private final BM25Similarity similarity;
|
||||||
|
// sorted map for fields.
|
||||||
|
private final TreeMap<String, FieldAndWeight> fieldAndWeights;
|
||||||
|
// array of terms, sorted.
|
||||||
|
private final BytesRef terms[];
|
||||||
|
// array of terms per field, sorted
|
||||||
|
private final Term fieldTerms[];
|
||||||
|
|
||||||
|
private BM25FQuery(BM25Similarity similarity, TreeMap<String, FieldAndWeight> fieldAndWeights, BytesRef[] terms) {
|
||||||
|
this.similarity = similarity;
|
||||||
|
this.fieldAndWeights = fieldAndWeights;
|
||||||
|
this.terms = terms;
|
||||||
|
int numFieldTerms = fieldAndWeights.size() * terms.length;
|
||||||
|
if (numFieldTerms > BooleanQuery.getMaxClauseCount()) {
|
||||||
|
throw new BooleanQuery.TooManyClauses();
|
||||||
|
}
|
||||||
|
this.fieldTerms = new Term[numFieldTerms];
|
||||||
|
Arrays.sort(terms);
|
||||||
|
int pos = 0;
|
||||||
|
for (String field : fieldAndWeights.keySet()) {
|
||||||
|
for (BytesRef term : terms) {
|
||||||
|
fieldTerms[pos++] = new Term(field, term);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Term> getTerms() {
|
||||||
|
return Collections.unmodifiableList(Arrays.asList(fieldTerms));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(String field) {
|
||||||
|
StringBuilder builder = new StringBuilder("BM25F((");
|
||||||
|
int pos = 0;
|
||||||
|
for (FieldAndWeight fieldWeight : fieldAndWeights.values()) {
|
||||||
|
if (pos++ != 0) {
|
||||||
|
builder.append(" ");
|
||||||
|
}
|
||||||
|
builder.append(fieldWeight.field);
|
||||||
|
if (fieldWeight.weight != 1f) {
|
||||||
|
builder.append("^");
|
||||||
|
builder.append(fieldWeight.weight);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.append(")(");
|
||||||
|
pos = 0;
|
||||||
|
for (BytesRef term : terms) {
|
||||||
|
if (pos++ != 0) {
|
||||||
|
builder.append(" ");
|
||||||
|
}
|
||||||
|
builder.append(term.utf8ToString());
|
||||||
|
}
|
||||||
|
builder.append("))");
|
||||||
|
return builder.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return 31 * classHash() + Arrays.hashCode(terms);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object other) {
|
||||||
|
return sameClassAs(other) &&
|
||||||
|
Arrays.equals(terms, ((BM25FQuery) other).terms);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Query rewrite(IndexReader reader) throws IOException {
|
||||||
|
// optimize zero and single field cases
|
||||||
|
if (terms.length == 0) {
|
||||||
|
return new BooleanQuery.Builder().build();
|
||||||
|
}
|
||||||
|
// single field and one term
|
||||||
|
if (fieldTerms.length == 1) {
|
||||||
|
return new TermQuery(fieldTerms[0]);
|
||||||
|
}
|
||||||
|
// single field and multiple terms
|
||||||
|
if (fieldAndWeights.size() == 1) {
|
||||||
|
return new SynonymQuery(fieldTerms);
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private BooleanQuery rewriteToBoolean() {
|
||||||
|
// rewrite to a simple disjunction if the score is not needed.
|
||||||
|
BooleanQuery.Builder bq = new BooleanQuery.Builder();
|
||||||
|
for (Term term : fieldTerms) {
|
||||||
|
bq.add(new TermQuery(term), BooleanClause.Occur.SHOULD);
|
||||||
|
}
|
||||||
|
return bq.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
|
||||||
|
if (scoreMode.needsScores()) {
|
||||||
|
return new BM25FWeight(this, searcher, scoreMode, boost);
|
||||||
|
} else {
|
||||||
|
// rewrite to a simple disjunction if the score is not needed.
|
||||||
|
Query bq = rewriteToBoolean();
|
||||||
|
return searcher.rewrite(bq).createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, boost);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class BM25FWeight extends Weight {
|
||||||
|
private final IndexSearcher searcher;
|
||||||
|
private final TermStates termStates[];
|
||||||
|
private final Similarity.SimScorer simWeight;
|
||||||
|
|
||||||
|
BM25FWeight(Query query, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
|
||||||
|
super(query);
|
||||||
|
assert scoreMode.needsScores();
|
||||||
|
this.searcher = searcher;
|
||||||
|
long docFreq = 0;
|
||||||
|
long totalTermFreq = 0;
|
||||||
|
termStates = new TermStates[fieldTerms.length];
|
||||||
|
for (int i = 0; i < termStates.length; i++) {
|
||||||
|
FieldAndWeight field = fieldAndWeights.get(fieldTerms[i].field());
|
||||||
|
termStates[i] = TermStates.build(searcher.getTopReaderContext(), fieldTerms[i], true);
|
||||||
|
TermStatistics termStats = searcher.termStatistics(fieldTerms[i], termStates[i]);
|
||||||
|
if (termStats != null) {
|
||||||
|
docFreq = Math.max(termStats.docFreq(), docFreq);
|
||||||
|
totalTermFreq += (double) field.weight * termStats.totalTermFreq();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (docFreq > 0) {
|
||||||
|
CollectionStatistics pseudoCollectionStats = mergeCollectionStatistics(searcher);
|
||||||
|
TermStatistics pseudoTermStatistics = new TermStatistics(new BytesRef("pseudo_term"), docFreq, Math.max(1, totalTermFreq));
|
||||||
|
this.simWeight = similarity.scorer(boost, pseudoCollectionStats, pseudoTermStatistics);
|
||||||
|
} else {
|
||||||
|
this.simWeight = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CollectionStatistics mergeCollectionStatistics(IndexSearcher searcher) throws IOException {
|
||||||
|
long maxDoc = searcher.getIndexReader().maxDoc();
|
||||||
|
long docCount = 0;
|
||||||
|
long sumTotalTermFreq = 0;
|
||||||
|
long sumDocFreq = 0;
|
||||||
|
for (FieldAndWeight fieldWeight : fieldAndWeights.values()) {
|
||||||
|
CollectionStatistics collectionStats = searcher.collectionStatistics(fieldWeight.field);
|
||||||
|
if (collectionStats != null) {
|
||||||
|
docCount = Math.max(collectionStats.docCount(), docCount);
|
||||||
|
sumDocFreq = Math.max(collectionStats.sumDocFreq(), sumDocFreq);
|
||||||
|
sumTotalTermFreq += (double) fieldWeight.weight * collectionStats.sumTotalTermFreq();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CollectionStatistics("pseudo_field", maxDoc, docCount, sumTotalTermFreq, sumDocFreq);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void extractTerms(Set<Term> termSet) {
|
||||||
|
termSet.addAll(Arrays.asList(fieldTerms));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Matches matches(LeafReaderContext context, int doc) throws IOException {
|
||||||
|
Weight weight = searcher.rewrite(rewriteToBoolean()).createWeight(searcher, ScoreMode.COMPLETE, 1f);
|
||||||
|
return weight.matches(context, doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||||
|
Scorer scorer = scorer(context);
|
||||||
|
if (scorer != null) {
|
||||||
|
int newDoc = scorer.iterator().advance(doc);
|
||||||
|
if (newDoc == doc) {
|
||||||
|
final float freq;
|
||||||
|
if (scorer instanceof BM25FScorer) {
|
||||||
|
freq = ((BM25FScorer) scorer).freq();
|
||||||
|
} else {
|
||||||
|
assert scorer instanceof TermScorer;
|
||||||
|
freq = ((TermScorer) scorer).freq();
|
||||||
|
}
|
||||||
|
final MultiNormsLeafSimScorer docScorer =
|
||||||
|
new MultiNormsLeafSimScorer(simWeight, context.reader(), fieldAndWeights.values(), true);
|
||||||
|
Explanation freqExplanation = Explanation.match(freq, "termFreq=" + freq);
|
||||||
|
Explanation scoreExplanation = docScorer.explain(doc, freqExplanation);
|
||||||
|
return Explanation.match(
|
||||||
|
scoreExplanation.getValue(),
|
||||||
|
"weight(" + getQuery() + " in " + doc + ") ["
|
||||||
|
+ similarity.getClass().getSimpleName() + "], result of:",
|
||||||
|
scoreExplanation);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Explanation.noMatch("no matching term");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Scorer scorer(LeafReaderContext context) throws IOException {
|
||||||
|
List<PostingsEnum> iterators = new ArrayList<>();
|
||||||
|
List<FieldAndWeight> fields = new ArrayList<>();
|
||||||
|
for (int i = 0; i < fieldTerms.length; i++) {
|
||||||
|
TermState state = termStates[i].get(context);
|
||||||
|
if (state != null) {
|
||||||
|
TermsEnum termsEnum = context.reader().terms(fieldTerms[i].field()).iterator();
|
||||||
|
termsEnum.seekExact(fieldTerms[i].bytes(), state);
|
||||||
|
PostingsEnum postingsEnum = termsEnum.postings(null, PostingsEnum.FREQS);
|
||||||
|
iterators.add(postingsEnum);
|
||||||
|
fields.add(fieldAndWeights.get(fieldTerms[i].field()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iterators.isEmpty()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we must optimize this case (term not in segment), disjunctions require >= 2 subs
|
||||||
|
if (iterators.size() == 1) {
|
||||||
|
final LeafSimScorer scoringSimScorer =
|
||||||
|
new LeafSimScorer(simWeight, context.reader(), fields.get(0).field, true);
|
||||||
|
return new TermScorer(this, iterators.get(0), scoringSimScorer);
|
||||||
|
}
|
||||||
|
final MultiNormsLeafSimScorer scoringSimScorer =
|
||||||
|
new MultiNormsLeafSimScorer(simWeight, context.reader(), fields, true);
|
||||||
|
LeafSimScorer nonScoringSimScorer = new LeafSimScorer(simWeight, context.reader(), "pseudo_field", false);
|
||||||
|
// we use termscorers + disjunction as an impl detail
|
||||||
|
DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size());
|
||||||
|
for (int i = 0; i < iterators.size(); i++) {
|
||||||
|
float weight = fields.get(i).weight;
|
||||||
|
queue.add(new WeightedDisiWrapper(new TermScorer(this, iterators.get(i), nonScoringSimScorer), weight));
|
||||||
|
}
|
||||||
|
// Even though it is called approximation, it is accurate since none of
|
||||||
|
// the sub iterators are two-phase iterators.
|
||||||
|
DocIdSetIterator iterator = new DisjunctionDISIApproximation(queue);
|
||||||
|
return new BM25FScorer(this, queue, iterator, scoringSimScorer);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isCacheable(LeafReaderContext ctx) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class WeightedDisiWrapper extends DisiWrapper {
|
||||||
|
final float weight;
|
||||||
|
|
||||||
|
WeightedDisiWrapper(Scorer scorer, float weight) {
|
||||||
|
super(scorer);
|
||||||
|
this.weight = weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
float freq() throws IOException {
|
||||||
|
return weight * ((PostingsEnum) iterator).freq();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class BM25FScorer extends Scorer {
|
||||||
|
private final DisiPriorityQueue queue;
|
||||||
|
private final DocIdSetIterator iterator;
|
||||||
|
private final MultiNormsLeafSimScorer simScorer;
|
||||||
|
|
||||||
|
BM25FScorer(Weight weight, DisiPriorityQueue queue, DocIdSetIterator iterator, MultiNormsLeafSimScorer simScorer) {
|
||||||
|
super(weight);
|
||||||
|
this.queue = queue;
|
||||||
|
this.iterator = iterator;
|
||||||
|
this.simScorer = simScorer;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return iterator.docID();
|
||||||
|
}
|
||||||
|
|
||||||
|
float freq() throws IOException {
|
||||||
|
DisiWrapper w = queue.topList();
|
||||||
|
float freq = ((WeightedDisiWrapper) w).freq();
|
||||||
|
for (w = w.next; w != null; w = w.next) {
|
||||||
|
freq += ((WeightedDisiWrapper) w).freq();
|
||||||
|
if (freq < 0) { // overflow
|
||||||
|
return Integer.MAX_VALUE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return freq;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float score() throws IOException {
|
||||||
|
return simScorer.score(iterator.docID(), freq());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DocIdSetIterator iterator() {
|
||||||
|
return iterator;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float getMaxScore(int upTo) throws IOException {
|
||||||
|
return Float.POSITIVE_INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.NumericDocValues;
|
||||||
|
import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||||
|
import org.apache.lucene.util.SmallFloat;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.BM25FQuery.FieldAndWeight;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy of {@link LeafSimScorer} that sums document's norms from multiple fields.
|
||||||
|
*/
|
||||||
|
final class MultiNormsLeafSimScorer {
|
||||||
|
/**
|
||||||
|
* Cache of decoded norms.
|
||||||
|
*/
|
||||||
|
private static final float[] LENGTH_TABLE = new float[256];
|
||||||
|
|
||||||
|
static {
|
||||||
|
for (int i = 0; i < 256; i++) {
|
||||||
|
LENGTH_TABLE[i] = SmallFloat.byte4ToInt((byte) i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final SimScorer scorer;
|
||||||
|
private final NumericDocValues norms;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sole constructor: Score documents of {@code reader} with {@code scorer}.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
MultiNormsLeafSimScorer(SimScorer scorer, LeafReader reader, Collection<FieldAndWeight> normFields, boolean needsScores) throws IOException {
|
||||||
|
this.scorer = Objects.requireNonNull(scorer);
|
||||||
|
if (needsScores) {
|
||||||
|
final List<NumericDocValues> normsList = new ArrayList<>();
|
||||||
|
final List<Float> weightList = new ArrayList<>();
|
||||||
|
for (FieldAndWeight field : normFields) {
|
||||||
|
NumericDocValues norms = reader.getNormValues(field.field);
|
||||||
|
if (norms != null) {
|
||||||
|
normsList.add(norms);
|
||||||
|
weightList.add(field.weight);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (normsList.isEmpty()) {
|
||||||
|
norms = null;
|
||||||
|
} else if (normsList.size() == 1) {
|
||||||
|
norms = normsList.get(0);
|
||||||
|
} else {
|
||||||
|
final NumericDocValues[] normsArr = normsList.toArray(new NumericDocValues[0]);
|
||||||
|
final float[] weightArr = new float[normsList.size()];
|
||||||
|
for (int i = 0; i < weightList.size(); i++) {
|
||||||
|
weightArr[i] = weightList.get(i);
|
||||||
|
}
|
||||||
|
norms = new MultiFieldNormValues(normsArr, weightArr);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
norms = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private long getNormValue(int doc) throws IOException {
|
||||||
|
if (norms != null) {
|
||||||
|
boolean found = norms.advanceExact(doc);
|
||||||
|
assert found;
|
||||||
|
return norms.longValue();
|
||||||
|
} else {
|
||||||
|
return 1L; // default norm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Score the provided document assuming the given term document frequency.
|
||||||
|
* This method must be called on non-decreasing sequences of doc ids.
|
||||||
|
* @see SimScorer#score(float, long) */
|
||||||
|
public float score(int doc, float freq) throws IOException {
|
||||||
|
return scorer.score(freq, getNormValue(doc));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Explain the score for the provided document assuming the given term document frequency.
|
||||||
|
* This method must be called on non-decreasing sequences of doc ids.
|
||||||
|
* @see SimScorer#explain(Explanation, long) */
|
||||||
|
public Explanation explain(int doc, Explanation freqExpl) throws IOException {
|
||||||
|
return scorer.explain(freqExpl, getNormValue(doc));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class MultiFieldNormValues extends NumericDocValues {
|
||||||
|
private final NumericDocValues[] normsArr;
|
||||||
|
private final float[] weightArr;
|
||||||
|
private long current;
|
||||||
|
private int docID = -1;
|
||||||
|
|
||||||
|
MultiFieldNormValues(NumericDocValues[] normsArr, float[] weightArr) {
|
||||||
|
this.normsArr = normsArr;
|
||||||
|
this.weightArr = weightArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long longValue() {
|
||||||
|
return current;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean advanceExact(int target) throws IOException {
|
||||||
|
float normValue = 0;
|
||||||
|
for (int i = 0; i < normsArr.length; i++) {
|
||||||
|
boolean found = normsArr[i].advanceExact(target);
|
||||||
|
assert found;
|
||||||
|
normValue += weightArr[i] * LENGTH_TABLE[(byte) normsArr[i].longValue()];
|
||||||
|
}
|
||||||
|
current = SmallFloat.intToByte4(Math.round(normValue));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return docID;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long cost() {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,168 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.Field.Store;
|
||||||
|
import org.apache.lucene.document.StringField;
|
||||||
|
import org.apache.lucene.document.TextField;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.MultiReader;
|
||||||
|
import org.apache.lucene.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
|
||||||
|
public class TestBM25FQuery extends LuceneTestCase {
|
||||||
|
public void testInvalid() {
|
||||||
|
BM25FQuery.Builder builder = new BM25FQuery.Builder();
|
||||||
|
IllegalArgumentException exc =
|
||||||
|
expectThrows(IllegalArgumentException.class, () -> builder.addField("foo", 0.5f));
|
||||||
|
assertEquals(exc.getMessage(), "weight must be greater than 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRewrite() throws IOException {
|
||||||
|
BM25FQuery.Builder builder = new BM25FQuery.Builder();
|
||||||
|
IndexReader reader = new MultiReader();
|
||||||
|
IndexSearcher searcher = new IndexSearcher(reader);
|
||||||
|
Query actual = searcher.rewrite(builder.build());
|
||||||
|
assertEquals(actual, new MatchNoDocsQuery());
|
||||||
|
builder.addField("field", 1f);
|
||||||
|
actual = searcher.rewrite(builder.build());
|
||||||
|
assertEquals(actual, new MatchNoDocsQuery());
|
||||||
|
builder.addTerm(new BytesRef("foo"));
|
||||||
|
actual = searcher.rewrite(builder.build());
|
||||||
|
assertEquals(actual, new TermQuery(new Term("field", "foo")));
|
||||||
|
builder.addTerm(new BytesRef("bar"));
|
||||||
|
actual = searcher.rewrite(builder.build());
|
||||||
|
assertEquals(actual, new SynonymQuery(new Term("field", "foo"),
|
||||||
|
new Term("field", "bar")));
|
||||||
|
builder.addField("another_field", 1f);
|
||||||
|
Query query = builder.build();
|
||||||
|
actual = searcher.rewrite(query);
|
||||||
|
assertEquals(actual, query);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToString() {
|
||||||
|
assertEquals("BM25F(()())", new BM25FQuery.Builder().build().toString());
|
||||||
|
BM25FQuery.Builder builder = new BM25FQuery.Builder();
|
||||||
|
builder.addField("foo", 1f);
|
||||||
|
assertEquals("BM25F((foo)())", builder.build().toString());
|
||||||
|
builder.addTerm(new BytesRef("bar"));
|
||||||
|
assertEquals("BM25F((foo)(bar))", builder.build().toString());
|
||||||
|
builder.addField("title", 3f);
|
||||||
|
assertEquals("BM25F((foo title^3.0)(bar))", builder.build().toString());
|
||||||
|
builder.addTerm(new BytesRef("baz"));
|
||||||
|
assertEquals("BM25F((foo title^3.0)(bar baz))", builder.build().toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSameScore() throws IOException {
|
||||||
|
Directory dir = newDirectory();
|
||||||
|
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
|
||||||
|
|
||||||
|
Document doc = new Document();
|
||||||
|
doc.add(new StringField("f", "a", Store.NO));
|
||||||
|
w.addDocument(doc);
|
||||||
|
|
||||||
|
doc = new Document();
|
||||||
|
doc.add(new StringField("g", "a", Store.NO));
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
w.addDocument(doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexReader reader = w.getReader();
|
||||||
|
IndexSearcher searcher = newSearcher(reader);
|
||||||
|
BM25FQuery query = new BM25FQuery.Builder()
|
||||||
|
.addField("f", 1f)
|
||||||
|
.addField("g", 1f)
|
||||||
|
.addTerm(new BytesRef("a"))
|
||||||
|
.build();
|
||||||
|
TopScoreDocCollector collector = TopScoreDocCollector.create(Math.min(reader.numDocs(), Integer.MAX_VALUE), null, Integer.MAX_VALUE);
|
||||||
|
searcher.search(query, collector);
|
||||||
|
TopDocs topDocs = collector.topDocs();
|
||||||
|
assertEquals(TotalHits.Relation.EQUAL_TO, topDocs.totalHits.relation);
|
||||||
|
assertEquals(11, topDocs.totalHits.value);
|
||||||
|
// All docs must have the same score
|
||||||
|
for (int i = 0; i < topDocs.scoreDocs.length; ++i) {
|
||||||
|
assertEquals(topDocs.scoreDocs[0].score, topDocs.scoreDocs[i].score, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
reader.close();
|
||||||
|
w.close();
|
||||||
|
dir.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testAgainstCopyField() throws IOException {
|
||||||
|
Directory dir = newDirectory();
|
||||||
|
RandomIndexWriter w = new RandomIndexWriter(random(), dir, new MockAnalyzer(random()));
|
||||||
|
int numMatch = atLeast(10);
|
||||||
|
int boost1 = Math.max(1, random().nextInt(5));
|
||||||
|
int boost2 = Math.max(1, random().nextInt(5));
|
||||||
|
for (int i = 0; i < numMatch; i++) {
|
||||||
|
Document doc = new Document();
|
||||||
|
if (random().nextBoolean()) {
|
||||||
|
doc.add(new TextField("a", "baz", Store.NO));
|
||||||
|
doc.add(new TextField("b", "baz", Store.NO));
|
||||||
|
for (int k = 0; k < boost1+boost2; k++) {
|
||||||
|
doc.add(new TextField("ab", "baz", Store.NO));
|
||||||
|
}
|
||||||
|
w.addDocument(doc);
|
||||||
|
doc.clear();
|
||||||
|
}
|
||||||
|
int freqA = random().nextInt(5) + 1;
|
||||||
|
for (int j = 0; j < freqA; j++) {
|
||||||
|
doc.add(new TextField("a", "foo", Store.NO));
|
||||||
|
}
|
||||||
|
int freqB = random().nextInt(5) + 1;
|
||||||
|
for (int j = 0; j < freqB; j++) {
|
||||||
|
doc.add(new TextField("b", "foo", Store.NO));
|
||||||
|
}
|
||||||
|
int freqAB = freqA * boost1 + freqB * boost2;
|
||||||
|
for (int j = 0; j < freqAB; j++) {
|
||||||
|
doc.add(new TextField("ab", "foo", Store.NO));
|
||||||
|
}
|
||||||
|
w.addDocument(doc);
|
||||||
|
}
|
||||||
|
IndexReader reader = w.getReader();
|
||||||
|
IndexSearcher searcher = newSearcher(reader);
|
||||||
|
searcher.setSimilarity(new BM25Similarity());
|
||||||
|
BM25FQuery query = new BM25FQuery.Builder()
|
||||||
|
.addField("a", (float) boost1)
|
||||||
|
.addField("b", (float) boost2)
|
||||||
|
.addTerm(new BytesRef("foo"))
|
||||||
|
.addTerm(new BytesRef("foo"))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TopScoreDocCollector bm25FCollector = TopScoreDocCollector.create(numMatch, null, Integer.MAX_VALUE);
|
||||||
|
searcher.search(query, bm25FCollector);
|
||||||
|
TopDocs bm25FTopDocs = bm25FCollector.topDocs();
|
||||||
|
assertEquals(numMatch, bm25FTopDocs.totalHits.value);
|
||||||
|
TopScoreDocCollector collector = TopScoreDocCollector.create(reader.numDocs(), null, Integer.MAX_VALUE);
|
||||||
|
searcher.search(new TermQuery(new Term("ab", "foo")), collector);
|
||||||
|
TopDocs topDocs = collector.topDocs();
|
||||||
|
CheckHits.checkEqual(query, topDocs.scoreDocs, bm25FTopDocs.scoreDocs);
|
||||||
|
|
||||||
|
reader.close();
|
||||||
|
w.close();
|
||||||
|
dir.close();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue