diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index b570c48e6db..761616f8364 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -156,6 +156,10 @@ Changes in Runtime Behavior tokens by default, preventing a number of bugs when the filter is chained with tokenfilters that change the length of their tokens (Alan Woodward) +* LUCENE-8633: IntervalQuery scores do not use term weighting any more, the score + is instead calculated as a function of the sloppy frequency of the matching + intervals. (Alan Woodward, Jim Ferenczi) + New Features * LUCENE-8340: LongPoint#newDistanceQuery may be used to boost scores based on diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java index 6cbfadab38a..ec4341de709 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java @@ -62,6 +62,15 @@ class ConjunctionIntervalsSource extends IntervalsSource { } } + @Override + public int minExtent() { + int minExtent = 0; + for (IntervalsSource source : subSources) { + minExtent += source.minExtent(); + } + return minExtent; + } + @Override public IntervalIterator intervals(String field, LeafReaderContext ctx) throws IOException { List subIntervals = new ArrayList<>(); diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java index e4b7fd9aad8..7289d04ba25 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java @@ -86,4 +86,9 @@ class DifferenceIntervalsSource extends IntervalsSource { public void extractTerms(String field, Set terms) { minuend.extractTerms(field, terms); } + + @Override + public int minExtent() { + return minuend.minExtent(); + } } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java index 68a9e5decdd..79089c7a5e0 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java @@ -90,6 +90,15 @@ class DisjunctionIntervalsSource extends IntervalsSource { } } + @Override + public int minExtent() { + int minExtent = subSources.get(0).minExtent(); + for (int i = 1; i < subSources.size(); i++) { + minExtent = Math.min(minExtent, subSources.get(i).minExtent()); + } + return minExtent; + } + private static class DisjunctionIntervalIterator extends IntervalIterator { final DocIdSetIterator approximation; diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java index d4e3bfa5693..864a4b573ca 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java @@ -61,6 +61,15 @@ class ExtendedIntervalsSource extends IntervalsSource { source.extractTerms(field, terms); } + @Override + public int minExtent() { + int minExtent = before + source.minExtent() + after; + if (minExtent < 0) { + return Integer.MAX_VALUE; + } + return minExtent; + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java index 8eac88d8b23..c2b4d6012c8 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java @@ -77,6 +77,11 @@ public abstract class FilteredIntervalsSource extends IntervalsSource { return IntervalMatches.asMatches(filtered, mi, doc); } + @Override + public int minExtent() { + return in.minExtent(); + } + @Override public void extractTerms(String field, Set terms) { in.extractTerms(field, terms); diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteringConjunctionIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteringConjunctionIntervalsSource.java new file mode 100644 index 00000000000..cc029821e27 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteringConjunctionIntervalsSource.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.intervals; + +import java.util.Arrays; + +/** + * An intervals source that combines two other sources, requiring both of them to + * be present in order to match, but using the minExtent of one of them + */ +class FilteringConjunctionIntervalsSource extends ConjunctionIntervalsSource { + + private final IntervalsSource source; + + FilteringConjunctionIntervalsSource(IntervalsSource source, IntervalsSource filter, IntervalFunction function) { + super(Arrays.asList(source, filter), function); + this.source = source; + } + + @Override + public int minExtent() { + return source.minExtent(); + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java index 4e2569ca1bd..62fe0679bf0 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java @@ -18,28 +18,21 @@ package org.apache.lucene.search.intervals; import java.io.IOException; -import java.util.HashSet; import java.util.Objects; import java.util.Set; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; -import org.apache.lucene.index.TermStates; -import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FilterMatchesIterator; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.LeafSimScorer; import org.apache.lucene.search.Matches; import org.apache.lucene.search.MatchesIterator; import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.Weight; -import org.apache.lucene.search.similarities.Similarity; -import org.apache.lucene.util.ArrayUtil; /** * A query that retrieves documents containing intervals returned from an @@ -47,11 +40,30 @@ import org.apache.lucene.util.ArrayUtil; * * Static constructor functions for various different sources can be found in the * {@link Intervals} class + * + * Scores for this query are computed as a function of the sloppy frequency of + * intervals appearing in a particular document. Sloppy frequency is calculated + * from the number of matching intervals, and their width, with wider intervals + * contributing lower values. The scores can be adjusted with two optional + * parameters: + *
    + *
  • pivot - the sloppy frequency value at which the overall score of the + * document will equal 0.5. The default value is 1
  • + *
  • exp - higher values of this parameter make the function grow more slowly + * below the pivot and faster higher than the pivot. The default value is 1
  • + *
