From 72163649b16a909ba21baca4cb6d3840f60855e8 Mon Sep 17 00:00:00 2001 From: David Smiley Date: Mon, 3 Jan 2022 08:26:50 -0500 Subject: [PATCH] LUCENE-10252: ValueSource.asDoubleValues should not compute the score (#519) ValueSource.asDoubleValues and asLongValues should not compute the score unless asked to -- typically never. This fixes a performance regression since 7.3 LUCENE-8099 when some older boosting queries were replaced with this. --- lucene/CHANGES.txt | 4 + .../lucene/queries/function/ValueSource.java | 85 +++++----- .../queries/function/TestValueSources.java | 156 +++++++++++++++++- 3 files changed, 204 insertions(+), 41 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 1293cdc7b80..014b13e3530 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -100,6 +100,10 @@ Optimizations * LUCENE-10321: Tweak MultiRangeQuery interval tree creation to skip "pulling up" mins. (Greg Miller) +* LUCENE-10252: ValueSource.asDoubleValues and asLongValues should not compute the score unless + asked to -- typically never. This fixes a performance regression since 7.3 LUCENE-8099 when some + older boosting queries were replaced with this. (David Smiley) + Bug Fixes --------------------- diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java index 68539a92492..88658389c42 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/ValueSource.java @@ -17,7 +17,6 @@ package org.apache.lucene.queries.function; import java.io.IOException; -import java.util.HashMap; import java.util.IdentityHashMap; import java.util.Map; import java.util.Objects; @@ -77,56 +76,75 @@ public abstract class ValueSource { return context; } - private static class ScoreAndDoc extends Scorable { - - int current = -1; + private static class ScorableView extends Scorable { + final DoubleValues scores; + int docId = -1; + int scoresDocId = -1; float score = 0; - @Override - public int docID() { - return current; + public ScorableView(int docId, float score) { + this(null); + this.docId = this.scoresDocId = docId; + this.score = score; + } + + public ScorableView(DoubleValues scores) { + this.scores = scores == null ? DoubleValues.EMPTY : scores; } @Override - public float score() { + public int docID() { + return docId; + } + + @Override + public float score() throws IOException { + // ensure we calculate the score at most once + if (scoresDocId != docId) { + scoresDocId = docId; + if (scores.advanceExact(docId)) { + score = (float) scores.doubleValue(); + } else { + score = 0; + } + } return score; } } /** Expose this ValueSource as a LongValuesSource */ public LongValuesSource asLongValuesSource() { - return new WrappedLongValuesSource(this); + return new WrappedLongValuesSource(this, null); } private static class WrappedLongValuesSource extends LongValuesSource { private final ValueSource in; + private final IndexSearcher searcher; - private WrappedLongValuesSource(ValueSource in) { + private WrappedLongValuesSource(ValueSource in, IndexSearcher searcher) { this.in = in; + this.searcher = searcher; } @Override public LongValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { - Map context = new IdentityHashMap<>(); - ScoreAndDoc scorer = new ScoreAndDoc(); + Map context = newContext(searcher); + + var scorer = new ScorableView(scores); context.put("scorer", scorer); - final FunctionValues fv = in.getValues(context, ctx); + + FunctionValues fv = in.getValues(context, ctx); return new LongValues() { @Override public long longValue() throws IOException { - return fv.longVal(scorer.current); + return fv.longVal(scorer.docId); } @Override public boolean advanceExact(int doc) throws IOException { - scorer.current = doc; - if (scores != null && scores.advanceExact(doc)) { - scorer.score = (float) scores.doubleValue(); - } else { - scorer.score = 0; - } + scorer.docId = doc; return fv.exists(doc); } }; @@ -162,7 +180,7 @@ public abstract class ValueSource { @Override public LongValuesSource rewrite(IndexSearcher searcher) throws IOException { - return this; + return new WrappedLongValuesSource(in, searcher); } } @@ -183,26 +201,22 @@ public abstract class ValueSource { @Override public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { - Map context = new HashMap<>(); - ScoreAndDoc scorer = new ScoreAndDoc(); + Map context = newContext(searcher); + + var scorer = new ScorableView(scores); context.put("scorer", scorer); - context.put("searcher", searcher); + FunctionValues fv = in.getValues(context, ctx); return new DoubleValues() { @Override public double doubleValue() throws IOException { - return fv.doubleVal(scorer.current); + return fv.doubleVal(scorer.docId); } @Override - public boolean advanceExact(int doc) throws IOException { - scorer.current = doc; - if (scores != null && scores.advanceExact(doc)) { - scorer.score = (float) scores.doubleValue(); - } else { - scorer.score = 0; - } + public boolean advanceExact(int doc) { + scorer.docId = doc; // ValueSource will return values even if exists() is false, generally a default // of some kind. To preserve this behaviour with the iterator, we need to always // return 'true' here. @@ -224,11 +238,8 @@ public abstract class ValueSource { @Override public Explanation explain(LeafReaderContext ctx, int docId, Explanation scoreExplanation) throws IOException { - Map context = new HashMap<>(); - ScoreAndDoc scorer = new ScoreAndDoc(); - scorer.score = scoreExplanation.getValue().floatValue(); - context.put("scorer", scorer); - context.put("searcher", searcher); + Map context = newContext(searcher); + context.put("scorer", new ScorableView(docId, scoreExplanation.getValue().floatValue())); FunctionValues fv = in.getValues(context, ctx); return fv.explain(docId); } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java b/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java index 8575990566a..75c2103c046 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/function/TestValueSources.java @@ -34,18 +34,60 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.queries.function.docvalues.FloatDocValues; -import org.apache.lucene.queries.function.valuesource.*; +import org.apache.lucene.queries.function.valuesource.BytesRefFieldSource; +import org.apache.lucene.queries.function.valuesource.ConstValueSource; +import org.apache.lucene.queries.function.valuesource.DivFloatFunction; +import org.apache.lucene.queries.function.valuesource.DocFreqValueSource; +import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource; +import org.apache.lucene.queries.function.valuesource.DoubleFieldSource; +import org.apache.lucene.queries.function.valuesource.FloatFieldSource; +import org.apache.lucene.queries.function.valuesource.IDFValueSource; +import org.apache.lucene.queries.function.valuesource.IfFunction; +import org.apache.lucene.queries.function.valuesource.IntFieldSource; +import org.apache.lucene.queries.function.valuesource.JoinDocFreqValueSource; +import org.apache.lucene.queries.function.valuesource.LinearFloatFunction; +import org.apache.lucene.queries.function.valuesource.LiteralValueSource; +import org.apache.lucene.queries.function.valuesource.LongFieldSource; +import org.apache.lucene.queries.function.valuesource.MaxDocValueSource; +import org.apache.lucene.queries.function.valuesource.MaxFloatFunction; +import org.apache.lucene.queries.function.valuesource.MinFloatFunction; +import org.apache.lucene.queries.function.valuesource.MultiBoolFunction; +import org.apache.lucene.queries.function.valuesource.MultiFloatFunction; +import org.apache.lucene.queries.function.valuesource.MultiFunction; +import org.apache.lucene.queries.function.valuesource.MultiValuedDoubleFieldSource; +import org.apache.lucene.queries.function.valuesource.MultiValuedFloatFieldSource; +import org.apache.lucene.queries.function.valuesource.MultiValuedIntFieldSource; +import org.apache.lucene.queries.function.valuesource.MultiValuedLongFieldSource; +import org.apache.lucene.queries.function.valuesource.NormValueSource; +import org.apache.lucene.queries.function.valuesource.NumDocsValueSource; +import org.apache.lucene.queries.function.valuesource.PowFloatFunction; +import org.apache.lucene.queries.function.valuesource.ProductFloatFunction; +import org.apache.lucene.queries.function.valuesource.QueryValueSource; +import org.apache.lucene.queries.function.valuesource.RangeMapFloatFunction; +import org.apache.lucene.queries.function.valuesource.ReciprocalFloatFunction; +import org.apache.lucene.queries.function.valuesource.ScaleFloatFunction; +import org.apache.lucene.queries.function.valuesource.SumFloatFunction; +import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TFValueSource; +import org.apache.lucene.queries.function.valuesource.TermFreqValueSource; +import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource; import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.FilterScorer; +import org.apache.lucene.search.FilterWeight; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortedNumericSelector.Type; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.Directory; @@ -674,14 +716,75 @@ public class TestValueSources extends LuceneTestCase { } } - public void testWrappingAsDoubleValues() throws IOException { + public void testWrappingAsDoubleValues() throws Exception { + + class AssertScoreComputedOnceQuery extends Query { + + private final Query in; + + public AssertScoreComputedOnceQuery(Query query) { + in = query; + } + + @Override + public String toString(String field) { + return in.toString(field); + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + return new FilterWeight(in.createWeight(searcher, scoreMode, boost)) { + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + return new FilterScorer(super.scorer(context)) { + int lastDocId = -1; + + @Override + public float score() throws IOException { + assertTrue("shouldn't re-compute score", lastDocId != docID()); + this.lastDocId = docID(); + return super.score(); + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return in.getMaxScore(upTo); + } + }; + } + }; + } + + @Override + public Query rewrite(IndexReader reader) throws IOException { + var rewrite = in.rewrite(reader); + return rewrite == in ? this : new AssertScoreComputedOnceQuery(rewrite); + } + + @Override + public void visit(QueryVisitor visitor) { + in.visit(visitor); + } + + @Override + public boolean equals(Object obj) { + throw new UnsupportedOperationException(); + } + + @Override + public int hashCode() { + return in.hashCode(); + } + } FunctionScoreQuery q = FunctionScoreQuery.boostByValue( - new TermQuery(new Term("f", "t")), + new AssertScoreComputedOnceQuery(new TermQuery(new Term("text", "test"))), new DoubleFieldSource("double").asDoubleValuesSource()); - searcher.createWeight(searcher.rewrite(q), ScoreMode.COMPLETE, 1); + var topFieldDocs = searcher.search(q, 1); + assertTrue(topFieldDocs.scoreDocs.length > 0); // assert that the query has not cached a reference to the IndexSearcher FunctionScoreQuery.MultiplicativeBoostValuesSource source1 = @@ -691,6 +794,51 @@ public class TestValueSources extends LuceneTestCase { assertNull(source2.searcher); } + /** Tests "scorer" key-value inside the Map context to ValueSource */ + public void testScorerContext() throws IOException { + // a VS that yields the score + class ScoreValueSource extends ValueSource { + @Override + public FunctionValues getValues(Map context, LeafReaderContext readerContext) + throws IOException { + var scorer = (Scorable) context.get("scorer"); + assertNotNull(scorer); + return new FloatDocValues(this) { + @Override + public float floatVal(int doc) throws IOException { + assertEquals(doc, scorer.docID()); + return scorer.score(); + } + }; + } + + @Override + public boolean equals(Object o) { + return this == o; // just for a test + } + + @Override + public int hashCode() { + return 0; // just for a test + } + + @Override + public String description() { + return "score"; + } + } + + var plainQ = new TermQuery(new Term("text", "test")); + float origScore = searcher.search(plainQ, 1).scoreDocs[0].score; + + // boosts the score by the value source (which is the score), thus score^2 + var scoreSquaredQ = + FunctionScoreQuery.boostByValue(plainQ, new ScoreValueSource().asDoubleValuesSource()); + var topFieldDocs = searcher.search(scoreSquaredQ, 1); + assertTrue(topFieldDocs.scoreDocs.length > 0); + assertEquals(origScore * origScore, topFieldDocs.scoreDocs[0].score, 0.00001); + } + public void testBuildingFromDoubleValues() throws Exception { DoubleValuesSource dvs = ValueSource.fromDoubleValuesSource(DoubleValuesSource.fromDoubleField("double"))