mirror of https://github.com/apache/lucene.git
LUCENE-8633: Remove term weighting from IntervalQuery scores
This commit is contained in:
parent
09778b2133
commit
a826649241
|
@ -156,6 +156,10 @@ Changes in Runtime Behavior
|
|||
tokens by default, preventing a number of bugs when the filter is chained with
|
||||
tokenfilters that change the length of their tokens (Alan Woodward)
|
||||
|
||||
* LUCENE-8633: IntervalQuery scores do not use term weighting any more, the score
|
||||
is instead calculated as a function of the sloppy frequency of the matching
|
||||
intervals. (Alan Woodward, Jim Ferenczi)
|
||||
|
||||
New Features
|
||||
|
||||
* LUCENE-8340: LongPoint#newDistanceQuery may be used to boost scores based on
|
||||
|
|
|
@ -62,6 +62,15 @@ class ConjunctionIntervalsSource extends IntervalsSource {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
int minExtent = 0;
|
||||
for (IntervalsSource source : subSources) {
|
||||
minExtent += source.minExtent();
|
||||
}
|
||||
return minExtent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException {
|
||||
List<IntervalIterator> subIntervals = new ArrayList<>();
|
||||
|
|
|
@ -86,4 +86,9 @@ class DifferenceIntervalsSource extends IntervalsSource {
|
|||
public void extractTerms(String field, Set<Term> terms) {
|
||||
minuend.extractTerms(field, terms);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
return minuend.minExtent();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,6 +90,15 @@ class DisjunctionIntervalsSource extends IntervalsSource {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
int minExtent = subSources.get(0).minExtent();
|
||||
for (int i = 1; i < subSources.size(); i++) {
|
||||
minExtent = Math.min(minExtent, subSources.get(i).minExtent());
|
||||
}
|
||||
return minExtent;
|
||||
}
|
||||
|
||||
private static class DisjunctionIntervalIterator extends IntervalIterator {
|
||||
|
||||
final DocIdSetIterator approximation;
|
||||
|
|
|
@ -61,6 +61,15 @@ class ExtendedIntervalsSource extends IntervalsSource {
|
|||
source.extractTerms(field, terms);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
int minExtent = before + source.minExtent() + after;
|
||||
if (minExtent < 0) {
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
return minExtent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
|
|
|
@ -77,6 +77,11 @@ public abstract class FilteredIntervalsSource extends IntervalsSource {
|
|||
return IntervalMatches.asMatches(filtered, mi, doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
return in.minExtent();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void extractTerms(String field, Set<Term> terms) {
|
||||
in.extractTerms(field, terms);
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search.intervals;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* An intervals source that combines two other sources, requiring both of them to
|
||||
* be present in order to match, but using the minExtent of one of them
|
||||
*/
|
||||
class FilteringConjunctionIntervalsSource extends ConjunctionIntervalsSource {
|
||||
|
||||
private final IntervalsSource source;
|
||||
|
||||
FilteringConjunctionIntervalsSource(IntervalsSource source, IntervalsSource filter, IntervalFunction function) {
|
||||
super(Arrays.asList(source, filter), function);
|
||||
this.source = source;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
return source.minExtent();
|
||||
}
|
||||
}
|
|
@ -18,28 +18,21 @@
|
|||
package org.apache.lucene.search.intervals;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.TermStates;
|
||||
import org.apache.lucene.search.CollectionStatistics;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.FilterMatchesIterator;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.LeafSimScorer;
|
||||
import org.apache.lucene.search.Matches;
|
||||
import org.apache.lucene.search.MatchesIterator;
|
||||
import org.apache.lucene.search.MatchesUtils;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.TermStatistics;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
||||
/**
|
||||
* A query that retrieves documents containing intervals returned from an
|
||||
|
@ -47,11 +40,30 @@ import org.apache.lucene.util.ArrayUtil;
|
|||
*
|
||||
* Static constructor functions for various different sources can be found in the
|
||||
* {@link Intervals} class
|
||||
*
|
||||
* Scores for this query are computed as a function of the sloppy frequency of
|
||||
* intervals appearing in a particular document. Sloppy frequency is calculated
|
||||
* from the number of matching intervals, and their width, with wider intervals
|
||||
* contributing lower values. The scores can be adjusted with two optional
|
||||
* parameters:
|
||||
* <ul>
|
||||
* <li>pivot - the sloppy frequency value at which the overall score of the
|
||||
* document will equal 0.5. The default value is 1</li>
|
||||
* <li>exp - higher values of this parameter make the function grow more slowly
|
||||
* below the pivot and faster higher than the pivot. The default value is 1</li>
|
||||
* </ul>
|
||||
*
|
||||
* Optimal values for both pivot and exp depend on the type of queries and corpus of
|
||||
* documents being queried.
|
||||
*
|
||||
* Scores are bounded to between 0 and 1. For higher contributions, wrap the query
|
||||
* in a {@link org.apache.lucene.search.BoostQuery}
|
||||
*/
|
||||
public final class IntervalQuery extends Query {
|
||||
|
||||
private final String field;
|
||||
private final IntervalsSource intervalsSource;
|
||||
private final IntervalScoreFunction scoreFunction;
|
||||
|
||||
/**
|
||||
* Create a new IntervalQuery
|
||||
|
@ -59,10 +71,41 @@ public final class IntervalQuery extends Query {
|
|||
* @param intervalsSource an {@link IntervalsSource} to retrieve intervals from
|
||||
*/
|
||||
public IntervalQuery(String field, IntervalsSource intervalsSource) {
|
||||
this.field = field;
|
||||
this.intervalsSource = intervalsSource;
|
||||
this(field, intervalsSource, IntervalScoreFunction.saturationFunction(1));
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new IntervalQuery with a scoring pivot
|
||||
*
|
||||
* @param field the field to query
|
||||
* @param intervalsSource an {@link IntervalsSource} to retrieve intervals from
|
||||
* @param pivot the sloppy frequency value at which the score will be 0.5, must be within (0, +Infinity)
|
||||
*/
|
||||
public IntervalQuery(String field, IntervalsSource intervalsSource, float pivot) {
|
||||
this(field, intervalsSource, IntervalScoreFunction.saturationFunction(pivot));
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new IntervalQuery with a scoring pivot and exponent
|
||||
* @param field the field to query
|
||||
* @param intervalsSource an {@link IntervalsSource} to retrieve intervals from
|
||||
* @param pivot the sloppy frequency value at which the score will be 0.5, must be within (0, +Infinity)
|
||||
* @param exp exponent, higher values make the function grow slower before 'pivot' and faster
|
||||
* after 'pivot', must be in (0, +Infinity)
|
||||
*/
|
||||
public IntervalQuery(String field, IntervalsSource intervalsSource, float pivot, float exp) {
|
||||
this(field, intervalsSource, IntervalScoreFunction.sigmoidFunction(pivot, exp));
|
||||
}
|
||||
|
||||
private IntervalQuery(String field, IntervalsSource intervalsSource, IntervalScoreFunction scoreFunction) {
|
||||
this.field = field;
|
||||
this.intervalsSource = intervalsSource;
|
||||
this.scoreFunction = scoreFunction;
|
||||
}
|
||||
|
||||
/**
|
||||
* The field to query
|
||||
*/
|
||||
public String getField() {
|
||||
return field;
|
||||
}
|
||||
|
@ -74,26 +117,7 @@ public final class IntervalQuery extends Query {
|
|||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
|
||||
return new IntervalWeight(this, scoreMode.needsScores() ? buildSimScorer(searcher, boost) : null,
|
||||
searcher.getSimilarity(), scoreMode);
|
||||
}
|
||||
|
||||
private Similarity.SimScorer buildSimScorer(IndexSearcher searcher, float boost) throws IOException {
|
||||
Set<Term> terms = new HashSet<>();
|
||||
intervalsSource.extractTerms(field, terms);
|
||||
TermStatistics[] termStats = new TermStatistics[terms.size()];
|
||||
int termUpTo = 0;
|
||||
for (Term term : terms) {
|
||||
TermStatistics termStatistics = searcher.termStatistics(term, TermStates.build(searcher.getTopReaderContext(), term, true));
|
||||
if (termStatistics != null) {
|
||||
termStats[termUpTo++] = termStatistics;
|
||||
}
|
||||
}
|
||||
if (termUpTo == 0) {
|
||||
return null;
|
||||
}
|
||||
CollectionStatistics collectionStats = searcher.collectionStatistics(field);
|
||||
return searcher.getSimilarity().scorer(boost, collectionStats, ArrayUtil.copyOfSubArray(termStats, 0, termUpTo));
|
||||
return new IntervalWeight(this, boost, scoreMode);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -112,15 +136,13 @@ public final class IntervalQuery extends Query {
|
|||
|
||||
private class IntervalWeight extends Weight {
|
||||
|
||||
final Similarity.SimScorer simScorer;
|
||||
final Similarity similarity;
|
||||
final ScoreMode scoreMode;
|
||||
final float boost;
|
||||
|
||||
public IntervalWeight(Query query, Similarity.SimScorer simScorer, Similarity similarity, ScoreMode scoreMode) {
|
||||
public IntervalWeight(Query query, float boost, ScoreMode scoreMode) {
|
||||
super(query);
|
||||
this.simScorer = simScorer;
|
||||
this.similarity = similarity;
|
||||
this.scoreMode = scoreMode;
|
||||
this.boost = boost;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -134,7 +156,8 @@ public final class IntervalQuery extends Query {
|
|||
if (scorer != null) {
|
||||
int newDoc = scorer.iterator().advance(doc);
|
||||
if (newDoc == doc) {
|
||||
return scorer.explain("weight("+getQuery()+" in "+doc+") [" + similarity.getClass().getSimpleName() + "]");
|
||||
float freq = scorer.freq();
|
||||
return scoreFunction.explain(intervalsSource.toString(), boost, freq);
|
||||
}
|
||||
}
|
||||
return Explanation.noMatch("no matching intervals");
|
||||
|
@ -161,9 +184,7 @@ public final class IntervalQuery extends Query {
|
|||
IntervalIterator intervals = intervalsSource.intervals(field, context);
|
||||
if (intervals == null)
|
||||
return null;
|
||||
LeafSimScorer leafScorer = simScorer == null ? null
|
||||
: new LeafSimScorer(simScorer, context.reader(), field, scoreMode.needsScores());
|
||||
return new IntervalScorer(this, intervals, leafScorer);
|
||||
return new IntervalScorer(this, intervals, intervalsSource.minExtent(), boost, scoreFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.search.intervals;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
|
||||
abstract class IntervalScoreFunction {
|
||||
|
||||
static IntervalScoreFunction saturationFunction(float pivot) {
|
||||
if (pivot <= 0 || Float.isFinite(pivot) == false) {
|
||||
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
|
||||
}
|
||||
return new SaturationFunction(pivot);
|
||||
}
|
||||
|
||||
static IntervalScoreFunction sigmoidFunction(float pivot, float exp) {
|
||||
if (pivot <= 0 || Float.isFinite(pivot) == false) {
|
||||
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
|
||||
}
|
||||
if (exp <= 0 || Float.isFinite(exp) == false) {
|
||||
throw new IllegalArgumentException("exp must be > 0, got: " + exp);
|
||||
}
|
||||
return new SigmoidFunction(pivot, exp);
|
||||
}
|
||||
|
||||
public abstract Similarity.SimScorer scorer(float weight);
|
||||
|
||||
public abstract Explanation explain(String interval, float weight, float sloppyFreq);
|
||||
|
||||
@Override
|
||||
public abstract boolean equals(Object other);
|
||||
|
||||
@Override
|
||||
public abstract int hashCode();
|
||||
|
||||
@Override
|
||||
public abstract String toString();
|
||||
|
||||
private static class SaturationFunction extends IntervalScoreFunction {
|
||||
|
||||
final float pivot;
|
||||
|
||||
private SaturationFunction(float pivot) {
|
||||
this.pivot = pivot;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Similarity.SimScorer scorer(float weight) {
|
||||
return new Similarity.SimScorer() {
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
// should be f / (f + k) but we rewrite it to
|
||||
// 1 - k / (f + k) to make sure it doesn't decrease
|
||||
// with f in spite of rounding
|
||||
return weight * (1.0f - pivot / (pivot + freq));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(String interval, float weight, float sloppyFreq) {
|
||||
float score = scorer(weight).score(sloppyFreq, 1L);
|
||||
return Explanation.match(score,
|
||||
"Saturation function on interval frequency, computed as w * S / (S + k) from:",
|
||||
Explanation.match(weight, "w, weight of this function"),
|
||||
Explanation.match(pivot, "k, pivot feature value that would give a score contribution equal to w/2"),
|
||||
Explanation.match(sloppyFreq, "S, the sloppy frequency of the interval " + interval));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
SaturationFunction that = (SaturationFunction) o;
|
||||
return Float.compare(that.pivot, pivot) == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(pivot);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "SaturationFunction(pivot=" + pivot + ")";
|
||||
}
|
||||
}
|
||||
|
||||
private static class SigmoidFunction extends IntervalScoreFunction {
|
||||
|
||||
private final float pivot, a;
|
||||
private final double pivotPa;
|
||||
|
||||
private SigmoidFunction(float pivot, float a) {
|
||||
this.pivot = pivot;
|
||||
this.a = a;
|
||||
this.pivotPa = Math.pow(pivot, a);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Similarity.SimScorer scorer(float weight) {
|
||||
return new Similarity.SimScorer() {
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
// should be f^a / (f^a + k^a) but we rewrite it to
|
||||
// 1 - k^a / (f + k^a) to make sure it doesn't decrease
|
||||
// with f in spite of rounding
|
||||
return (float) (weight * (1.0f - pivotPa / (Math.pow(freq, a) + pivotPa)));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(String interval, float weight, float sloppyFreq) {
|
||||
float score = scorer(weight).score(sloppyFreq, 1L);
|
||||
return Explanation.match(score,
|
||||
"Sigmoid function on interval frequency, computed as w * S^a / (S^a + k^a) from:",
|
||||
Explanation.match(weight, "w, weight of this function"),
|
||||
Explanation.match(pivot, "k, pivot feature value that would give a score contribution equal to w/2"),
|
||||
Explanation.match(a, "a, exponent, higher values make the function grow slower before k and faster after k"),
|
||||
Explanation.match(sloppyFreq, "S, the sloppy frequency of the interval " + interval));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
SigmoidFunction that = (SigmoidFunction) o;
|
||||
return Float.compare(that.pivot, pivot) == 0 &&
|
||||
Float.compare(that.a, a) == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(pivot, a);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "SigmoidFunction(pivot=" + pivot + ", a=" + a + ")";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -20,24 +20,27 @@ package org.apache.lucene.search.intervals;
|
|||
import java.io.IOException;
|
||||
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.LeafSimScorer;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.TwoPhaseIterator;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
|
||||
class IntervalScorer extends Scorer {
|
||||
|
||||
private final IntervalIterator intervals;
|
||||
private final LeafSimScorer simScorer;
|
||||
private final Similarity.SimScorer simScorer;
|
||||
private final float boost;
|
||||
private final int minExtent;
|
||||
|
||||
private float freq = -1;
|
||||
private float freq;
|
||||
private int lastScoredDoc = -1;
|
||||
|
||||
protected IntervalScorer(Weight weight, IntervalIterator intervals, LeafSimScorer simScorer) {
|
||||
IntervalScorer(Weight weight, IntervalIterator intervals, int minExtent, float boost, IntervalScoreFunction scoreFunction) {
|
||||
super(weight);
|
||||
this.intervals = intervals;
|
||||
this.simScorer = simScorer;
|
||||
this.minExtent = minExtent;
|
||||
this.boost = boost;
|
||||
this.simScorer = scoreFunction.scorer(boost);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -48,19 +51,10 @@ class IntervalScorer extends Scorer {
|
|||
@Override
|
||||
public float score() throws IOException {
|
||||
ensureFreq();
|
||||
return simScorer.score(docID(), freq);
|
||||
return simScorer.score(freq, 1);
|
||||
}
|
||||
|
||||
public Explanation explain(String topLevel) throws IOException {
|
||||
ensureFreq();
|
||||
Explanation freqExplanation = Explanation.match(freq, "intervalFreq=" + freq);
|
||||
Explanation scoreExplanation = simScorer.explain(docID(), freqExplanation);
|
||||
return Explanation.match(scoreExplanation.getValue(),
|
||||
topLevel + ", result of:",
|
||||
scoreExplanation);
|
||||
}
|
||||
|
||||
public float freq() throws IOException {
|
||||
float freq() throws IOException {
|
||||
ensureFreq();
|
||||
return freq;
|
||||
}
|
||||
|
@ -70,7 +64,8 @@ class IntervalScorer extends Scorer {
|
|||
lastScoredDoc = docID();
|
||||
freq = 0;
|
||||
do {
|
||||
freq += (1.0 / (intervals.end() - intervals.start() + 1));
|
||||
int length = (intervals.end() - intervals.start() + 1);
|
||||
freq += 1.0 / Math.max(length - minExtent + 1, 1);
|
||||
}
|
||||
while (intervals.nextInterval() != IntervalIterator.NO_MORE_INTERVALS);
|
||||
}
|
||||
|
@ -97,9 +92,8 @@ class IntervalScorer extends Scorer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int upTo) throws IOException {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
public float getMaxScore(int upTo) {
|
||||
return boost;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -177,7 +177,7 @@ public final class Intervals {
|
|||
* @param reference the source to filter by
|
||||
*/
|
||||
public static IntervalsSource overlapping(IntervalsSource source, IntervalsSource reference) {
|
||||
return new ConjunctionIntervalsSource(Arrays.asList(source, reference), IntervalFunction.OVERLAPPING);
|
||||
return new FilteringConjunctionIntervalsSource(source, reference, IntervalFunction.OVERLAPPING);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -230,7 +230,7 @@ public final class Intervals {
|
|||
* @param small the {@link IntervalsSource} to filter by
|
||||
*/
|
||||
public static IntervalsSource containing(IntervalsSource big, IntervalsSource small) {
|
||||
return new ConjunctionIntervalsSource(Arrays.asList(big, small), IntervalFunction.CONTAINING);
|
||||
return new FilteringConjunctionIntervalsSource(big, small, IntervalFunction.CONTAINING);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -255,7 +255,7 @@ public final class Intervals {
|
|||
* @param big the {@link IntervalsSource} to filter by
|
||||
*/
|
||||
public static IntervalsSource containedBy(IntervalsSource small, IntervalsSource big) {
|
||||
return new ConjunctionIntervalsSource(Arrays.asList(small, big), IntervalFunction.CONTAINED_BY);
|
||||
return new FilteringConjunctionIntervalsSource(small, big, IntervalFunction.CONTAINED_BY);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -269,8 +269,8 @@ public final class Intervals {
|
|||
* Returns intervals from the source that appear before intervals from the reference
|
||||
*/
|
||||
public static IntervalsSource before(IntervalsSource source, IntervalsSource reference) {
|
||||
return new ConjunctionIntervalsSource(Arrays.asList(source,
|
||||
Intervals.extend(new OffsetIntervalsSource(reference, true), Integer.MAX_VALUE, 0)),
|
||||
return new FilteringConjunctionIntervalsSource(source,
|
||||
Intervals.extend(new OffsetIntervalsSource(reference, true), Integer.MAX_VALUE, 0),
|
||||
IntervalFunction.CONTAINED_BY);
|
||||
}
|
||||
|
||||
|
@ -278,8 +278,8 @@ public final class Intervals {
|
|||
* Returns intervals from the source that appear after intervals from the reference
|
||||
*/
|
||||
public static IntervalsSource after(IntervalsSource source, IntervalsSource reference) {
|
||||
return new ConjunctionIntervalsSource(Arrays.asList(source,
|
||||
Intervals.extend(new OffsetIntervalsSource(reference, false), 0, Integer.MAX_VALUE)),
|
||||
return new FilteringConjunctionIntervalsSource(source,
|
||||
Intervals.extend(new OffsetIntervalsSource(reference, false), 0, Integer.MAX_VALUE),
|
||||
IntervalFunction.CONTAINED_BY);
|
||||
}
|
||||
|
||||
|
|
|
@ -56,12 +56,17 @@ public abstract class IntervalsSource {
|
|||
public abstract MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException;
|
||||
|
||||
/**
|
||||
* Expert: collect {@link Term} objects from this source, to be used for top-level term scoring
|
||||
* Expert: collect {@link Term} objects from this source
|
||||
* @param field the field to be scored
|
||||
* @param terms a {@link Set} which terms should be added to
|
||||
*/
|
||||
public abstract void extractTerms(String field, Set<Term> terms);
|
||||
|
||||
/**
|
||||
* Return the minimum possible width of an interval returned by this source
|
||||
*/
|
||||
public abstract int minExtent();
|
||||
|
||||
@Override
|
||||
public abstract int hashCode();
|
||||
|
||||
|
|
|
@ -91,6 +91,20 @@ class MinimumShouldMatchIntervalsSource extends IntervalsSource {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
int[] subExtents = new int[sources.length];
|
||||
for (int i = 0; i < subExtents.length; i++) {
|
||||
subExtents[i] = sources[i].minExtent();
|
||||
}
|
||||
Arrays.sort(subExtents);
|
||||
int minExtent = 0;
|
||||
for (int i = 0; i < minShouldMatch; i++) {
|
||||
minExtent += subExtents[i];
|
||||
}
|
||||
return minExtent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "AtLeast("
|
||||
|
|
|
@ -148,6 +148,11 @@ class OffsetIntervalsSource extends IntervalsSource {
|
|||
in.extractTerms(field, terms);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
|
|
|
@ -195,6 +195,11 @@ class TermIntervalsSource extends IntervalsSource {
|
|||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int minExtent() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(term);
|
||||
|
|
|
@ -24,9 +24,11 @@ import org.apache.lucene.document.Document;
|
|||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.search.BoostQuery;
|
||||
import org.apache.lucene.search.CheckHits;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.junit.Ignore;
|
||||
|
@ -194,4 +196,53 @@ public class TestIntervalQuery extends LuceneTestCase {
|
|||
Intervals.phrase(Intervals.term("w1"), Intervals.extend(Intervals.term("w2"), 1, 0)));
|
||||
checkHits(q, new int[]{ 1, 2, 5 });
|
||||
}
|
||||
|
||||
public void testScoring() throws IOException {
|
||||
|
||||
IntervalsSource source = Intervals.ordered(Intervals.or(Intervals.term("w1"), Intervals.term("w2")), Intervals.term("w3"));
|
||||
|
||||
Query q = new IntervalQuery(field, source);
|
||||
TopDocs td = searcher.search(q, 10);
|
||||
assertEquals(5, td.totalHits.value);
|
||||
assertEquals(1, td.scoreDocs[0].doc);
|
||||
assertEquals(3, td.scoreDocs[1].doc);
|
||||
assertEquals(0, td.scoreDocs[2].doc);
|
||||
assertEquals(5, td.scoreDocs[3].doc);
|
||||
assertEquals(2, td.scoreDocs[4].doc);
|
||||
|
||||
Query boostQ = new BoostQuery(q, 2);
|
||||
TopDocs boostTD = searcher.search(boostQ, 10);
|
||||
assertEquals(5, boostTD.totalHits.value);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
assertEquals(td.scoreDocs[i].score * 2, boostTD.scoreDocs[i].score, 0);
|
||||
}
|
||||
|
||||
// change the pivot - order should remain the same
|
||||
Query q1 = new IntervalQuery(field, source, 2);
|
||||
TopDocs td1 = searcher.search(q1, 10);
|
||||
assertEquals(5, td1.totalHits.value);
|
||||
assertEquals(0.5f, td1.scoreDocs[0].score, 0); // freq=pivot
|
||||
for (int i = 0; i < 5; i++) {
|
||||
assertEquals(td.scoreDocs[i].doc, td1.scoreDocs[i].doc);
|
||||
}
|
||||
|
||||
// increase the exp, docs higher than pivot should get a higher score, and vice versa
|
||||
Query q2 = new IntervalQuery(field, source, 1.2f, 2f);
|
||||
TopDocs td2 = searcher.search(q2, 10);
|
||||
assertEquals(5, td2.totalHits.value);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
assertEquals(td.scoreDocs[i].doc, td2.scoreDocs[i].doc);
|
||||
if (i < 2) {
|
||||
assertTrue(td.scoreDocs[i].score < td2.scoreDocs[i].score);
|
||||
}
|
||||
else {
|
||||
assertTrue(td.scoreDocs[i].score > td2.scoreDocs[i].score);
|
||||
}
|
||||
}
|
||||
|
||||
// check valid bounds
|
||||
expectThrows(IllegalArgumentException.class, () -> new IntervalQuery(field, source, -1));
|
||||
expectThrows(IllegalArgumentException.class, () -> new IntervalQuery(field, source, 1, -1f));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -189,6 +189,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(mi, 4, 4, 27, 35);
|
||||
assertMatch(mi, 7, 7, 47, 55);
|
||||
assertFalse(mi.next());
|
||||
|
||||
assertEquals(1, source.minExtent());
|
||||
}
|
||||
|
||||
public void testOrderedNearIntervals() throws IOException {
|
||||
|
@ -214,6 +216,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(sub, 17, 17, 97, 100);
|
||||
assertFalse(sub.next());
|
||||
assertFalse(mi.next());
|
||||
|
||||
assertEquals(2, source.minExtent());
|
||||
}
|
||||
|
||||
public void testPhraseIntervals() throws IOException {
|
||||
|
@ -235,6 +239,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(sub, 4, 4, 26, 34);
|
||||
assertFalse(sub.next());
|
||||
assertMatch(mi, 6, 7, 41, 55);
|
||||
|
||||
assertEquals(2, source.minExtent());
|
||||
}
|
||||
|
||||
public void testUnorderedNearIntervals() throws IOException {
|
||||
|
@ -261,6 +267,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertGaps(source, 1, "field1", new int[]{
|
||||
1, 0, 10
|
||||
});
|
||||
|
||||
assertEquals(2, source.minExtent());
|
||||
}
|
||||
|
||||
public void testIntervalDisjunction() throws IOException {
|
||||
|
@ -280,6 +288,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(mi, 7, 7, 31, 36);
|
||||
assertNull(mi.getSubMatches());
|
||||
assertFalse(mi.next());
|
||||
|
||||
assertEquals(1, source.minExtent());
|
||||
}
|
||||
|
||||
public void testCombinationDisjunction() throws IOException {
|
||||
|
@ -292,6 +302,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
{ 3, 8 },
|
||||
{}, {}, {}, {}
|
||||
});
|
||||
|
||||
assertEquals(2, source.minExtent());
|
||||
}
|
||||
|
||||
public void testNesting() throws IOException {
|
||||
|
@ -307,6 +319,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
{ 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 17 },
|
||||
{}
|
||||
});
|
||||
assertEquals(3, source.minExtent());
|
||||
|
||||
assertNull(getMatches(source, 0, "field1"));
|
||||
MatchesIterator mi = getMatches(source, 1, "field1");
|
||||
|
@ -384,6 +397,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(sub, 21, 21, 114, 118);
|
||||
assertFalse(sub.next());
|
||||
assertFalse(it.next());
|
||||
assertEquals(4, source.minExtent());
|
||||
}
|
||||
|
||||
public void testUnorderedDistinct() throws IOException {
|
||||
|
@ -447,6 +461,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(subs, 21, 21, 114, 118);
|
||||
assertFalse(subs.next());
|
||||
assertFalse(mi.next());
|
||||
assertEquals(1, source.minExtent());
|
||||
}
|
||||
|
||||
public void testContaining() throws IOException {
|
||||
|
@ -476,6 +491,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(subs, 21, 21, 114, 118);
|
||||
assertFalse(subs.next());
|
||||
assertFalse(mi.next());
|
||||
assertEquals(2, source.minExtent());
|
||||
}
|
||||
|
||||
public void testNotContaining() throws IOException {
|
||||
|
@ -498,6 +514,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(subs, 6, 6, 41, 46);
|
||||
assertFalse(subs.next());
|
||||
assertFalse(mi.next());
|
||||
assertEquals(2, source.minExtent());
|
||||
}
|
||||
|
||||
public void testMaxGaps() throws IOException {
|
||||
|
@ -512,6 +529,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
MatchesIterator mi = getMatches(source, 5, "field2");
|
||||
assertMatch(mi, 0, 3, 0, 11);
|
||||
|
||||
assertEquals(3, source.minExtent());
|
||||
|
||||
}
|
||||
|
||||
public void testNestedMaxGaps() throws IOException {
|
||||
|
@ -531,6 +550,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(mi, 0, 3, 0, 11);
|
||||
assertMatch(mi, 3, 6, 9, 20);
|
||||
assertMatch(mi, 4, 8, 12, 26);
|
||||
|
||||
assertEquals(3, source.minExtent());
|
||||
}
|
||||
|
||||
public void testMinimumShouldMatch() throws IOException {
|
||||
|
@ -564,6 +585,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(subs, 7, 7, 47, 55);
|
||||
assertMatch(subs, 11, 11, 67, 71);
|
||||
|
||||
assertEquals(3, source.minExtent());
|
||||
|
||||
}
|
||||
|
||||
public void testDefinedGaps() throws IOException {
|
||||
|
@ -580,6 +603,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
{ 3, 7 },
|
||||
{}
|
||||
});
|
||||
assertEquals(5, source.minExtent());
|
||||
|
||||
MatchesIterator mi = getMatches(source, 1, "field1");
|
||||
assertMatch(mi, 3, 7, 20, 55);
|
||||
|
@ -594,6 +618,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
{}, {}, {}, {}, {},
|
||||
{ 0, Integer.MAX_VALUE - 1, 0, Integer.MAX_VALUE - 1, 5, Integer.MAX_VALUE - 1 }
|
||||
});
|
||||
|
||||
assertEquals(Integer.MAX_VALUE, source.minExtent());
|
||||
}
|
||||
|
||||
public void testAfter() throws IOException {
|
||||
|
@ -616,6 +642,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(sub, 5, 5, 35, 39);
|
||||
assertMatch(sub, 7, 7, 47, 55);
|
||||
assertFalse(sub.next());
|
||||
|
||||
assertEquals(1, source.minExtent());
|
||||
}
|
||||
|
||||
public void testBefore() throws IOException {
|
||||
|
@ -628,6 +656,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
{ 5, 5 },
|
||||
{}
|
||||
});
|
||||
assertEquals(1, source.minExtent());
|
||||
}
|
||||
|
||||
public void testWithin() throws IOException {
|
||||
|
@ -641,6 +670,7 @@ public class TestIntervals extends LuceneTestCase {
|
|||
{ 2, 2 },
|
||||
{}
|
||||
});
|
||||
assertEquals(1, source.minExtent());
|
||||
}
|
||||
|
||||
public void testOverlapping() throws IOException {
|
||||
|
@ -670,6 +700,8 @@ public class TestIntervals extends LuceneTestCase {
|
|||
assertMatch(sub, 5, 5, 35, 39);
|
||||
assertFalse(sub.next());
|
||||
assertMatch(mi, 7, 17, 41, 118);
|
||||
|
||||
assertEquals(2, source.minExtent());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue