SOLR-6088: Add query re-ranking with the ReRankingQParserPlugin

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1600720 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Joel Bernstein 2014-06-05 18:28:30 +00:00
parent 529edff3ba
commit 77edd9d5ac
9 changed files with 790 additions and 102 deletions

View File

@ -95,7 +95,7 @@ public class DebugComponent extends SearchComponent
}
NamedList stdinfo = SolrPluginUtils.doStandardDebug( rb.req,
rb.getQueryString(), rb.getQuery(), results, rb.isDebugQuery(), rb.isDebugResults());
rb.getQueryString(), rb.wrap(rb.getQuery()), results, rb.isDebugQuery(), rb.isDebugResults());
NamedList info = rb.getDebugInfo();
if( info == null ) {
@ -234,7 +234,7 @@ public class DebugComponent extends SearchComponent
}
// No responses were received from shards. Show local query info.
SolrPluginUtils.doStandardQueryDebug(
rb.req, rb.getQueryString(), rb.getQuery(), rb.isDebugQuery(), info);
rb.req, rb.getQueryString(), rb.wrap(rb.getQuery()), rb.isDebugQuery(), info);
if (rb.isDebugQuery() && rb.getQparser() != null) {
rb.getQparser().addDebugInfo(info);
}

View File

@ -17,9 +17,12 @@
package org.apache.solr.handler.component;
import com.carrotsearch.hppc.IntOpenHashSet;
import com.carrotsearch.hppc.IntIntOpenHashMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.DocsEnum;
import org.apache.lucene.index.Fields;
@ -80,6 +83,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -98,6 +102,10 @@ public class QueryElevationComponent extends SearchComponent implements SolrCore
static final String CONFIG_FILE = "config-file";
static final String EXCLUDE = "exclude";
public static final String BOOSTED = "BOOSTED";
public static final String BOOSTED_DOCIDS = "BOOSTED_DOCIDS";
public static final String BOOSTED_PRIORITY = "BOOSTED_PRIORITY";
public static final String EXCLUDED = "EXCLUDED";
// Runtime param -- should be in common?
@ -401,6 +409,7 @@ public class QueryElevationComponent extends SearchComponent implements SolrCore
if (booster != null) {
rb.req.getContext().put(BOOSTED, booster.ids);
rb.req.getContext().put(BOOSTED_PRIORITY, booster.priority);
// Change the query to insert forced documents
if (exclusive == true) {
@ -519,6 +528,67 @@ public class QueryElevationComponent extends SearchComponent implements SolrCore
return null;
}
public static IntIntOpenHashMap getBoostDocs(SolrIndexSearcher indexSearcher, Map<BytesRef, Integer>boosted, Map context) throws IOException {
IntIntOpenHashMap boostDocs = null;
if(boosted != null) {
//First see if it's already in the request context. Could have been put there
//by another caller.
if(context != null) {
boostDocs = (IntIntOpenHashMap)context.get(BOOSTED_DOCIDS);
}
if(boostDocs != null) {
return boostDocs;
}
//Not in the context yet so load it.
SchemaField idField = indexSearcher.getSchema().getUniqueKeyField();
String fieldName = idField.getName();
HashSet<BytesRef> localBoosts = new HashSet(boosted.size()*2);
Iterator<BytesRef> boostedIt = boosted.keySet().iterator();
while(boostedIt.hasNext()) {
localBoosts.add(boostedIt.next());
}
boostDocs = new IntIntOpenHashMap(boosted.size()*2);
List<AtomicReaderContext>leaves = indexSearcher.getTopReaderContext().leaves();
TermsEnum termsEnum = null;
DocsEnum docsEnum = null;
for(AtomicReaderContext leaf : leaves) {
AtomicReader reader = leaf.reader();
int docBase = leaf.docBase;
Bits liveDocs = reader.getLiveDocs();
Terms terms = reader.terms(fieldName);
termsEnum = terms.iterator(termsEnum);
Iterator<BytesRef> it = localBoosts.iterator();
while(it.hasNext()) {
BytesRef ref = it.next();
if(termsEnum.seekExact(ref)) {
docsEnum = termsEnum.docs(liveDocs, docsEnum);
int doc = docsEnum.nextDoc();
if(doc != DocsEnum.NO_MORE_DOCS) {
//Found the document.
int p = boosted.get(ref);
boostDocs.put(doc+docBase, p);
it.remove();
}
}
}
}
}
if(context != null) {
context.put(BOOSTED_DOCIDS, boostDocs);
}
return boostDocs;
}
@Override
public void process(ResponseBuilder rb) throws IOException {
// Do nothing -- the real work is modifying the input query

View File

@ -427,7 +427,7 @@ public class ResponseBuilder
return cmd;
}
private Query wrap(Query q) {
Query wrap(Query q) {
if(this.rankQuery != null) {
return this.rankQuery.wrap(q);
} else {

View File

@ -19,32 +19,20 @@ package org.apache.solr.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.DocsEnum;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FilterCollector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.solr.common.SolrException;
@ -56,15 +44,14 @@ import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.SolrRequestInfo;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.schema.TrieFloatField;
import org.apache.solr.schema.TrieIntField;
import org.apache.solr.schema.TrieLongField;
import com.carrotsearch.hppc.FloatArrayList;
import com.carrotsearch.hppc.IntOpenHashSet;
import com.carrotsearch.hppc.cursors.IntCursor;
import com.carrotsearch.hppc.IntIntOpenHashMap;
import com.carrotsearch.hppc.cursors.IntIntCursor;
/**
@ -141,7 +128,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
private String min;
private boolean needsScores = true;
private int nullPolicy;
private Set<String> boosted;
private Map<BytesRef, Integer> boosted;
public static final int NULL_POLICY_IGNORE = 0;
public static final int NULL_POLICY_COLLAPSE = 1;
public static final int NULL_POLICY_EXPAND = 2;
@ -168,24 +155,9 @@ public class CollapsingQParserPlugin extends QParserPlugin {
}
public int hashCode() {
/*
* Checking for boosted here because the request context will not have the elevated docs
* until after the query is constructed. So to be sure there are no elevated docs in the query
* while checking the cache we must check the request context during the call to hashCode().
*/
if(this.boosted == null) {
SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
if(info != null) {
this.boosted = (Set<String>)info.getReq().getContext().get(QueryElevationComponent.BOOSTED);
}
}
int hashCode = field.hashCode();
hashCode = max!=null ? hashCode+max.hashCode():hashCode;
hashCode = min!=null ? hashCode+min.hashCode():hashCode;
hashCode = boosted!=null ? hashCode+boosted.hashCode():hashCode;
hashCode = hashCode+nullPolicy;
hashCode = hashCode*((1+Float.floatToIntBits(this.getBoost()))*31);
return hashCode;
@ -199,7 +171,6 @@ public class CollapsingQParserPlugin extends QParserPlugin {
((this.max == null && c.max == null) || (this.max != null && c.max != null && this.max.equals(c.max))) &&
((this.min == null && c.min == null) || (this.min != null && c.min != null && this.min.equals(c.min))) &&
this.nullPolicy == c.nullPolicy &&
((this.boosted == null && c.boosted == null) || (this.boosted == c.boosted)) &&
this.getBoost()==c.getBoost()) {
return true;
}
@ -236,47 +207,11 @@ public class CollapsingQParserPlugin extends QParserPlugin {
} else {
throw new IOException("Invalid nullPolicy:"+nPolicy);
}
}
private IntOpenHashSet getBoostDocs(SolrIndexSearcher indexSearcher, Set<String> boosted) throws IOException {
IntOpenHashSet boostDocs = null;
if(boosted != null) {
SchemaField idField = indexSearcher.getSchema().getUniqueKeyField();
String fieldName = idField.getName();
HashSet<BytesRef> localBoosts = new HashSet(boosted.size()*2);
Iterator<String> boostedIt = boosted.iterator();
while(boostedIt.hasNext()) {
localBoosts.add(new BytesRef(boostedIt.next()));
}
boostDocs = new IntOpenHashSet(boosted.size()*2);
List<AtomicReaderContext>leaves = indexSearcher.getTopReaderContext().leaves();
TermsEnum termsEnum = null;
DocsEnum docsEnum = null;
for(AtomicReaderContext leaf : leaves) {
AtomicReader reader = leaf.reader();
int docBase = leaf.docBase;
Bits liveDocs = reader.getLiveDocs();
Terms terms = reader.terms(fieldName);
termsEnum = terms.iterator(termsEnum);
Iterator<BytesRef> it = localBoosts.iterator();
while(it.hasNext()) {
BytesRef ref = it.next();
if(termsEnum.seekExact(ref)) {
docsEnum = termsEnum.docs(liveDocs, docsEnum);
int doc = docsEnum.nextDoc();
if(doc != DocsEnum.NO_MORE_DOCS) {
//Found the document.
boostDocs.add(doc+docBase);
it.remove();
}
}
}
}
}
private IntIntOpenHashMap getBoostDocs(SolrIndexSearcher indexSearcher, Map<BytesRef, Integer> boosted, Map context) throws IOException {
IntIntOpenHashMap boostDocs = QueryElevationComponent.getBoostDocs(indexSearcher, boosted, context);
return boostDocs;
}
@ -284,8 +219,6 @@ public class CollapsingQParserPlugin extends QParserPlugin {
try {
SolrIndexSearcher searcher = (SolrIndexSearcher)indexSearcher;
IndexSchema schema = searcher.getSchema();
SchemaField schemaField = schema.getField(this.field);
SortedDocValues docValues = null;
FunctionQuery funcQuery = null;
@ -332,14 +265,22 @@ public class CollapsingQParserPlugin extends QParserPlugin {
int maxDoc = searcher.maxDoc();
int leafCount = searcher.getTopReaderContext().leaves().size();
if(this.boosted == null) {
SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
if(info != null) {
this.boosted = (Set<String>)info.getReq().getContext().get(QueryElevationComponent.BOOSTED);
}
//Deal with boosted docs.
//We have to deal with it here rather then the constructor because
//because the QueryElevationComponent runs after the Queries are constructed.
IntIntOpenHashMap boostDocs = null;
Map context = null;
SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
if(info != null) {
context = info.getReq().getContext();
}
IntOpenHashSet boostDocs = getBoostDocs(searcher, this.boosted);
if(this.boosted == null && context != null) {
this.boosted = (Map<BytesRef, Integer>)context.get(QueryElevationComponent.BOOSTED_PRIORITY);
}
boostDocs = getBoostDocs(searcher, this.boosted, context);
if (this.min != null || this.max != null) {
@ -442,14 +383,14 @@ public class CollapsingQParserPlugin extends QParserPlugin {
private float nullScore = -Float.MAX_VALUE;
private int nullDoc;
private FloatArrayList nullScores;
private IntOpenHashSet boostDocs;
private IntIntOpenHashMap boostDocs;
private int[] boostOrds;
public CollapsingScoreCollector(int maxDoc,
int segments,
SortedDocValues values,
int nullPolicy,
IntOpenHashSet boostDocs) {
IntIntOpenHashMap boostDocs) {
this.maxDoc = maxDoc;
this.contexts = new AtomicReaderContext[segments];
this.collapsedSet = new FixedBitSet(maxDoc);
@ -457,10 +398,10 @@ public class CollapsingQParserPlugin extends QParserPlugin {
if(this.boostDocs != null) {
//Set the elevated docs now.
IntOpenHashSet boostG = new IntOpenHashSet();
Iterator<IntCursor> it = this.boostDocs.iterator();
Iterator<IntIntCursor> it = this.boostDocs.iterator();
while(it.hasNext()) {
IntCursor cursor = it.next();
int i = cursor.value;
IntIntCursor cursor = it.next();
int i = cursor.key;
this.collapsedSet.set(i);
int ord = values.getOrd(i);
if(ord > -1) {
@ -558,7 +499,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
if(ord > -1) {
dummy.score = scores[ord];
} else if(this.boostDocs != null && boostDocs.contains(docId)) {
} else if(this.boostDocs != null && boostDocs.containsKey(docId)) {
//Elevated docs don't need a score.
dummy.score = 0F;
} else if (nullPolicy == CollapsingPostFilter.NULL_POLICY_COLLAPSE) {
@ -595,7 +536,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
private FieldValueCollapse fieldValueCollapse;
private boolean needsScores;
private IntOpenHashSet boostDocs;
private IntIntOpenHashMap boostDocs;
public CollapsingFieldValueCollector(int maxDoc,
int segments,
@ -605,7 +546,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
boolean max,
boolean needsScores,
FieldType fieldType,
IntOpenHashSet boostDocs,
IntIntOpenHashMap boostDocs,
FunctionQuery funcQuery, IndexSearcher searcher) throws IOException{
this.maxDoc = maxDoc;
@ -674,7 +615,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
int ord = values.getOrd(docId);
if(ord > -1) {
dummy.score = scores[ord];
} else if (boostDocs != null && boostDocs.contains(docId)) {
} else if (boostDocs != null && boostDocs.containsKey(docId)) {
//Its an elevated doc so no score is needed
dummy.score = 0F;
} else if (nullPolicy == CollapsingPostFilter.NULL_POLICY_COLLAPSE) {
@ -711,7 +652,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
protected float nullScore;
protected float[] scores;
protected FixedBitSet collapsedSet;
protected IntOpenHashSet boostDocs;
protected IntIntOpenHashMap boostDocs;
protected int[] boostOrds;
protected int nullDoc = -1;
protected boolean needsScores;
@ -726,7 +667,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
int nullPolicy,
boolean max,
boolean needsScores,
IntOpenHashSet boostDocs,
IntIntOpenHashMap boostDocs,
SortedDocValues values) {
this.field = field;
this.nullPolicy = nullPolicy;
@ -736,10 +677,10 @@ public class CollapsingQParserPlugin extends QParserPlugin {
this.boostDocs = boostDocs;
if(this.boostDocs != null) {
IntOpenHashSet boostG = new IntOpenHashSet();
Iterator<IntCursor> it = boostDocs.iterator();
Iterator<IntIntCursor> it = boostDocs.iterator();
while(it.hasNext()) {
IntCursor cursor = it.next();
int i = cursor.value;
IntIntCursor cursor = it.next();
int i = cursor.key;
this.collapsedSet.set(i);
int ord = values.getOrd(i);
if(ord > -1) {
@ -802,7 +743,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
int[] ords,
boolean max,
boolean needsScores,
IntOpenHashSet boostDocs, SortedDocValues values) throws IOException {
IntIntOpenHashMap boostDocs, SortedDocValues values) throws IOException {
super(maxDoc, field, nullPolicy, max, needsScores, boostDocs, values);
this.ords = ords;
this.ordVals = new int[ords.length];
@ -870,7 +811,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
int[] ords,
boolean max,
boolean needsScores,
IntOpenHashSet boostDocs, SortedDocValues values) throws IOException {
IntIntOpenHashMap boostDocs, SortedDocValues values) throws IOException {
super(maxDoc, field, nullPolicy, max, needsScores, boostDocs, values);
this.ords = ords;
this.ordVals = new long[ords.length];
@ -939,7 +880,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
int[] ords,
boolean max,
boolean needsScores,
IntOpenHashSet boostDocs, SortedDocValues values) throws IOException {
IntIntOpenHashMap boostDocs, SortedDocValues values) throws IOException {
super(maxDoc, field, nullPolicy, max, needsScores, boostDocs, values);
this.ords = ords;
this.ordVals = new float[ords.length];
@ -1013,7 +954,7 @@ public class CollapsingQParserPlugin extends QParserPlugin {
int[] ords,
boolean max,
boolean needsScores,
IntOpenHashSet boostDocs,
IntIntOpenHashMap boostDocs,
FunctionQuery funcQuery, IndexSearcher searcher, SortedDocValues values) throws IOException {
super(maxDoc, null, nullPolicy, max, needsScores, boostDocs, values);
this.valueSource = funcQuery.getValueSource();

View File

@ -60,7 +60,8 @@ public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrI
BlockJoinChildQParserPlugin.NAME, BlockJoinChildQParserPlugin.class,
CollapsingQParserPlugin.NAME, CollapsingQParserPlugin.class,
SimpleQParserPlugin.NAME, SimpleQParserPlugin.class,
ComplexPhraseQParserPlugin.NAME, ComplexPhraseQParserPlugin.class
ComplexPhraseQParserPlugin.NAME, ComplexPhraseQParserPlugin.class,
ReRankQParserPlugin.NAME, ReRankQParserPlugin.class
};
/** return a {@link QParser} */

View File

@ -20,7 +20,9 @@ package org.apache.solr.search;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.IndexSearcher;
import org.apache.solr.handler.component.MergeStrategy;
import java.io.IOException;
import java.io.IOException;
@ -28,7 +30,7 @@ import java.io.IOException;
* <b>Note: This API is experimental and may change in non backward-compatible ways in the future</b>
**/
public abstract class RankQuery extends Query {
public abstract class RankQuery extends ExtendedQueryBase {
public abstract TopDocsCollector getTopDocsCollector(int len, SolrIndexSearcher.QueryCommand cmd, IndexSearcher searcher) throws IOException;
public abstract MergeStrategy getMergeStrategy();

View File

@ -0,0 +1,351 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import com.carrotsearch.hppc.IntIntOpenHashMap;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.search.QueryRescorer;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.MergeStrategy;
import org.apache.solr.handler.component.QueryElevationComponent;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.ScoreDoc;
import com.carrotsearch.hppc.IntFloatOpenHashMap;
import org.apache.lucene.util.Bits;
import org.apache.solr.request.SolrRequestInfo;
import java.io.IOException;
import java.util.Map;
import java.util.Arrays;
import java.util.Comparator;
/*
*
* Syntax: q=*:*&rq={!rerank reRankQuery=$rqq reRankDocs=300 reRankWeight=3}
*
*/
public class ReRankQParserPlugin extends QParserPlugin {
public static final String NAME = "rerank";
private static Query defaultQuery = new MatchAllDocsQuery();
public void init(NamedList args) {
}
public QParser createParser(String query, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
return new ReRankQParser(query, localParams, params, req);
}
private class ReRankQParser extends QParser {
public ReRankQParser(String query, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(query, localParams, params, req);
}
public Query parse() throws SyntaxError {
String reRankQueryString = localParams.get("reRankQuery");
QParser reRankParser = QParser.getParser(reRankQueryString, null, req);
Query reRankQuery = reRankParser.parse();
int reRankDocs = localParams.getInt("reRankDocs", 200);
double reRankWeight = localParams.getDouble("reRankWeight",2.0d);
int start = params.getInt(CommonParams.START,0);
int rows = params.getInt(CommonParams.ROWS,10);
// This enusres that reRankDocs >= docs needed to satisfy the result set.
reRankDocs = Math.max(start+rows, reRankDocs);
return new ReRankQuery(reRankQuery, reRankDocs, reRankWeight);
}
}
private class ReRankQuery extends RankQuery {
private Query mainQuery = defaultQuery;
private Query reRankQuery;
private int reRankDocs;
private double reRankWeight;
private Map<BytesRef, Integer> boostedPriority;
public int hashCode() {
return mainQuery.hashCode()+reRankQuery.hashCode()+(int)reRankWeight+reRankDocs+(int)getBoost();
}
public boolean equals(Object o) {
if(o instanceof ReRankQuery) {
ReRankQuery rrq = (ReRankQuery)o;
return (mainQuery.equals(rrq.mainQuery) &&
reRankQuery.equals(rrq.reRankQuery) &&
reRankWeight == rrq.reRankWeight &&
reRankDocs == rrq.reRankDocs &&
getBoost() == rrq.getBoost());
}
return false;
}
public ReRankQuery(Query reRankQuery, int reRankDocs, double reRankWeight) {
this.reRankQuery = reRankQuery;
this.reRankDocs = reRankDocs;
this.reRankWeight = reRankWeight;
}
public RankQuery wrap(Query _mainQuery) {
if(_mainQuery != null){
this.mainQuery = _mainQuery;
}
return this;
}
public MergeStrategy getMergeStrategy() {
return null;
}
public TopDocsCollector getTopDocsCollector(int len, SolrIndexSearcher.QueryCommand cmd, IndexSearcher searcher) throws IOException {
if(this.boostedPriority == null) {
SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
if(info != null) {
Map context = info.getReq().getContext();
this.boostedPriority = (Map<BytesRef, Integer>)context.get(QueryElevationComponent.BOOSTED_PRIORITY);
}
}
return new ReRankCollector(reRankDocs, reRankQuery, reRankWeight, cmd, searcher, boostedPriority);
}
public String toString(String s) {
return "{!rerank mainQuery='"+mainQuery.toString()+
"' reRankQuery='"+reRankQuery.toString()+
"' reRankDocs="+reRankDocs+
" reRankWeigh="+reRankWeight+"}";
}
public String toString() {
return toString(null);
}
public Weight createWeight(IndexSearcher searcher) throws IOException{
return new ReRankWeight(mainQuery, reRankQuery, reRankWeight, searcher);
}
}
private class ReRankWeight extends Weight{
private Query reRankQuery;
private IndexSearcher searcher;
private Weight mainWeight;
private double reRankWeight;
public ReRankWeight(Query mainQuery, Query reRankQuery, double reRankWeight, IndexSearcher searcher) throws IOException {
this.reRankQuery = reRankQuery;
this.searcher = searcher;
this.reRankWeight = reRankWeight;
this.mainWeight = mainQuery.createWeight(searcher);
}
public float getValueForNormalization() throws IOException {
return mainWeight.getValueForNormalization();
}
public Scorer scorer(AtomicReaderContext context, Bits bits) throws IOException {
return mainWeight.scorer(context, bits);
}
public Query getQuery() {
return mainWeight.getQuery();
}
public void normalize(float norm, float topLevelBoost) {
mainWeight.normalize(norm, topLevelBoost);
}
public Explanation explain(AtomicReaderContext context, int doc) throws IOException {
Explanation mainExplain = mainWeight.explain(context, doc);
return new QueryRescorer(reRankQuery) {
@Override
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
float score = firstPassScore;
if (secondPassMatches) {
score += reRankWeight * secondPassScore;
}
return score;
}
}.explain(searcher, mainExplain, context.docBase+doc);
}
}
private class ReRankCollector extends TopDocsCollector {
private Query reRankQuery;
private TopDocsCollector mainCollector;
private IndexSearcher searcher;
private int reRankDocs;
private double reRankWeight;
private Map<BytesRef, Integer> boostedPriority;
public ReRankCollector(int reRankDocs,
Query reRankQuery,
double reRankWeight,
SolrIndexSearcher.QueryCommand cmd,
IndexSearcher searcher,
Map<BytesRef, Integer> boostedPriority) throws IOException {
super(null);
this.reRankQuery = reRankQuery;
this.reRankDocs = reRankDocs;
this.boostedPriority = boostedPriority;
Sort sort = cmd.getSort();
if(sort == null) {
this.mainCollector = TopScoreDocCollector.create(this.reRankDocs,true);
} else {
sort = sort.rewrite(searcher);
this.mainCollector = TopFieldCollector.create(sort, this.reRankDocs, false, true, true, true);
}
this.searcher = searcher;
this.reRankWeight = reRankWeight;
}
public boolean acceptsDocsOutOfOrder() {
return false;
}
public void collect(int doc) throws IOException {
mainCollector.collect(doc);
}
public void setScorer(Scorer scorer) throws IOException{
mainCollector.setScorer(scorer);
}
public void doSetNextReader(AtomicReaderContext context) throws IOException{
mainCollector.getLeafCollector(context);
}
public int getTotalHits() {
return mainCollector.getTotalHits();
}
public TopDocs topDocs(int start, int howMany) {
try {
TopDocs mainDocs = mainCollector.topDocs(0, reRankDocs);
if(boostedPriority != null) {
SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
Map requestContext = null;
if(info != null) {
requestContext = info.getReq().getContext();
}
IntIntOpenHashMap boostedDocs = QueryElevationComponent.getBoostDocs((SolrIndexSearcher)searcher, boostedPriority, requestContext);
TopDocs rescoredDocs = new QueryRescorer(reRankQuery) {
@Override
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
float score = firstPassScore;
if (secondPassMatches) {
score += reRankWeight * secondPassScore;
}
return score;
}
}.rescore(searcher, mainDocs, reRankDocs);
Arrays.sort(rescoredDocs.scoreDocs, new BoostedComp(boostedDocs, mainDocs.scoreDocs, rescoredDocs.getMaxScore()));
if(howMany > rescoredDocs.scoreDocs.length) {
howMany = rescoredDocs.scoreDocs.length;
}
ScoreDoc[] scoreDocs = new ScoreDoc[howMany];
System.arraycopy(rescoredDocs.scoreDocs,0,scoreDocs,0,howMany);
rescoredDocs.scoreDocs = scoreDocs;
return rescoredDocs;
} else {
return new QueryRescorer(reRankQuery) {
@Override
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
float score = firstPassScore;
if (secondPassMatches) {
score += reRankWeight * secondPassScore;
}
return score;
}
}.rescore(searcher, mainDocs, howMany);
}
} catch (Exception e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
}
public class BoostedComp implements Comparator {
IntFloatOpenHashMap boostedMap;
public BoostedComp(IntIntOpenHashMap boostedDocs, ScoreDoc[] scoreDocs, float maxScore) {
this.boostedMap = new IntFloatOpenHashMap(boostedDocs.size()*2);
for(int i=0; i<scoreDocs.length; i++) {
if(boostedDocs.containsKey(scoreDocs[i].doc)) {
boostedMap.put(scoreDocs[i].doc, maxScore+boostedDocs.lget());
} else {
break;
}
}
}
public int compare(Object o1, Object o2) {
ScoreDoc doc1 = (ScoreDoc) o1;
ScoreDoc doc2 = (ScoreDoc) o2;
float score1 = doc1.score;
float score2 = doc2.score;
if(boostedMap.containsKey(doc1.doc)) {
score1 = boostedMap.lget();
}
if(boostedMap.containsKey(doc2.doc)) {
score2 = boostedMap.lget();
}
if(score1 > score2) {
return -1;
} else if(score1 < score2) {
return 1;
} else {
return 0;
}
}
}
}

View File

@ -127,6 +127,39 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
}
}
public void testReRankQuery() throws Exception {
SolrQueryRequest req = req("q", "*:*",
"rqq", "{!edismax}hello",
"rdocs", "20",
"rweight", "2",
"rows", "10",
"start", "0");
try {
assertQueryEquals("rerank", req,
"{!rerank reRankQuery=$rqq reRankDocs=$rdocs reRankWeight=$rweight}",
"{!rerank reRankQuery=$rqq reRankDocs=20 reRankWeight=2}");
} finally {
req.close();
}
req = req("qq", "*:*",
"rqq", "{!edismax}hello",
"rdocs", "20",
"rweight", "2",
"rows", "100",
"start", "50");
try {
assertQueryEquals("rerank", req,
"{!rerank mainQuery=$qq reRankQuery=$rqq reRankDocs=$rdocs reRankWeight=$rweight}",
"{!rerank mainQuery=$qq reRankQuery=$rqq reRankDocs=20 reRankWeight=2}");
} finally {
req.close();
}
}
public void testQuerySwitch() throws Exception {
SolrQueryRequest req = req("myXXX", "XXX",
"myField", "foo_s",

View File

@ -0,0 +1,290 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import com.carrotsearch.hppc.IntOpenHashSet;
import java.io.IOException;
import java.util.*;
import java.util.Random;
public class TestReRankQParserPlugin extends SolrTestCaseJ4 {
@BeforeClass
public static void beforeClass() throws Exception {
initCore("solrconfig-collapseqparser.xml", "schema11.xml");
}
@Override
@Before
public void setUp() throws Exception {
// if you override setUp or tearDown, you better call
// the super classes version
super.setUp();
clearIndex();
assertU(commit());
}
@Test
public void testReRankQueries() throws Exception {
String[] doc = {"id","1", "term_s", "YYYY", "group_s", "group1", "test_ti", "5", "test_tl", "10", "test_tf", "2000"};
assertU(adoc(doc));
assertU(commit());
String[] doc1 = {"id","2", "term_s","YYYY", "group_s", "group1", "test_ti", "50", "test_tl", "100", "test_tf", "200"};
assertU(adoc(doc1));
String[] doc2 = {"id","3", "term_s", "YYYY", "test_ti", "5000", "test_tl", "100", "test_tf", "200"};
assertU(adoc(doc2));
assertU(commit());
String[] doc3 = {"id","4", "term_s", "YYYY", "test_ti", "500", "test_tl", "1000", "test_tf", "2000"};
assertU(adoc(doc3));
String[] doc4 = {"id","5", "term_s", "YYYY", "group_s", "group2", "test_ti", "4", "test_tl", "10", "test_tf", "2000"};
assertU(adoc(doc4));
assertU(commit());
String[] doc5 = {"id","6", "term_s","YYYY", "group_s", "group2", "test_ti", "10", "test_tl", "100", "test_tf", "200"};
assertU(adoc(doc5));
assertU(commit());
ModifiableSolrParams params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=200}");
params.add("q", "term_s:YYYY");
params.add("rqq", "{!edismax bf=$bff}*:*");
params.add("bff", "field(test_ti)");
params.add("start", "0");
params.add("rows", "6");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='3.0']",
"//result/doc[2]/float[@name='id'][.='4.0']",
"//result/doc[3]/float[@name='id'][.='2.0']",
"//result/doc[4]/float[@name='id'][.='6.0']",
"//result/doc[5]/float[@name='id'][.='1.0']",
"//result/doc[6]/float[@name='id'][.='5.0']"
);
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "{!edismax bq=$bqq2}*:*");
params.add("bqq2", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='2.0']",
"//result/doc[2]/float[@name='id'][.='6.0']",
"//result/doc[3]/float[@name='id'][.='5.0']",
"//result/doc[4]/float[@name='id'][.='4.0']",
"//result/doc[5]/float[@name='id'][.='3.0']",
"//result/doc[6]/float[@name='id'][.='1.0']"
);
//Test with sort by score.
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "{!edismax bq=$bqq2}*:*");
params.add("bqq2", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
params.add("sort", "score desc");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='2.0']",
"//result/doc[2]/float[@name='id'][.='6.0']",
"//result/doc[3]/float[@name='id'][.='5.0']",
"//result/doc[4]/float[@name='id'][.='4.0']",
"//result/doc[5]/float[@name='id'][.='3.0']",
"//result/doc[6]/float[@name='id'][.='1.0']"
);
//Test with compound sort.
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "{!edismax bq=$bqq2}*:*");
params.add("bqq2", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
params.add("sort", "score desc,test_ti asc");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='2.0']",
"//result/doc[2]/float[@name='id'][.='6.0']",
"//result/doc[3]/float[@name='id'][.='5.0']",
"//result/doc[4]/float[@name='id'][.='4.0']",
"//result/doc[5]/float[@name='id'][.='3.0']",
"//result/doc[6]/float[@name='id'][.='1.0']"
);
//Test with elevation
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6 reRankWeight=50}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "{!edismax bq=$bqq2}*:*");
params.add("bqq2", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
params.add("qt", "/elevate");
params.add("elevateIds", "1");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='1.0']",
"//result/doc[2]/float[@name='id'][.='2.0']",
"//result/doc[3]/float[@name='id'][.='6.0']",
"//result/doc[4]/float[@name='id'][.='5.0']",
"//result/doc[5]/float[@name='id'][.='4.0']",
"//result/doc[6]/float[@name='id'][.='3.0']"
);
//Test TermQuery rqq
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6 reRankWeight=2}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='2.0']",
"//result/doc[2]/float[@name='id'][.='6.0']",
"//result/doc[3]/float[@name='id'][.='5.0']",
"//result/doc[4]/float[@name='id'][.='4.0']",
"//result/doc[5]/float[@name='id'][.='3.0']",
"//result/doc[6]/float[@name='id'][.='1.0']"
);
//Test Elevation
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6 reRankWeight=2}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
params.add("qt","/elevate");
params.add("elevateIds", "1,4");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='1.0']", //Elevated
"//result/doc[2]/float[@name='id'][.='4.0']", //Elevated
"//result/doc[3]/float[@name='id'][.='2.0']", //Boosted during rerank.
"//result/doc[4]/float[@name='id'][.='6.0']",
"//result/doc[5]/float[@name='id'][.='5.0']",
"//result/doc[6]/float[@name='id'][.='3.0']"
);
//Test Elevation swapped
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6 reRankWeight=2}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
params.add("qt","/elevate");
params.add("elevateIds", "4,1");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='4.0']", //Elevated
"//result/doc[2]/float[@name='id'][.='1.0']", //Elevated
"//result/doc[3]/float[@name='id'][.='2.0']", //Boosted during rerank.
"//result/doc[4]/float[@name='id'][.='6.0']",
"//result/doc[5]/float[@name='id'][.='5.0']",
"//result/doc[6]/float[@name='id'][.='3.0']"
);
//Pass in reRankDocs lower then the length being collected.
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=0 reRankWeight=2}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "10");
assertQ(req(params), "*[count(//doc)=6]",
"//result/doc[1]/float[@name='id'][.='2.0']",
"//result/doc[2]/float[@name='id'][.='6.0']",
"//result/doc[3]/float[@name='id'][.='5.0']",
"//result/doc[4]/float[@name='id'][.='4.0']",
"//result/doc[5]/float[@name='id'][.='3.0']",
"//result/doc[6]/float[@name='id'][.='1.0']"
);
//Test reRankWeight of 0, reranking will have no effect.
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=6 reRankWeight=0}");
params.add("q", "{!edismax bq=$bqq1}*:*");
params.add("bqq1", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("rqq", "test_ti:50^1000");
params.add("fl", "id,score");
params.add("start", "0");
params.add("rows", "5");
assertQ(req(params), "*[count(//doc)=5]",
"//result/doc[1]/float[@name='id'][.='6.0']",
"//result/doc[2]/float[@name='id'][.='5.0']",
"//result/doc[3]/float[@name='id'][.='4.0']",
"//result/doc[4]/float[@name='id'][.='3.0']",
"//result/doc[5]/float[@name='id'][.='2.0']"
);
//Test with start beyond reRankDocs
params = new ModifiableSolrParams();
params.add("rq", "{!rerank reRankQuery=$rqq reRankDocs=3 reRankWeight=2}");
params.add("q", "*:*");
params.add("rqq", "id:1^10 id:2^20 id:3^30 id:4^40 id:5^50 id:6^60");
params.add("fl", "id,score");
params.add("start", "4");
params.add("rows", "5");
assertQ(req(params), "*[count(//doc)=2]",
"//result/doc[1]/float[@name='id'][.='2.0']",
"//result/doc[2]/float[@name='id'][.='1.0']"
);
}
}