From 5c9bcc9e900de027931a86704a8ab5fd4c525d9f Mon Sep 17 00:00:00 2001 From: Alan Woodward Date: Fri, 3 Nov 2017 20:18:20 +0000 Subject: [PATCH] LUCENE-8038: Add PayloadDecoder --- lucene/CHANGES.txt | 4 + .../queries/payloads/PayloadDecoder.java | 47 +++ .../queries/payloads/PayloadScoreQuery.java | 83 +++-- .../TestDeprecatedPayloadScoreQuery.java | 327 ++++++++++++++++++ .../payloads/TestPayloadScoreQuery.java | 35 +- 5 files changed, 450 insertions(+), 46 deletions(-) create mode 100644 lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadDecoder.java create mode 100644 lucene/queries/src/test/org/apache/lucene/queries/payloads/TestDeprecatedPayloadScoreQuery.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 9477a7991a2..a3d6756ee5a 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -42,6 +42,10 @@ API Changes * LUCENE-8017: Weight now exposes a getCacheHelper() method to help query caches determine whether or not a query can be cached. (Alan Woodward) +* LUCENE-8038: Payload factors for scoring in PayloadScoreQuery are now + calculated by a PayloadDecoder, instead of delegating to the Similarity. + (Alan Woodward) + Bug Fixes * LUCENE-7991: KNearestNeighborDocumentClassifier.knnSearch no longer applies diff --git a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadDecoder.java b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadDecoder.java new file mode 100644 index 00000000000..318797099f8 --- /dev/null +++ b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadDecoder.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.queries.payloads; + +import org.apache.lucene.util.BytesRef; + +/** + * Defines a way of converting payloads to float values, for use by {@link PayloadScoreQuery} + */ +public interface PayloadDecoder { + + /** + * Compute a float value based on the doc, position and payload + * @deprecated Use {@link #computePayloadFactor(BytesRef)} - doc and position can be taken + * into account in {@link PayloadFunction#currentScore(int, String, int, int, int, float, float)} + */ + @Deprecated + default float computePayloadFactor(int docID, int startPosition, int endPosition, BytesRef payload) { + return computePayloadFactor(payload); + } + + /** + * Compute a float value for the given payload + */ + float computePayloadFactor(BytesRef payload); + + /** + * A {@link PayloadDecoder} that interprets the bytes of a payload as a float + */ + PayloadDecoder FLOAT_DECODER = bytes -> bytes == null ? 1 : bytes.bytes[bytes.offset]; + +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java index 5df04c8ebb8..883ff0e4369 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java @@ -18,8 +18,8 @@ package org.apache.lucene.queries.payloads; import java.io.IOException; import java.util.Map; -import java.util.Set; import java.util.Objects; +import java.util.Set; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; @@ -29,7 +29,6 @@ import org.apache.lucene.index.TermContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; -import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.Similarity.SimScorer; import org.apache.lucene.search.spans.FilterSpans; @@ -42,36 +41,56 @@ import org.apache.lucene.util.BytesRef; /** * A Query class that uses a {@link PayloadFunction} to modify the score of a wrapped SpanQuery - * - * NOTE: In order to take advantage of this with the default scoring implementation - * ({@link ClassicSimilarity}), you must override {@link ClassicSimilarity#scorePayload(int, int, int, BytesRef)}, - * which returns 1 by default. - * - * @see org.apache.lucene.search.similarities.Similarity.SimScorer#computePayloadFactor(int, int, int, BytesRef) */ public class PayloadScoreQuery extends SpanQuery { private final SpanQuery wrappedQuery; private final PayloadFunction function; + private final PayloadDecoder decoder; private final boolean includeSpanScore; /** * Creates a new PayloadScoreQuery * @param wrappedQuery the query to wrap * @param function a PayloadFunction to use to modify the scores + * @param decoder a PayloadDecoder to convert payloads into float values * @param includeSpanScore include both span score and payload score in the scoring algorithm */ - public PayloadScoreQuery(SpanQuery wrappedQuery, PayloadFunction function, boolean includeSpanScore) { + public PayloadScoreQuery(SpanQuery wrappedQuery, PayloadFunction function, PayloadDecoder decoder, boolean includeSpanScore) { this.wrappedQuery = Objects.requireNonNull(wrappedQuery); this.function = Objects.requireNonNull(function); + this.decoder = decoder; this.includeSpanScore = includeSpanScore; } + /** + * Creates a new PayloadScoreQuery + * @param wrappedQuery the query to wrap + * @param function a PayloadFunction to use to modify the scores + * @param includeSpanScore include both span score and payload score in the scoring algorithm + * @deprecated Use {@link #PayloadScoreQuery(SpanQuery, PayloadFunction, PayloadDecoder, boolean)} + */ + @Deprecated + public PayloadScoreQuery(SpanQuery wrappedQuery, PayloadFunction function, boolean includeSpanScore) { + this(wrappedQuery, function, null, includeSpanScore); + } + /** * Creates a new PayloadScoreQuery that includes the underlying span scores * @param wrappedQuery the query to wrap * @param function a PayloadFunction to use to modify the scores */ + public PayloadScoreQuery(SpanQuery wrappedQuery, PayloadFunction function, PayloadDecoder decoder) { + this(wrappedQuery, function, decoder, true); + } + + /** + * Creates a new PayloadScoreQuery that includes the underlying span scores + * @param wrappedQuery the query to wrap + * @param function a PayloadFunction to use to modify the scores + * @deprecated Use {@link #PayloadScoreQuery(SpanQuery, PayloadFunction, PayloadDecoder)} + */ + @Deprecated public PayloadScoreQuery(SpanQuery wrappedQuery, PayloadFunction function) { this(wrappedQuery, function, true); } @@ -85,7 +104,7 @@ public class PayloadScoreQuery extends SpanQuery { public Query rewrite(IndexReader reader) throws IOException { Query matchRewritten = wrappedQuery.rewrite(reader); if (wrappedQuery != matchRewritten && matchRewritten instanceof SpanQuery) { - return new PayloadScoreQuery((SpanQuery)matchRewritten, function, includeSpanScore); + return new PayloadScoreQuery((SpanQuery)matchRewritten, function, decoder, includeSpanScore); } return super.rewrite(reader); } @@ -120,16 +139,13 @@ public class PayloadScoreQuery extends SpanQuery { private boolean equalsTo(PayloadScoreQuery other) { return wrappedQuery.equals(other.wrappedQuery) && - function.equals(other.function) && (includeSpanScore == other.includeSpanScore); + function.equals(other.function) && (includeSpanScore == other.includeSpanScore) && + Objects.equals(decoder, other.decoder); } @Override public int hashCode() { - int result = classHash(); - result = 31 * result + Objects.hashCode(wrappedQuery); - result = 31 * result + Objects.hashCode(function); - result = 31 * result + Objects.hashCode(includeSpanScore); - return result; + return Objects.hash(wrappedQuery, function, decoder, includeSpanScore); } private class PayloadSpanWeight extends SpanWeight { @@ -157,7 +173,8 @@ public class PayloadScoreQuery extends SpanQuery { if (spans == null) return null; SimScorer docScorer = innerWeight.getSimScorer(context); - PayloadSpans payloadSpans = new PayloadSpans(spans, docScorer); + PayloadSpans payloadSpans = new PayloadSpans(spans, + decoder == null ? new SimilarityPayloadDecoder(docScorer) : decoder); return new PayloadSpanScorer(this, payloadSpans, docScorer); } @@ -192,13 +209,13 @@ public class PayloadScoreQuery extends SpanQuery { private class PayloadSpans extends FilterSpans implements SpanCollector { - private final SimScorer docScorer; + private final PayloadDecoder decoder; public int payloadsSeen; public float payloadScore; - private PayloadSpans(Spans in, SimScorer docScorer) { + private PayloadSpans(Spans in, PayloadDecoder decoder) { super(in); - this.docScorer = docScorer; + this.decoder = decoder; } @Override @@ -215,9 +232,7 @@ public class PayloadScoreQuery extends SpanQuery { @Override public void collectLeaf(PostingsEnum postings, int position, Term term) throws IOException { BytesRef payload = postings.getPayload(); - if (payload == null) - return; - float payloadFactor = docScorer.computePayloadFactor(docID(), in.startPosition(), in.endPosition(), payload); + float payloadFactor = decoder.computePayloadFactor(docID(), in.startPosition(), in.endPosition(), payload); payloadScore = function.currentScore(docID(), getField(), in.startPosition(), in.endPosition(), payloadsSeen, payloadScore, payloadFactor); payloadsSeen++; @@ -262,4 +277,26 @@ public class PayloadScoreQuery extends SpanQuery { } + @Deprecated + private static class SimilarityPayloadDecoder implements PayloadDecoder { + + final Similarity.SimScorer docScorer; + + public SimilarityPayloadDecoder(Similarity.SimScorer docScorer) { + this.docScorer = docScorer; + } + + @Override + public float computePayloadFactor(int docID, int startPosition, int endPosition, BytesRef payload) { + if (payload == null) + return 0; + return docScorer.computePayloadFactor(docID, startPosition, endPosition, payload); + } + + @Override + public float computePayloadFactor(BytesRef payload) { + throw new UnsupportedOperationException(); + } + } + } diff --git a/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestDeprecatedPayloadScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestDeprecatedPayloadScoreQuery.java new file mode 100644 index 00000000000..4693e6bea9f --- /dev/null +++ b/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestDeprecatedPayloadScoreQuery.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.queries.payloads; + +import java.io.IOException; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.MockTokenizer; +import org.apache.lucene.analysis.TokenFilter; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.QueryUtils; +import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.WildcardQuery; +import org.apache.lucene.search.similarities.ClassicSimilarity; +import org.apache.lucene.search.spans.SpanContainingQuery; +import org.apache.lucene.search.spans.SpanMultiTermQueryWrapper; +import org.apache.lucene.search.spans.SpanNearQuery; +import org.apache.lucene.search.spans.SpanOrQuery; +import org.apache.lucene.search.spans.SpanQuery; +import org.apache.lucene.search.spans.SpanTermQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.English; +import org.apache.lucene.util.LuceneTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestDeprecatedPayloadScoreQuery extends LuceneTestCase { + + private static void checkQuery(SpanQuery query, PayloadFunction function, int[] expectedDocs, float[] expectedScores) throws IOException { + checkQuery(query, function, true, expectedDocs, expectedScores); + } + + private static void checkQuery(SpanQuery query, PayloadFunction function, boolean includeSpanScore, int[] expectedDocs, float[] expectedScores) throws IOException { + + assertTrue("Expected docs and scores arrays must be the same length!", expectedDocs.length == expectedScores.length); + + PayloadScoreQuery psq = new PayloadScoreQuery(query, function, includeSpanScore); + TopDocs hits = searcher.search(psq, expectedDocs.length); + + for (int i = 0; i < hits.scoreDocs.length; i++) { + if (i > expectedDocs.length - 1) + fail("Unexpected hit in document " + hits.scoreDocs[i].doc); + if (hits.scoreDocs[i].doc != expectedDocs[i]) + fail("Unexpected hit in document " + hits.scoreDocs[i].doc); + assertEquals("Bad score in document " + expectedDocs[i], expectedScores[i], hits.scoreDocs[i].score, 0.000001); + } + + if (hits.scoreDocs.length > expectedDocs.length) + fail("Unexpected hit in document " + hits.scoreDocs[expectedDocs.length]); + + QueryUtils.check(random(), psq, searcher); + } + + @Test + public void testTermQuery() throws IOException { + + SpanTermQuery q = new SpanTermQuery(new Term("field", "eighteen")); + for (PayloadFunction fn + : new PayloadFunction[]{ new AveragePayloadFunction(), new MaxPayloadFunction(), new MinPayloadFunction() }) { + checkQuery(q, fn, new int[]{ 118, 218, 18 }, + new float[] { 4.0f, 4.0f, 2.0f }); + } + + } + + @Test + public void testOrQuery() throws IOException { + + SpanOrQuery q = new SpanOrQuery(new SpanTermQuery(new Term("field", "eighteen")), + new SpanTermQuery(new Term("field", "nineteen"))); + for (PayloadFunction fn + : new PayloadFunction[]{ new AveragePayloadFunction(), new MaxPayloadFunction(), new MinPayloadFunction() }) { + checkQuery(q, fn, new int[]{ 118, 119, 218, 219, 18, 19 }, + new float[] { 4.0f, 4.0f, 4.0f, 4.0f, 2.0f, 2.0f }); + } + + } + + @Test + public void testNearQuery() throws IOException { + + // 2 4 + // twenty two + // 2 4 4 4 + // one hundred twenty two + + SpanNearQuery q = new SpanNearQuery(new SpanQuery[]{ + new SpanTermQuery(new Term("field", "twenty")), + new SpanTermQuery(new Term("field", "two")) + }, 0, true); + + checkQuery(q, new MaxPayloadFunction(), new int[]{ 22, 122, 222 }, new float[]{ 4.0f, 4.0f, 4.0f }); + checkQuery(q, new MinPayloadFunction(), new int[]{ 122, 222, 22 }, new float[]{ 4.0f, 4.0f, 2.0f }); + checkQuery(q, new AveragePayloadFunction(), new int[] { 122, 222, 22 }, new float[] { 4.0f, 4.0f, 3.0f }); + + } + + @Test + public void testNestedNearQuery() throws Exception { + + // (one OR hundred) NEAR (twenty two) ~ 1 + // 2 4 4 4 + // one hundred twenty two + // two hundred twenty two + + SpanNearQuery q = new SpanNearQuery(new SpanQuery[]{ + new SpanOrQuery(new SpanTermQuery(new Term("field", "one")), new SpanTermQuery(new Term("field", "hundred"))), + new SpanNearQuery(new SpanQuery[]{ + new SpanTermQuery(new Term("field", "twenty")), + new SpanTermQuery(new Term("field", "two")) + }, 0, true) + }, 1, true); + + // check includeSpanScore makes a difference here + searcher.setSimilarity(new MultiplyingSimilarity()); + try { + checkQuery(q, new MaxPayloadFunction(), new int[]{ 122, 222 }, new float[]{ 20.901256561279297f, 17.06580352783203f }); + checkQuery(q, new MinPayloadFunction(), new int[]{ 222, 122 }, new float[]{ 17.06580352783203f, 10.450628280639648f }); + checkQuery(q, new AveragePayloadFunction(), new int[] { 122, 222 }, new float[]{ 19.15948486328125f, 17.06580352783203f }); + checkQuery(q, new MaxPayloadFunction(), false, new int[]{122, 222}, new float[]{4.0f, 4.0f}); + checkQuery(q, new MinPayloadFunction(), false, new int[]{222, 122}, new float[]{4.0f, 2.0f}); + checkQuery(q, new AveragePayloadFunction(), false, new int[]{222, 122}, new float[]{4.0f, 3.666666f}); + } + finally { + searcher.setSimilarity(similarity); + } + + } + + @Test + public void testSpanContainingQuery() throws Exception { + + // twenty WITHIN ((one OR hundred) NEAR two)~2 + SpanContainingQuery q = new SpanContainingQuery( + new SpanNearQuery(new SpanQuery[]{ + new SpanOrQuery(new SpanTermQuery(new Term("field", "one")), new SpanTermQuery(new Term("field", "hundred"))), + new SpanTermQuery(new Term("field", "two")) + }, 2, true), + new SpanTermQuery(new Term("field", "twenty")) + ); + + checkQuery(q, new AveragePayloadFunction(), new int[] { 222, 122 }, new float[]{ 4.0f, 3.666666f }); + checkQuery(q, new MaxPayloadFunction(), new int[]{ 122, 222 }, new float[]{ 4.0f, 4.0f }); + checkQuery(q, new MinPayloadFunction(), new int[]{ 222, 122 }, new float[]{ 4.0f, 2.0f }); + + } + + @Test + public void testEquality() { + SpanQuery sq1 = new SpanTermQuery(new Term("field", "one")); + SpanQuery sq2 = new SpanTermQuery(new Term("field", "two")); + PayloadFunction minFunc = new MinPayloadFunction(); + PayloadFunction maxFunc = new MaxPayloadFunction(); + PayloadScoreQuery query1 = new PayloadScoreQuery(sq1, minFunc, true); + PayloadScoreQuery query2 = new PayloadScoreQuery(sq2, minFunc, true); + PayloadScoreQuery query3 = new PayloadScoreQuery(sq2, maxFunc, true); + PayloadScoreQuery query4 = new PayloadScoreQuery(sq2, maxFunc, false); + PayloadScoreQuery query5 = new PayloadScoreQuery(sq1, minFunc); + + assertEquals(query1, query5); + assertFalse(query1.equals(query2)); + assertFalse(query1.equals(query3)); + assertFalse(query1.equals(query4)); + assertFalse(query2.equals(query3)); + assertFalse(query2.equals(query4)); + assertFalse(query3.equals(query4)); + } + + public void testRewrite() throws IOException { + SpanMultiTermQueryWrapper xyz = new SpanMultiTermQueryWrapper(new WildcardQuery(new Term("field", "xyz*"))); + PayloadScoreQuery psq = new PayloadScoreQuery(xyz, new AveragePayloadFunction(), false); + + // if query wasn't rewritten properly, the query would have failed with "Rewrite first!" + searcher.search(psq, 1); + } + + + private static IndexSearcher searcher; + private static IndexReader reader; + private static Directory directory; + private static JustScorePayloadSimilarity similarity = new JustScorePayloadSimilarity(); + private static byte[] payload2 = new byte[]{2}; + private static byte[] payload4 = new byte[]{4}; + + private static class PayloadAnalyzer extends Analyzer { + @Override + public TokenStreamComponents createComponents(String fieldName) { + Tokenizer result = new MockTokenizer(MockTokenizer.SIMPLE, true); + return new TokenStreamComponents(result, new PayloadFilter(result)); + } + } + + private static class PayloadFilter extends TokenFilter { + + private int numSeen = 0; + private final PayloadAttribute payAtt; + + public PayloadFilter(TokenStream input) { + super(input); + payAtt = addAttribute(PayloadAttribute.class); + } + + @Override + public boolean incrementToken() throws IOException { + boolean result = false; + if (input.incrementToken()) { + if (numSeen % 4 == 0) { + payAtt.setPayload(new BytesRef(payload2)); + } else { + payAtt.setPayload(new BytesRef(payload4)); + } + numSeen++; + result = true; + } + return result; + } + + @Override + public void reset() throws IOException { + super.reset(); + this.numSeen = 0; + } + } + + @BeforeClass + public static void beforeClass() throws Exception { + directory = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), directory, + newIndexWriterConfig(new PayloadAnalyzer()) + .setMergePolicy(NoMergePolicy.INSTANCE) + .setSimilarity(similarity)); + //writer.infoStream = System.out; + for (int i = 0; i < 300; i++) { + Document doc = new Document(); + doc.add(newTextField("field", English.intToEnglish(i), Field.Store.YES)); + String txt = English.intToEnglish(i) +' '+English.intToEnglish(i+1); + doc.add(newTextField("field2", txt, Field.Store.YES)); + writer.addDocument(doc); + } + reader = writer.getReader(); + writer.close(); + + searcher = newSearcher(reader); + searcher.setSimilarity(similarity); + } + + @AfterClass + public static void afterClass() throws Exception { + searcher = null; + reader.close(); + reader = null; + directory.close(); + directory = null; + } + + static class MultiplyingSimilarity extends ClassicSimilarity { + + @Override + public float scorePayload(int docId, int start, int end, BytesRef payload) { + //we know it is size 4 here, so ignore the offset/length + return payload.bytes[payload.offset]; + } + + } + + static class JustScorePayloadSimilarity extends MultiplyingSimilarity { + + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + //Make everything else 1 so we see the effect of the payload + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + @Override + public float lengthNorm(int length) { + return 1; + } + + @Override + public float sloppyFreq(int distance) { + return 1.0f; + } + + @Override + public float tf(float freq) { + return 1.0f; + } + + // idf used for phrase queries + @Override + public Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics[] termStats) { + return Explanation.match(1.0f, "Inexplicable"); + } + + @Override + public Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics termStats) { + return Explanation.match(1.0f, "Inexplicable"); + } + + } + +} diff --git a/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java b/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java index f00cc9b4763..509246e09db 100644 --- a/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java +++ b/lucene/queries/src/test/org/apache/lucene/queries/payloads/TestPayloadScoreQuery.java @@ -62,7 +62,7 @@ public class TestPayloadScoreQuery extends LuceneTestCase { assertTrue("Expected docs and scores arrays must be the same length!", expectedDocs.length == expectedScores.length); - PayloadScoreQuery psq = new PayloadScoreQuery(query, function, includeSpanScore); + PayloadScoreQuery psq = new PayloadScoreQuery(query, function, PayloadDecoder.FLOAT_DECODER, includeSpanScore); TopDocs hits = searcher.search(psq, expectedDocs.length); for (int i = 0; i < hits.scoreDocs.length; i++) { @@ -140,7 +140,7 @@ public class TestPayloadScoreQuery extends LuceneTestCase { }, 1, true); // check includeSpanScore makes a difference here - searcher.setSimilarity(new MultiplyingSimilarity()); + searcher.setSimilarity(new ClassicSimilarity()); try { checkQuery(q, new MaxPayloadFunction(), new int[]{ 122, 222 }, new float[]{ 20.901256561279297f, 17.06580352783203f }); checkQuery(q, new MinPayloadFunction(), new int[]{ 222, 122 }, new float[]{ 17.06580352783203f, 10.450628280639648f }); @@ -179,11 +179,11 @@ public class TestPayloadScoreQuery extends LuceneTestCase { SpanQuery sq2 = new SpanTermQuery(new Term("field", "two")); PayloadFunction minFunc = new MinPayloadFunction(); PayloadFunction maxFunc = new MaxPayloadFunction(); - PayloadScoreQuery query1 = new PayloadScoreQuery(sq1, minFunc, true); - PayloadScoreQuery query2 = new PayloadScoreQuery(sq2, minFunc, true); - PayloadScoreQuery query3 = new PayloadScoreQuery(sq2, maxFunc, true); - PayloadScoreQuery query4 = new PayloadScoreQuery(sq2, maxFunc, false); - PayloadScoreQuery query5 = new PayloadScoreQuery(sq1, minFunc); + PayloadScoreQuery query1 = new PayloadScoreQuery(sq1, minFunc, PayloadDecoder.FLOAT_DECODER, true); + PayloadScoreQuery query2 = new PayloadScoreQuery(sq2, minFunc, PayloadDecoder.FLOAT_DECODER, true); + PayloadScoreQuery query3 = new PayloadScoreQuery(sq2, maxFunc, PayloadDecoder.FLOAT_DECODER, true); + PayloadScoreQuery query4 = new PayloadScoreQuery(sq2, maxFunc, PayloadDecoder.FLOAT_DECODER, false); + PayloadScoreQuery query5 = new PayloadScoreQuery(sq1, minFunc, PayloadDecoder.FLOAT_DECODER); assertEquals(query1, query5); assertFalse(query1.equals(query2)); @@ -195,8 +195,8 @@ public class TestPayloadScoreQuery extends LuceneTestCase { } public void testRewrite() throws IOException { - SpanMultiTermQueryWrapper xyz = new SpanMultiTermQueryWrapper(new WildcardQuery(new Term("field", "xyz*"))); - PayloadScoreQuery psq = new PayloadScoreQuery(xyz, new AveragePayloadFunction(), false); + SpanMultiTermQueryWrapper xyz = new SpanMultiTermQueryWrapper<>(new WildcardQuery(new Term("field", "xyz*"))); + PayloadScoreQuery psq = new PayloadScoreQuery(xyz, new AveragePayloadFunction(), PayloadDecoder.FLOAT_DECODER, false); // if query wasn't rewritten properly, the query would have failed with "Rewrite first!" searcher.search(psq, 1); @@ -255,8 +255,7 @@ public class TestPayloadScoreQuery extends LuceneTestCase { directory = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), directory, newIndexWriterConfig(new PayloadAnalyzer()) - .setMergePolicy(NoMergePolicy.INSTANCE) - .setSimilarity(similarity)); + .setMergePolicy(NoMergePolicy.INSTANCE)); //writer.infoStream = System.out; for (int i = 0; i < 300; i++) { Document doc = new Document(); @@ -269,7 +268,7 @@ public class TestPayloadScoreQuery extends LuceneTestCase { writer.close(); searcher = newSearcher(reader); - searcher.setSimilarity(similarity); + searcher.setSimilarity(new JustScorePayloadSimilarity()); } @AfterClass @@ -281,17 +280,7 @@ public class TestPayloadScoreQuery extends LuceneTestCase { directory = null; } - static class MultiplyingSimilarity extends ClassicSimilarity { - - @Override - public float scorePayload(int docId, int start, int end, BytesRef payload) { - //we know it is size 4 here, so ignore the offset/length - return payload.bytes[payload.offset]; - } - - } - - static class JustScorePayloadSimilarity extends MultiplyingSimilarity { + static class JustScorePayloadSimilarity extends ClassicSimilarity { //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //Make everything else 1 so we see the effect of the payload