From f4bdb44fd8c6d13845355cefffccb963941bfa8f Mon Sep 17 00:00:00 2001 From: Erik Hatcher Date: Thu, 27 Apr 2017 15:08:51 -0400 Subject: [PATCH] LUCENE-7481: Fix PayloadScoreQuery rewrite --- .../lucene/queries/payloads/PayloadScoreQuery.java | 12 ++++++++++++ .../queries/payloads/TestPayloadScoreQuery.java | 13 +++++++++++++ 2 files changed, 25 insertions(+) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java index 0cddd00a24e..43d33fd9d40 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java @@ -21,12 +21,14 @@ import java.util.Map; import java.util.Set; import java.util.Objects; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.Similarity.SimScorer; @@ -80,6 +82,16 @@ public class PayloadScoreQuery extends SpanQuery { return wrappedQuery.getField(); } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query matchRewritten = wrappedQuery.rewrite(reader); + if (wrappedQuery != matchRewritten && matchRewritten instanceof SpanQuery) { + return new PayloadScoreQuery((SpanQuery)matchRewritten, function, includeSpanScore); + } + return super.rewrite(reader); + } + + @Override public String toString(String field) { return "PayloadScoreQuery[" + wrappedQuery.toString(field) + "; " + function.getClass().getSimpleName() + "; " + includeSpanScore + "]"; diff --git a/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java index fa387762b6b..50a62112b3b 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java @@ -17,6 +17,8 @@ package org.apache.lucene.queries.payloads; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.MockTokenizer; @@ -37,8 +39,10 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.QueryUtils; import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.spans.SpanContainingQuery; +import org.apache.lucene.search.spans.SpanMultiTermQueryWrapper; import org.apache.lucene.search.spans.SpanNearQuery; import org.apache.lucene.search.spans.SpanOrQuery; import org.apache.lucene.search.spans.SpanQuery; @@ -193,6 +197,15 @@ public class TestPayloadScoreQuery extends LuceneTestCase { assertFalse(query3.equals(query4)); } + public void testRewrite() throws IOException { + SpanMultiTermQueryWrapper xyz = new SpanMultiTermQueryWrapper(new WildcardQuery(new Term("field", "xyz*"))); + PayloadScoreQuery psq = new PayloadScoreQuery(xyz, new AveragePayloadFunction(), false); + + // if query wasn't rewritten properly, the query would have failed with "Rewrite first!" + searcher.search(psq, 1); + } + + private static IndexSearcher searcher; private static IndexReader reader; private static Directory directory;