+ * + * Optimal values for both pivot and exp depend on the type of queries and corpus of + * documents being queried. + * + * Scores are bounded to between 0 and 1. For higher contributions, wrap the query + * in a {@link org.apache.lucene.search.BoostQuery} */ public final class IntervalQuery extends Query { private final String field; private final IntervalsSource intervalsSource; + private final IntervalScoreFunction scoreFunction; /** * Create a new IntervalQuery @@ -59,10 +71,41 @@ public final class IntervalQuery extends Query { * @param intervalsSource an {@link IntervalsSource} to retrieve intervals from */ public IntervalQuery(String field, IntervalsSource intervalsSource) { - this.field = field; - this.intervalsSource = intervalsSource; + this(field, intervalsSource, IntervalScoreFunction.saturationFunction(1)); } + /** + * Create a new IntervalQuery with a scoring pivot + * + * @param field the field to query + * @param intervalsSource an {@link IntervalsSource} to retrieve intervals from + * @param pivot the sloppy frequency value at which the score will be 0.5, must be within (0, +Infinity) + */ + public IntervalQuery(String field, IntervalsSource intervalsSource, float pivot) { + this(field, intervalsSource, IntervalScoreFunction.saturationFunction(pivot)); + } + + /** + * Create a new IntervalQuery with a scoring pivot and exponent + * @param field the field to query + * @param intervalsSource an {@link IntervalsSource} to retrieve intervals from + * @param pivot the sloppy frequency value at which the score will be 0.5, must be within (0, +Infinity) + * @param exp exponent, higher values make the function grow slower before 'pivot' and faster + * after 'pivot', must be in (0, +Infinity) + */ + public IntervalQuery(String field, IntervalsSource intervalsSource, float pivot, float exp) { + this(field, intervalsSource, IntervalScoreFunction.sigmoidFunction(pivot, exp)); + } + + private IntervalQuery(String field, IntervalsSource intervalsSource, IntervalScoreFunction scoreFunction) { + this.field = field; + this.intervalsSource = intervalsSource; + this.scoreFunction = scoreFunction; + } + + /** + * The field to query + */ public String getField() { return field; } @@ -74,26 +117,7 @@ public final class IntervalQuery extends Query { @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - return new IntervalWeight(this, scoreMode.needsScores() ? buildSimScorer(searcher, boost) : null, - searcher.getSimilarity(), scoreMode); - } - - private Similarity.SimScorer buildSimScorer(IndexSearcher searcher, float boost) throws IOException { - Set terms = new HashSet<>(); - intervalsSource.extractTerms(field, terms); - TermStatistics[] termStats = new TermStatistics[terms.size()]; - int termUpTo = 0; - for (Term term : terms) { - TermStatistics termStatistics = searcher.termStatistics(term, TermStates.build(searcher.getTopReaderContext(), term, true)); - if (termStatistics != null) { - termStats[termUpTo++] = termStatistics; - } - } - if (termUpTo == 0) { - return null; - } - CollectionStatistics collectionStats = searcher.collectionStatistics(field); - return searcher.getSimilarity().scorer(boost, collectionStats, ArrayUtil.copyOfSubArray(termStats, 0, termUpTo)); + return new IntervalWeight(this, boost, scoreMode); } @Override @@ -112,15 +136,13 @@ public final class IntervalQuery extends Query { private class IntervalWeight extends Weight { - final Similarity.SimScorer simScorer; - final Similarity similarity; final ScoreMode scoreMode; + final float boost; - public IntervalWeight(Query query, Similarity.SimScorer simScorer, Similarity similarity, ScoreMode scoreMode) { + public IntervalWeight(Query query, float boost, ScoreMode scoreMode) { super(query); - this.simScorer = simScorer; - this.similarity = similarity; this.scoreMode = scoreMode; + this.boost = boost; } @Override @@ -134,7 +156,8 @@ public final class IntervalQuery extends Query { if (scorer != null) { int newDoc = scorer.iterator().advance(doc); if (newDoc == doc) { - return scorer.explain("weight("+getQuery()+" in "+doc+") [" + similarity.getClass().getSimpleName() + "]"); + float freq = scorer.freq(); + return scoreFunction.explain(intervalsSource.toString(), boost, freq); } } return Explanation.noMatch("no matching intervals"); @@ -161,9 +184,7 @@ public final class IntervalQuery extends Query { IntervalIterator intervals = intervalsSource.intervals(field, context); if (intervals == null) return null; - LeafSimScorer leafScorer = simScorer == null ? null - : new LeafSimScorer(simScorer, context.reader(), field, scoreMode.needsScores()); - return new IntervalScorer(this, intervals, leafScorer); + return new IntervalScorer(this, intervals, intervalsSource.minExtent(), boost, scoreFunction); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalScoreFunction.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalScoreFunction.java new file mode 100644 index 00000000000..855b398545c --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalScoreFunction.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search.intervals; + +import java.util.Objects; + +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.similarities.Similarity; + +abstract class IntervalScoreFunction { + + static IntervalScoreFunction saturationFunction(float pivot) { + if (pivot <= 0 || Float.isFinite(pivot) == false) { + throw new IllegalArgumentException("pivot must be > 0, got: " + pivot); + } + return new SaturationFunction(pivot); + } + + static IntervalScoreFunction sigmoidFunction(float pivot, float exp) { + if (pivot <= 0 || Float.isFinite(pivot) == false) { + throw new IllegalArgumentException("pivot must be > 0, got: " + pivot); + } + if (exp <= 0 || Float.isFinite(exp) == false) { + throw new IllegalArgumentException("exp must be > 0, got: " + exp); + } + return new SigmoidFunction(pivot, exp); + } + + public abstract Similarity.SimScorer scorer(float weight); + + public abstract Explanation explain(String interval, float weight, float sloppyFreq); + + @Override + public abstract boolean equals(Object other); + + @Override + public abstract int hashCode(); + + @Override + public abstract String toString(); + + private static class SaturationFunction extends IntervalScoreFunction { + + final float pivot; + + private SaturationFunction(float pivot) { + this.pivot = pivot; + } + + @Override + public Similarity.SimScorer scorer(float weight) { + return new Similarity.SimScorer() { + @Override + public float score(float freq, long norm) { + // should be f / (f + k) but we rewrite it to + // 1 - k / (f + k) to make sure it doesn't decrease + // with f in spite of rounding + return weight * (1.0f - pivot / (pivot + freq)); + } + }; + } + + @Override + public Explanation explain(String interval, float weight, float sloppyFreq) { + float score = scorer(weight).score(sloppyFreq, 1L); + return Explanation.match(score, + "Saturation function on interval frequency, computed as w * S / (S + k) from:", + Explanation.match(weight, "w, weight of this function"), + Explanation.match(pivot, "k, pivot feature value that would give a score contribution equal to w/2"), + Explanation.match(sloppyFreq, "S, the sloppy frequency of the interval " + interval)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SaturationFunction that = (SaturationFunction) o; + return Float.compare(that.pivot, pivot) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(pivot); + } + + @Override + public String toString() { + return "SaturationFunction(pivot=" + pivot + ")"; + } + } + + private static class SigmoidFunction extends IntervalScoreFunction { + + private final float pivot, a; + private final double pivotPa; + + private SigmoidFunction(float pivot, float a) { + this.pivot = pivot; + this.a = a; + this.pivotPa = Math.pow(pivot, a); + } + + @Override + public Similarity.SimScorer scorer(float weight) { + return new Similarity.SimScorer() { + @Override + public float score(float freq, long norm) { + // should be f^a / (f^a + k^a) but we rewrite it to + // 1 - k^a / (f + k^a) to make sure it doesn't decrease + // with f in spite of rounding + return (float) (weight * (1.0f - pivotPa / (Math.pow(freq, a) + pivotPa))); + } + }; + } + + @Override + public Explanation explain(String interval, float weight, float sloppyFreq) { + float score = scorer(weight).score(sloppyFreq, 1L); + return Explanation.match(score, + "Sigmoid function on interval frequency, computed as w * S^a / (S^a + k^a) from:", + Explanation.match(weight, "w, weight of this function"), + Explanation.match(pivot, "k, pivot feature value that would give a score contribution equal to w/2"), + Explanation.match(a, "a, exponent, higher values make the function grow slower before k and faster after k"), + Explanation.match(sloppyFreq, "S, the sloppy frequency of the interval " + interval)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SigmoidFunction that = (SigmoidFunction) o; + return Float.compare(that.pivot, pivot) == 0 && + Float.compare(that.a, a) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(pivot, a); + } + + @Override + public String toString() { + return "SigmoidFunction(pivot=" + pivot + ", a=" + a + ")"; + } + } + +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalScorer.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalScorer.java index 6672905df96..18b88994fc8 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalScorer.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalScorer.java @@ -20,24 +20,27 @@ package org.apache.lucene.search.intervals; import java.io.IOException; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.LeafSimScorer; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.similarities.Similarity; class IntervalScorer extends Scorer { private final IntervalIterator intervals; - private final LeafSimScorer simScorer; + private final Similarity.SimScorer simScorer; + private final float boost; + private final int minExtent; - private float freq = -1; + private float freq; private int lastScoredDoc = -1; - protected IntervalScorer(Weight weight, IntervalIterator intervals, LeafSimScorer simScorer) { + IntervalScorer(Weight weight, IntervalIterator intervals, int minExtent, float boost, IntervalScoreFunction scoreFunction) { super(weight); this.intervals = intervals; - this.simScorer = simScorer; + this.minExtent = minExtent; + this.boost = boost; + this.simScorer = scoreFunction.scorer(boost); } @Override @@ -48,19 +51,10 @@ class IntervalScorer extends Scorer { @Override public float score() throws IOException { ensureFreq(); - return simScorer.score(docID(), freq); + return simScorer.score(freq, 1); } - public Explanation explain(String topLevel) throws IOException { - ensureFreq(); - Explanation freqExplanation = Explanation.match(freq, "intervalFreq=" + freq); - Explanation scoreExplanation = simScorer.explain(docID(), freqExplanation); - return Explanation.match(scoreExplanation.getValue(), - topLevel + ", result of:", - scoreExplanation); - } - - public float freq() throws IOException { + float freq() throws IOException { ensureFreq(); return freq; } @@ -70,7 +64,8 @@ class IntervalScorer extends Scorer { lastScoredDoc = docID(); freq = 0; do { - freq += (1.0 / (intervals.end() - intervals.start() + 1)); + int length = (intervals.end() - intervals.start() + 1); + freq += 1.0 / Math.max(length - minExtent + 1, 1); } while (intervals.nextInterval() != IntervalIterator.NO_MORE_INTERVALS); } @@ -97,9 +92,8 @@ class IntervalScorer extends Scorer { } @Override - public float getMaxScore(int upTo) throws IOException { - return Float.POSITIVE_INFINITY; + public float getMaxScore(int upTo) { + return boost; } - } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/Intervals.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/Intervals.java index 1b6dbae1f0e..1c8d71a38db 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/Intervals.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/Intervals.java @@ -177,7 +177,7 @@ public final class Intervals { * @param reference the source to filter by */ public static IntervalsSource overlapping(IntervalsSource source, IntervalsSource reference) { - return new ConjunctionIntervalsSource(Arrays.asList(source, reference), IntervalFunction.OVERLAPPING); + return new FilteringConjunctionIntervalsSource(source, reference, IntervalFunction.OVERLAPPING); } /** @@ -230,7 +230,7 @@ public final class Intervals { * @param small the {@link IntervalsSource} to filter by */ public static IntervalsSource containing(IntervalsSource big, IntervalsSource small) { - return new ConjunctionIntervalsSource(Arrays.asList(big, small), IntervalFunction.CONTAINING); + return new FilteringConjunctionIntervalsSource(big, small, IntervalFunction.CONTAINING); } /** @@ -255,7 +255,7 @@ public final class Intervals { * @param big the {@link IntervalsSource} to filter by */ public static IntervalsSource containedBy(IntervalsSource small, IntervalsSource big) { - return new ConjunctionIntervalsSource(Arrays.asList(small, big), IntervalFunction.CONTAINED_BY); + return new FilteringConjunctionIntervalsSource(small, big, IntervalFunction.CONTAINED_BY); } /** @@ -269,8 +269,8 @@ public final class Intervals { * Returns intervals from the source that appear before intervals from the reference */ public static IntervalsSource before(IntervalsSource source, IntervalsSource reference) { - return new ConjunctionIntervalsSource(Arrays.asList(source, - Intervals.extend(new OffsetIntervalsSource(reference, true), Integer.MAX_VALUE, 0)), + return new FilteringConjunctionIntervalsSource(source, + Intervals.extend(new OffsetIntervalsSource(reference, true), Integer.MAX_VALUE, 0), IntervalFunction.CONTAINED_BY); } @@ -278,8 +278,8 @@ public final class Intervals { * Returns intervals from the source that appear after intervals from the reference */ public static IntervalsSource after(IntervalsSource source, IntervalsSource reference) { - return new ConjunctionIntervalsSource(Arrays.asList(source, - Intervals.extend(new OffsetIntervalsSource(reference, false), 0, Integer.MAX_VALUE)), + return new FilteringConjunctionIntervalsSource(source, + Intervals.extend(new OffsetIntervalsSource(reference, false), 0, Integer.MAX_VALUE), IntervalFunction.CONTAINED_BY); } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java index 14d9471b6c8..dc4161fa051 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java @@ -56,12 +56,17 @@ public abstract class IntervalsSource { public abstract MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException; /** - * Expert: collect {@link Term} objects from this source, to be used for top-level term scoring + * Expert: collect {@link Term} objects from this source * @param field the field to be scored * @param terms a {@link Set} which terms should be added to */ public abstract void extractTerms(String field, Set terms); + /** + * Return the minimum possible width of an interval returned by this source + */ + public abstract int minExtent(); + @Override public abstract int hashCode(); diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java index e97f60c9967..1935c628ee1 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java @@ -91,6 +91,20 @@ class MinimumShouldMatchIntervalsSource extends IntervalsSource { } } + @Override + public int minExtent() { + int[] subExtents = new int[sources.length]; + for (int i = 0; i < subExtents.length; i++) { + subExtents[i] = sources[i].minExtent(); + } + Arrays.sort(subExtents); + int minExtent = 0; + for (int i = 0; i < minShouldMatch; i++) { + minExtent += subExtents[i]; + } + return minExtent; + } + @Override public String toString() { return "AtLeast(" diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java index 470e2b539f6..b2ca30224e5 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java @@ -148,6 +148,11 @@ class OffsetIntervalsSource extends IntervalsSource { in.extractTerms(field, terms); } + @Override + public int minExtent() { + return 1; + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java index f58d8677c9d..1b5444ae5b0 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java @@ -195,6 +195,11 @@ class TermIntervalsSource extends IntervalsSource { }; } + @Override + public int minExtent() { + return 1; + } + @Override public int hashCode() { return Objects.hash(term); diff --git a/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervalQuery.java b/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervalQuery.java index 61106052e1a..9cea616b3a7 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervalQuery.java +++ b/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervalQuery.java @@ -24,9 +24,11 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.CheckHits; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.util.LuceneTestCase; import org.junit.Ignore; @@ -194,4 +196,53 @@ public class TestIntervalQuery extends LuceneTestCase { Intervals.phrase(Intervals.term("w1"), Intervals.extend(Intervals.term("w2"), 1, 0))); checkHits(q, new int[]{ 1, 2, 5 }); } + + public void testScoring() throws IOException { + + IntervalsSource source = Intervals.ordered(Intervals.or(Intervals.term("w1"), Intervals.term("w2")), Intervals.term("w3")); + + Query q = new IntervalQuery(field, source); + TopDocs td = searcher.search(q, 10); + assertEquals(5, td.totalHits.value); + assertEquals(1, td.scoreDocs[0].doc); + assertEquals(3, td.scoreDocs[1].doc); + assertEquals(0, td.scoreDocs[2].doc); + assertEquals(5, td.scoreDocs[3].doc); + assertEquals(2, td.scoreDocs[4].doc); + + Query boostQ = new BoostQuery(q, 2); + TopDocs boostTD = searcher.search(boostQ, 10); + assertEquals(5, boostTD.totalHits.value); + for (int i = 0; i < 5; i++) { + assertEquals(td.scoreDocs[i].score * 2, boostTD.scoreDocs[i].score, 0); + } + + // change the pivot - order should remain the same + Query q1 = new IntervalQuery(field, source, 2); + TopDocs td1 = searcher.search(q1, 10); + assertEquals(5, td1.totalHits.value); + assertEquals(0.5f, td1.scoreDocs[0].score, 0); // freq=pivot + for (int i = 0; i < 5; i++) { + assertEquals(td.scoreDocs[i].doc, td1.scoreDocs[i].doc); + } + + // increase the exp, docs higher than pivot should get a higher score, and vice versa + Query q2 = new IntervalQuery(field, source, 1.2f, 2f); + TopDocs td2 = searcher.search(q2, 10); + assertEquals(5, td2.totalHits.value); + for (int i = 0; i < 5; i++) { + assertEquals(td.scoreDocs[i].doc, td2.scoreDocs[i].doc); + if (i < 2) { + assertTrue(td.scoreDocs[i].score < td2.scoreDocs[i].score); + } + else { + assertTrue(td.scoreDocs[i].score > td2.scoreDocs[i].score); + } + } + + // check valid bounds + expectThrows(IllegalArgumentException.class, () -> new IntervalQuery(field, source, -1)); + expectThrows(IllegalArgumentException.class, () -> new IntervalQuery(field, source, 1, -1f)); + } + } diff --git a/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervals.java b/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervals.java index d1c2479eaa1..5f58ebf5c87 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervals.java +++ b/lucene/sandbox/src/test/org/apache/lucene/search/intervals/TestIntervals.java @@ -189,6 +189,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(mi, 4, 4, 27, 35); assertMatch(mi, 7, 7, 47, 55); assertFalse(mi.next()); + + assertEquals(1, source.minExtent()); } public void testOrderedNearIntervals() throws IOException { @@ -214,6 +216,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(sub, 17, 17, 97, 100); assertFalse(sub.next()); assertFalse(mi.next()); + + assertEquals(2, source.minExtent()); } public void testPhraseIntervals() throws IOException { @@ -235,6 +239,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(sub, 4, 4, 26, 34); assertFalse(sub.next()); assertMatch(mi, 6, 7, 41, 55); + + assertEquals(2, source.minExtent()); } public void testUnorderedNearIntervals() throws IOException { @@ -261,6 +267,8 @@ public class TestIntervals extends LuceneTestCase { assertGaps(source, 1, "field1", new int[]{ 1, 0, 10 }); + + assertEquals(2, source.minExtent()); } public void testIntervalDisjunction() throws IOException { @@ -280,6 +288,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(mi, 7, 7, 31, 36); assertNull(mi.getSubMatches()); assertFalse(mi.next()); + + assertEquals(1, source.minExtent()); } public void testCombinationDisjunction() throws IOException { @@ -292,6 +302,8 @@ public class TestIntervals extends LuceneTestCase { { 3, 8 }, {}, {}, {}, {} }); + + assertEquals(2, source.minExtent()); } public void testNesting() throws IOException { @@ -307,6 +319,7 @@ public class TestIntervals extends LuceneTestCase { { 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 17 }, {} }); + assertEquals(3, source.minExtent()); assertNull(getMatches(source, 0, "field1")); MatchesIterator mi = getMatches(source, 1, "field1"); @@ -384,6 +397,7 @@ public class TestIntervals extends LuceneTestCase { assertMatch(sub, 21, 21, 114, 118); assertFalse(sub.next()); assertFalse(it.next()); + assertEquals(4, source.minExtent()); } public void testUnorderedDistinct() throws IOException { @@ -447,6 +461,7 @@ public class TestIntervals extends LuceneTestCase { assertMatch(subs, 21, 21, 114, 118); assertFalse(subs.next()); assertFalse(mi.next()); + assertEquals(1, source.minExtent()); } public void testContaining() throws IOException { @@ -476,6 +491,7 @@ public class TestIntervals extends LuceneTestCase { assertMatch(subs, 21, 21, 114, 118); assertFalse(subs.next()); assertFalse(mi.next()); + assertEquals(2, source.minExtent()); } public void testNotContaining() throws IOException { @@ -498,6 +514,7 @@ public class TestIntervals extends LuceneTestCase { assertMatch(subs, 6, 6, 41, 46); assertFalse(subs.next()); assertFalse(mi.next()); + assertEquals(2, source.minExtent()); } public void testMaxGaps() throws IOException { @@ -512,6 +529,8 @@ public class TestIntervals extends LuceneTestCase { MatchesIterator mi = getMatches(source, 5, "field2"); assertMatch(mi, 0, 3, 0, 11); + assertEquals(3, source.minExtent()); + } public void testNestedMaxGaps() throws IOException { @@ -531,6 +550,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(mi, 0, 3, 0, 11); assertMatch(mi, 3, 6, 9, 20); assertMatch(mi, 4, 8, 12, 26); + + assertEquals(3, source.minExtent()); } public void testMinimumShouldMatch() throws IOException { @@ -564,6 +585,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(subs, 7, 7, 47, 55); assertMatch(subs, 11, 11, 67, 71); + assertEquals(3, source.minExtent()); + } public void testDefinedGaps() throws IOException { @@ -580,6 +603,7 @@ public class TestIntervals extends LuceneTestCase { { 3, 7 }, {} }); + assertEquals(5, source.minExtent()); MatchesIterator mi = getMatches(source, 1, "field1"); assertMatch(mi, 3, 7, 20, 55); @@ -594,6 +618,8 @@ public class TestIntervals extends LuceneTestCase { {}, {}, {}, {}, {}, { 0, Integer.MAX_VALUE - 1, 0, Integer.MAX_VALUE - 1, 5, Integer.MAX_VALUE - 1 } }); + + assertEquals(Integer.MAX_VALUE, source.minExtent()); } public void testAfter() throws IOException { @@ -616,6 +642,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(sub, 5, 5, 35, 39); assertMatch(sub, 7, 7, 47, 55); assertFalse(sub.next()); + + assertEquals(1, source.minExtent()); } public void testBefore() throws IOException { @@ -628,6 +656,7 @@ public class TestIntervals extends LuceneTestCase { { 5, 5 }, {} }); + assertEquals(1, source.minExtent()); } public void testWithin() throws IOException { @@ -641,6 +670,7 @@ public class TestIntervals extends LuceneTestCase { { 2, 2 }, {} }); + assertEquals(1, source.minExtent()); } public void testOverlapping() throws IOException { @@ -670,6 +700,8 @@ public class TestIntervals extends LuceneTestCase { assertMatch(sub, 5, 5, 35, 39); assertFalse(sub.next()); assertMatch(mi, 7, 17, 41, 118); + + assertEquals(2, source.minExtent()); } }