SOLR-11199: Support OR queries in the PayloadScoreParser and a sum function

This commit is contained in:
Varun Thacker 2017-08-08 14:52:57 -07:00
parent ea85543ace
commit 7ed0a40eaa
6 changed files with 98 additions and 9 deletions

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.payloads;
/**
* Calculate the final score as the sum of scores of all payloads seen.
* <p>
* Is thread safe and completely reusable.
*
**/
public class SumPayloadFunction extends PayloadFunction {
@Override
public float currentScore(int docId, String field, int start, int end, int numPayloadsSeen, float currentScore, float currentPayloadScore) {
return currentPayloadScore + currentScore;
}
@Override
public float docScore(int docId, String field, int numPayloadsSeen, float payloadScore) {
return numPayloadsSeen > 0 ? payloadScore : 1;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + this.getClass().hashCode();
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
return true;
}
}

View File

@ -65,6 +65,11 @@ New Features
* SOLR-11126: Node level health check handler (Anshum Gupta)
* SOLR-11199: Payloads supports an "operator" param. Supported operators are 'or', "phrase" ( default ).
A new "sum" function is also added. Example :
{!payload_score f=payload_field func=sum operator=or}A B C" (Varun Thacker)
Bug Fixes
----------------------

View File

@ -37,11 +37,12 @@ import org.apache.solr.util.PayloadUtils;
* <br>Other parameters:
* <br><code>f</code>, the field (required)
* <br><code>func</code>, payload function (min, max, or average; required)
* <br><code>includeSpanScore</code>, multiple payload function result by similarity score or not (default: false)
* <br><code>includeSpanScore</code>, multiply payload function result by similarity score or not (default: false)
* <br>Example: <code>{!payload_score f=weighted_terms_dpf}Foo Bar</code> creates a SpanNearQuery with "Foo" followed by "Bar"
*/
public class PayloadScoreQParserPlugin extends QParserPlugin {
public static final String NAME = "payload_score";
public static final String DEFAULT_OPERATOR = "phrase";
@Override
public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
@ -51,6 +52,10 @@ public class PayloadScoreQParserPlugin extends QParserPlugin {
String field = localParams.get(QueryParsing.F);
String value = localParams.get(QueryParsing.V);
String func = localParams.get("func");
String operator = localParams.get("operator", DEFAULT_OPERATOR);
if (!(operator.equalsIgnoreCase(DEFAULT_OPERATOR) || operator.equalsIgnoreCase("or"))) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Supported operators are : or , phrase");
}
boolean includeSpanScore = localParams.getBool("includeSpanScore", false);
if (field == null) {
@ -63,9 +68,9 @@ public class PayloadScoreQParserPlugin extends QParserPlugin {
FieldType ft = req.getCore().getLatestSchema().getFieldType(field);
Analyzer analyzer = ft.getQueryAnalyzer();
SpanQuery query = null;
SpanQuery query;
try {
query = PayloadUtils.createSpanQuery(field, value, analyzer);
query = PayloadUtils.createSpanQuery(field, value, analyzer, operator);
} catch (IOException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,e);
}

View File

@ -33,12 +33,15 @@ import org.apache.lucene.queries.payloads.AveragePayloadFunction;
import org.apache.lucene.queries.payloads.MaxPayloadFunction;
import org.apache.lucene.queries.payloads.MinPayloadFunction;
import org.apache.lucene.queries.payloads.PayloadFunction;
import org.apache.lucene.queries.payloads.SumPayloadFunction;
import org.apache.lucene.search.spans.SpanNearQuery;
import org.apache.lucene.search.spans.SpanOrQuery;
import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.search.spans.SpanTermQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.analysis.TokenizerChain;
import org.apache.solr.schema.FieldType;
import org.apache.solr.search.PayloadScoreQParserPlugin;
public class PayloadUtils {
public static String getPayloadEncoder(FieldType fieldType) {
@ -95,15 +98,22 @@ public class PayloadUtils {
if ("average".equals(func)) {
payloadFunction = new AveragePayloadFunction();
}
if ("sum".equals(func)) {
payloadFunction = new SumPayloadFunction();
}
return payloadFunction;
}
public static SpanQuery createSpanQuery(String field, String value, Analyzer analyzer) throws IOException {
return createSpanQuery(field, value, analyzer, PayloadScoreQParserPlugin.DEFAULT_OPERATOR);
}
/**
* The generated SpanQuery will be either a SpanTermQuery or an ordered, zero slop SpanNearQuery, depending
* on how many tokens are emitted.
*/
public static SpanQuery createSpanQuery(String field, String value, Analyzer analyzer) throws IOException {
public static SpanQuery createSpanQuery(String field, String value, Analyzer analyzer, String operator) throws IOException {
// adapted this from QueryBuilder.createSpanQuery (which isn't currently public) and added reset(), end(), and close() calls
List<SpanTermQuery> terms = new ArrayList<>();
try (TokenStream in = analyzer.tokenStream(field, value)) {
@ -121,9 +131,11 @@ public class PayloadUtils {
query = null;
} else if (terms.size() == 1) {
query = terms.get(0);
} else if (operator != null && operator.equalsIgnoreCase("or")) {
query = new SpanOrQuery(terms.toArray(new SpanTermQuery[terms.size()]));
} else {
query = new SpanNearQuery(terms.toArray(new SpanTermQuery[terms.size()]), 0, true);
}
query = new SpanNearQuery(terms.toArray(new SpanTermQuery[terms.size()]), 0, true);
}
return query;
}
}

View File

@ -35,7 +35,6 @@ public class TestPayloadScoreQParserPlugin extends SolrTestCaseJ4 {
@Test
public void test() {
clearIndex();
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf v=B func=min}"), "//float[@name='score']='2.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf v=mult func=min}"), "//float[@name='score']='50.0'");
@ -47,6 +46,15 @@ public class TestPayloadScoreQParserPlugin extends SolrTestCaseJ4 {
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=average}B C"), "//float[@name='score']='2.5'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=max}A B C"), "//float[@name='score']='3.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=sum}A B C"), "//float[@name='score']='6.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=sum operator=or}A C"), "//float[@name='score']='4.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=sum operator=or}A"), "//float[@name='score']='1.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=sum operator=or}foo"), "//result[@numFound='0']");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=max operator=or}A C"), "//float[@name='score']='3.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=min operator=or}A x"), "//float[@name='score']='1.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf func=average operator=or}A C"), "//float[@name='score']='2.0'");
// TODO: fix this includeSpanScore test to be less brittle - score result is score of "A" (via BM25) multipled by 1.0 (payload value)
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf v=A func=min}"), "//float[@name='score']='1.0'");
assertQ(req("fl","*,score", "q", "{!payload_score f=vals_dpf v=A func=min includeSpanScore=true}"), "//float[@name='score']='0.2876821'");

View File

@ -657,7 +657,11 @@ This parser accepts the following parameters:
The field to use (required).
`func`::
Payload function: min, max, average (required).
Payload function: min, max, average, sum (required).
`operator`::
Search operator: or , phrase ( default ) (optional). This defines if the search query should be an OR
query or a phrase query
`includeSpanScore`::
If `true`, multiples computed payload factor by the score of the original query. If `false`, the default, the computed payload factor is the score.