From e3db2b9906aa8e17141f9978a09ceeee5870657c Mon Sep 17 00:00:00 2001 From: Grant Ingersoll Date: Fri, 14 Aug 2009 12:14:19 +0000 Subject: [PATCH] LUCENE-1790: pass in position information for scoring git-svn-id: https://svn.apache.org/repos/asf/lucene/java/trunk@804178 13f79535-47bb-0310-9956-ffa450edef68 --- .../org/apache/lucene/search/Similarity.java | 6 ++++-- .../search/payloads/AveragePayloadFunction.java | 4 +++- .../payloads/BoostingFunctionTermQuery.java | 4 ++-- .../search/payloads/BoostingNearQuery.java | 16 ++++++++++------ .../search/payloads/MaxPayloadFunction.java | 4 +++- .../search/payloads/MinPayloadFunction.java | 4 +++- .../lucene/search/payloads/PayloadFunction.java | 12 +++++++++--- .../payloads/BoostingFunctionTermQueryTest.java | 2 +- .../search/payloads/TestBoostingNearQuery.java | 9 +++++---- .../search/payloads/TestBoostingTermQuery.java | 6 +++--- 10 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/java/org/apache/lucene/search/Similarity.java b/src/java/org/apache/lucene/search/Similarity.java index c11adfb9244..e8987acb081 100644 --- a/src/java/org/apache/lucene/search/Similarity.java +++ b/src/java/org/apache/lucene/search/Similarity.java @@ -546,7 +546,7 @@ public abstract class Similarity implements Serializable { * @param length The length in the array * @return An implementation dependent float to be used as a scoring factor * - * @deprecated See {@link #scorePayload(int, String, byte[], int, int)} + * @deprecated See {@link #scorePayload(int, String, int, int, byte[], int, int)} */ //TODO: When removing this, set the default value below to return 1. public float scorePayload(String fieldName, byte [] payload, int offset, int length) @@ -564,13 +564,15 @@ public abstract class Similarity implements Serializable { * * @param docId The docId currently being scored. If this value is {@link #NO_DOC_ID_PROVIDED}, then it should be assumed that the PayloadQuery implementation does not provide document information * @param fieldName The fieldName of the term this payload belongs to + * @param start The start position of the payload + * @param end The end position of the payload * @param payload The payload byte array to be scored * @param offset The offset into the payload array * @param length The length in the array * @return An implementation dependent float to be used as a scoring factor * */ - public float scorePayload(int docId, String fieldName, byte [] payload, int offset, int length) + public float scorePayload(int docId, String fieldName, int start, int end, byte [] payload, int offset, int length) { //TODO: When removing the deprecated scorePayload above, set this to return 1 return scorePayload(fieldName, payload, offset, length); diff --git a/src/java/org/apache/lucene/search/payloads/AveragePayloadFunction.java b/src/java/org/apache/lucene/search/payloads/AveragePayloadFunction.java index 0dcc4387388..aa05f63cbb4 100644 --- a/src/java/org/apache/lucene/search/payloads/AveragePayloadFunction.java +++ b/src/java/org/apache/lucene/search/payloads/AveragePayloadFunction.java @@ -1,4 +1,6 @@ package org.apache.lucene.search.payloads; + +import org.apache.lucene.index.Term; /** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -25,7 +27,7 @@ package org.apache.lucene.search.payloads; **/ public class AveragePayloadFunction extends PayloadFunction{ - public float currentScore(int docId, String field, int numPayloadsSeen, float currentScore, float currentPayloadScore) { + public float currentScore(int docId, String field, int start, int end, int numPayloadsSeen, float currentScore, float currentPayloadScore) { return currentPayloadScore + currentScore; } diff --git a/src/java/org/apache/lucene/search/payloads/BoostingFunctionTermQuery.java b/src/java/org/apache/lucene/search/payloads/BoostingFunctionTermQuery.java index 9208d8125e3..d326d02dc33 100644 --- a/src/java/org/apache/lucene/search/payloads/BoostingFunctionTermQuery.java +++ b/src/java/org/apache/lucene/search/payloads/BoostingFunctionTermQuery.java @@ -106,8 +106,8 @@ public class BoostingFunctionTermQuery extends SpanTermQuery implements Payload protected void processPayload(Similarity similarity) throws IOException { if (positions.isPayloadAvailable()) { payload = positions.getPayload(payload, 0); - payloadScore = function.currentScore(doc, term.field(), payloadsSeen, payloadScore, - similarity.scorePayload(doc, term.field(), payload, 0, positions.getPayloadLength())); + payloadScore = function.currentScore(doc, term.field(), spans.start(), spans.end(), payloadsSeen, payloadScore, + similarity.scorePayload(doc, term.field(), spans.start(), spans.end(), payload, 0, positions.getPayloadLength())); payloadsSeen++; } else { diff --git a/src/java/org/apache/lucene/search/payloads/BoostingNearQuery.java b/src/java/org/apache/lucene/search/payloads/BoostingNearQuery.java index 7569fb495c0..1abfdc36978 100644 --- a/src/java/org/apache/lucene/search/payloads/BoostingNearQuery.java +++ b/src/java/org/apache/lucene/search/payloads/BoostingNearQuery.java @@ -102,12 +102,12 @@ public class BoostingNearQuery extends SpanNearQuery implements PayloadQuery { for (int i = 0; i < subSpans.length; i++) { if (subSpans[i] instanceof NearSpansOrdered) { if (((NearSpansOrdered) subSpans[i]).isPayloadAvailable()) { - processPayloads(((NearSpansOrdered) subSpans[i]).getPayload()); + processPayloads(((NearSpansOrdered) subSpans[i]).getPayload(), subSpans[i].start(), subSpans[i].end()); } getPayloads(((NearSpansOrdered) subSpans[i]).getSubSpans()); } else if (subSpans[i] instanceof NearSpansUnordered) { if (((NearSpansUnordered) subSpans[i]).isPayloadAvailable()) { - processPayloads(((NearSpansUnordered) subSpans[i]).getPayload()); + processPayloads(((NearSpansUnordered) subSpans[i]).getPayload(), subSpans[i].start(), subSpans[i].end()); } getPayloads(((NearSpansUnordered) subSpans[i]).getSubSpans()); } @@ -115,15 +115,19 @@ public class BoostingNearQuery extends SpanNearQuery implements PayloadQuery { } /** - * By default, sums the payloads, but can be overridden to do other things. + * By default, uses the {@link PayloadFunction} to score the payloads, but can be overridden to do other things. * * @param payLoads The payloads + * @param start The start position of the span being scored + * @param end The end position of the span being scored + * + * @see {@link org.apache.lucene.search.spans.Spans} */ - protected void processPayloads(Collection payLoads) { + protected void processPayloads(Collection payLoads, int start, int end) { for (Iterator iterator = payLoads.iterator(); iterator.hasNext();) { byte[] thePayload = (byte[]) iterator.next(); - payloadScore = function.currentScore(doc, fieldName, payloadsSeen, payloadScore, - similarity.scorePayload(doc, fieldName, thePayload, 0, thePayload.length)); + payloadScore = function.currentScore(doc, fieldName, start, end, payloadsSeen, payloadScore, + similarity.scorePayload(doc, fieldName, spans.start(), spans.end(), thePayload, 0, thePayload.length)); ++payloadsSeen; } } diff --git a/src/java/org/apache/lucene/search/payloads/MaxPayloadFunction.java b/src/java/org/apache/lucene/search/payloads/MaxPayloadFunction.java index 5565299f86b..b0b50bedb77 100644 --- a/src/java/org/apache/lucene/search/payloads/MaxPayloadFunction.java +++ b/src/java/org/apache/lucene/search/payloads/MaxPayloadFunction.java @@ -1,4 +1,6 @@ package org.apache.lucene.search.payloads; + +import org.apache.lucene.index.Term; /** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -24,7 +26,7 @@ package org.apache.lucene.search.payloads; * **/ public class MaxPayloadFunction extends PayloadFunction{ - public float currentScore(int docId, String field, int numPayloadsSeen, float currentScore, float currentPayloadScore) { + public float currentScore(int docId, String field, int start, int end, int numPayloadsSeen, float currentScore, float currentPayloadScore) { return Math.max(currentPayloadScore, currentScore); } diff --git a/src/java/org/apache/lucene/search/payloads/MinPayloadFunction.java b/src/java/org/apache/lucene/search/payloads/MinPayloadFunction.java index 357d2d75269..cd68469525b 100644 --- a/src/java/org/apache/lucene/search/payloads/MinPayloadFunction.java +++ b/src/java/org/apache/lucene/search/payloads/MinPayloadFunction.java @@ -1,5 +1,7 @@ package org.apache.lucene.search.payloads; +import org.apache.lucene.index.Term; + /** * Calculates the miniumum payload seen @@ -7,7 +9,7 @@ package org.apache.lucene.search.payloads; **/ public class MinPayloadFunction extends PayloadFunction { - public float currentScore(int docId, String field, int numPayloadsSeen, float currentScore, float currentPayloadScore) { + public float currentScore(int docId, String field, int start, int end, int numPayloadsSeen, float currentScore, float currentPayloadScore) { return Math.min(currentPayloadScore, currentScore); } diff --git a/src/java/org/apache/lucene/search/payloads/PayloadFunction.java b/src/java/org/apache/lucene/search/payloads/PayloadFunction.java index 2d0c53e6872..51ae6d5cf55 100644 --- a/src/java/org/apache/lucene/search/payloads/PayloadFunction.java +++ b/src/java/org/apache/lucene/search/payloads/PayloadFunction.java @@ -16,6 +16,8 @@ package org.apache.lucene.search.payloads; * limitations under the License. */ +import org.apache.lucene.index.Term; + import java.io.Serializable; @@ -37,13 +39,17 @@ public abstract class PayloadFunction implements Serializable { /** * Calculate the score up to this point for this doc and field * @param docId The current doc - * @param field The current field + * @param field The field + * @param start The start position of the matching Span + * @param end The end position of the matching Span * @param numPayloadsSeen The number of payloads seen so far * @param currentScore The current score so far * @param currentPayloadScore The score for the current payload - * @return The new current score + * @return The new current Score + * + * @see org.apache.lucene.search.spans.Spans */ - public abstract float currentScore(int docId, String field, int numPayloadsSeen, float currentScore, float currentPayloadScore); + public abstract float currentScore(int docId, String field, int start, int end, int numPayloadsSeen, float currentScore, float currentPayloadScore); /** * Calculate the final score for all the payloads seen so far for this doc/field diff --git a/src/test/org/apache/lucene/search/payloads/BoostingFunctionTermQueryTest.java b/src/test/org/apache/lucene/search/payloads/BoostingFunctionTermQueryTest.java index f6a1e26c892..701bd97f55a 100644 --- a/src/test/org/apache/lucene/search/payloads/BoostingFunctionTermQueryTest.java +++ b/src/test/org/apache/lucene/search/payloads/BoostingFunctionTermQueryTest.java @@ -262,7 +262,7 @@ public class BoostingFunctionTermQueryTest extends LuceneTestCase { static class BoostingSimilarity extends DefaultSimilarity { // TODO: Remove warning after API has been finalized - public float scorePayload(int docId, String fieldName, byte[] payload, int offset, int length) { + public float scorePayload(int docId, String fieldName, int start, int end, byte[] payload, int offset, int length) { //we know it is size 4 here, so ignore the offset/length return payload[0]; } diff --git a/src/test/org/apache/lucene/search/payloads/TestBoostingNearQuery.java b/src/test/org/apache/lucene/search/payloads/TestBoostingNearQuery.java index 5f47f9bed4f..ffe03e67188 100644 --- a/src/test/org/apache/lucene/search/payloads/TestBoostingNearQuery.java +++ b/src/test/org/apache/lucene/search/payloads/TestBoostingNearQuery.java @@ -184,10 +184,11 @@ public class TestBoostingNearQuery extends LuceneTestCase { // must be static for weight serialization tests static class BoostingSimilarity extends DefaultSimilarity { - public float scorePayload(int docId, String fieldName, byte[] payload, int offset, int length) { - return payload[0]; - } - +// TODO: Remove warning after API has been finalized + public float scorePayload(int docId, String fieldName, int start, int end, byte[] payload, int offset, int length) { + //we know it is size 4 here, so ignore the offset/length + return payload[0]; + } //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //Make everything else 1 so we see the effect of the payload //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! diff --git a/src/test/org/apache/lucene/search/payloads/TestBoostingTermQuery.java b/src/test/org/apache/lucene/search/payloads/TestBoostingTermQuery.java index 414b5901430..1c0b1bc6a87 100644 --- a/src/test/org/apache/lucene/search/payloads/TestBoostingTermQuery.java +++ b/src/test/org/apache/lucene/search/payloads/TestBoostingTermQuery.java @@ -206,15 +206,15 @@ public class TestBoostingTermQuery extends LuceneTestCase { CheckHits.checkHitCollector(query, PayloadHelper.NO_PAYLOAD_FIELD, searcher, results); } - // must be static for weight serialization tests + // must be static for weight serialization tests static class BoostingSimilarity extends DefaultSimilarity { - // TODO: Remove warning after API has been finalized - public float scorePayload(int docId, String fieldName, byte[] payload, int offset, int length) { + public float scorePayload(int docId, String fieldName, int start, int end, byte[] payload, int offset, int length) { //we know it is size 4 here, so ignore the offset/length return payload[0]; } + //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! //Make everything else 1 so we see the effect of the payload //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!