SOLR-1297: fix weighting when sorting by function

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1061065 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Yonik Seeley 2011-01-19 23:48:28 +00:00
parent 1099f8465d
commit 0a30dcd945
7 changed files with 135 additions and 45 deletions

View File

@ -443,7 +443,7 @@ public class QueryComponent extends SearchComponent
// take the documents given and re-derive the sort values. // take the documents given and re-derive the sort values.
boolean fsv = req.getParams().getBool(ResponseBuilder.FIELD_SORT_VALUES,false); boolean fsv = req.getParams().getBool(ResponseBuilder.FIELD_SORT_VALUES,false);
if(fsv){ if(fsv){
Sort sort = rb.getSortSpec().getSort(); Sort sort = searcher.weightSort(rb.getSortSpec().getSort());
SortField[] sortFields = sort==null ? new SortField[]{SortField.FIELD_SCORE} : sort.getSort(); SortField[] sortFields = sort==null ? new SortField[]{SortField.FIELD_SCORE} : sort.getSort();
NamedList sortVals = new NamedList(); // order is important for the sort fields NamedList sortVals = new NamedList(); // order is important for the sort fields
Field field = new Field("dummy", "", Field.Store.YES, Field.Index.NO); // a dummy Field Field field = new Field("dummy", "", Field.Store.YES, Field.Index.NO); // a dummy Field

View File

@ -162,7 +162,7 @@ public class Grouping {
// if we aren't going to return any groups, disregard the offset // if we aren't going to return any groups, disregard the offset
if (numGroups == 0) maxGroupToFind = 0; if (numGroups == 0) maxGroupToFind = 0;
collector = new TopGroupCollector(groupBy, context, normalizeSort(sort), maxGroupToFind); collector = new TopGroupCollector(groupBy, context, searcher.weightSort(normalizeSort(sort)), maxGroupToFind);
/*** if we need a different algorithm when sort != group.sort /*** if we need a different algorithm when sort != group.sort
if (compareSorts(sort, groupSort)) { if (compareSorts(sort, groupSort)) {
@ -185,9 +185,9 @@ public class Grouping {
int collectorOffset = format==Format.Simple ? 0 : offset; int collectorOffset = format==Format.Simple ? 0 : offset;
if (groupBy instanceof StrFieldSource) { if (groupBy instanceof StrFieldSource) {
collector2 = new Phase2StringGroupCollector(collector, groupBy, context, groupSort, docsToCollect, needScores, collectorOffset); collector2 = new Phase2StringGroupCollector(collector, groupBy, context, searcher.weightSort(groupSort), docsToCollect, needScores, collectorOffset);
} else { } else {
collector2 = new Phase2GroupCollector(collector, groupBy, context, groupSort, docsToCollect, needScores, collectorOffset); collector2 = new Phase2GroupCollector(collector, groupBy, context, searcher.weightSort(groupSort), docsToCollect, needScores, collectorOffset);
} }
return collector2; return collector2;
} }
@ -306,11 +306,11 @@ public class Grouping {
return v; return v;
} }
static TopDocsCollector newCollector(Sort sort, int numHits, boolean fillFields, boolean needScores) throws IOException { TopDocsCollector newCollector(Sort sort, int numHits, boolean fillFields, boolean needScores) throws IOException {
if (sort==null || sort==byScoreDesc) { if (sort==null || sort==byScoreDesc) {
return TopScoreDocCollector.create(numHits, true); return TopScoreDocCollector.create(numHits, true);
} else { } else {
return TopFieldCollector.create(sort, numHits, false, needScores, needScores, true); return TopFieldCollector.create(searcher.weightSort(sort), numHits, false, needScores, needScores, true);
} }
} }
@ -505,12 +505,12 @@ class TopGroupCollector extends GroupCollector {
int matches; int matches;
public TopGroupCollector(ValueSource groupByVS, Map vsContext, Sort sort, int nGroups) throws IOException { public TopGroupCollector(ValueSource groupByVS, Map vsContext, Sort weightedSort, int nGroups) throws IOException {
this.vs = groupByVS; this.vs = groupByVS;
this.context = vsContext; this.context = vsContext;
this.nGroups = nGroups = Math.max(1,nGroups); // we need a minimum of 1 for this collector this.nGroups = nGroups = Math.max(1,nGroups); // we need a minimum of 1 for this collector
SortField[] sortFields = sort.getSort(); SortField[] sortFields = weightedSort.getSort();
this.comparators = new FieldComparator[sortFields.length]; this.comparators = new FieldComparator[sortFields.length];
this.reversed = new int[sortFields.length]; this.reversed = new int[sortFields.length];
for (int i = 0; i < sortFields.length; i++) { for (int i = 0; i < sortFields.length; i++) {
@ -719,7 +719,7 @@ class Phase2GroupCollector extends Collector {
int docBase; int docBase;
// TODO: may want to decouple from the phase1 collector // TODO: may want to decouple from the phase1 collector
public Phase2GroupCollector(TopGroupCollector topGroups, ValueSource groupByVS, Map vsContext, Sort sort, int docsPerGroup, boolean getScores, int offset) throws IOException { public Phase2GroupCollector(TopGroupCollector topGroups, ValueSource groupByVS, Map vsContext, Sort weightedSort, int docsPerGroup, boolean getScores, int offset) throws IOException {
boolean getSortFields = false; boolean getSortFields = false;
if (topGroups.orderedGroups == null) if (topGroups.orderedGroups == null)
@ -733,10 +733,10 @@ class Phase2GroupCollector extends Collector {
} }
SearchGroupDocs groupDocs = new SearchGroupDocs(); SearchGroupDocs groupDocs = new SearchGroupDocs();
groupDocs.groupValue = group.groupValue; groupDocs.groupValue = group.groupValue;
if (sort==null) if (weightedSort==null)
groupDocs.collector = TopScoreDocCollector.create(docsPerGroup, true); groupDocs.collector = TopScoreDocCollector.create(docsPerGroup, true);
else else
groupDocs.collector = TopFieldCollector.create(sort, docsPerGroup, getSortFields, getScores, getScores, true); groupDocs.collector = TopFieldCollector.create(weightedSort, docsPerGroup, getSortFields, getScores, getScores, true);
groupMap.put(groupDocs.groupValue, groupDocs); groupMap.put(groupDocs.groupValue, groupDocs);
} }
@ -791,8 +791,8 @@ class Phase2StringGroupCollector extends Phase2GroupCollector {
final SearchGroupDocs[] groups; final SearchGroupDocs[] groups;
final BytesRef spare = new BytesRef(); final BytesRef spare = new BytesRef();
public Phase2StringGroupCollector(TopGroupCollector topGroups, ValueSource groupByVS, Map vsContext, Sort sort, int docsPerGroup, boolean getScores, int offset) throws IOException { public Phase2StringGroupCollector(TopGroupCollector topGroups, ValueSource groupByVS, Map vsContext, Sort weightedSort, int docsPerGroup, boolean getScores, int offset) throws IOException {
super(topGroups, groupByVS, vsContext,sort,docsPerGroup,getScores,offset); super(topGroups, groupByVS, vsContext,weightedSort,docsPerGroup,getScores,offset);
ordSet = new SentinelIntSet(groupMap.size(), -1); ordSet = new SentinelIntSet(groupMap.size(), -1);
groups = new SearchGroupDocs[ordSet.keys.length]; groups = new SearchGroupDocs[ordSet.keys.length];
} }

View File

@ -460,6 +460,30 @@ public class SolrIndexSearcher extends IndexSearcher implements SolrInfoMBean {
return fieldValueCache; return fieldValueCache;
} }
/** Returns a weighted sort according to this searcher */
public Sort weightSort(Sort sort) throws IOException {
if (sort == null) return null;
SortField[] sorts = sort.getSort();
boolean needsWeighting = false;
for (SortField sf : sorts) {
if (sf instanceof SolrSortField) {
needsWeighting = true;
break;
}
}
if (!needsWeighting) return sort;
SortField[] newSorts = Arrays.copyOf(sorts, sorts.length);
for (int i=0; i<newSorts.length; i++) {
if (newSorts[i] instanceof SolrSortField) {
newSorts[i] = ((SolrSortField)newSorts[i]).weight(this);
}
}
return new Sort(newSorts);
}
/** /**
* Returns the first document number containing the term <code>t</code> * Returns the first document number containing the term <code>t</code>
@ -1156,7 +1180,7 @@ public class SolrIndexSearcher extends IndexSearcher implements SolrInfoMBean {
if (cmd.getSort() == null) { if (cmd.getSort() == null) {
topCollector = TopScoreDocCollector.create(len, true); topCollector = TopScoreDocCollector.create(len, true);
} else { } else {
topCollector = TopFieldCollector.create(cmd.getSort(), len, false, needScores, needScores, true); topCollector = TopFieldCollector.create(weightSort(cmd.getSort()), len, false, needScores, needScores, true);
} }
Collector collector = topCollector; Collector collector = topCollector;
if( timeAllowed > 0 ) { if( timeAllowed > 0 ) {
@ -1266,7 +1290,7 @@ public class SolrIndexSearcher extends IndexSearcher implements SolrInfoMBean {
if (cmd.getSort() == null) { if (cmd.getSort() == null) {
topCollector = TopScoreDocCollector.create(len, true); topCollector = TopScoreDocCollector.create(len, true);
} else { } else {
topCollector = TopFieldCollector.create(cmd.getSort(), len, false, needScores, needScores, true); topCollector = TopFieldCollector.create(weightSort(cmd.getSort()), len, false, needScores, needScores, true);
} }
DocSetCollector setCollector = new DocSetDelegateCollector(maxDoc>>6, maxDoc, topCollector); DocSetCollector setCollector = new DocSetDelegateCollector(maxDoc>>6, maxDoc, topCollector);
@ -1548,7 +1572,7 @@ public class SolrIndexSearcher extends IndexSearcher implements SolrInfoMBean {
// bit of a hack to tell if a set is sorted - do it better in the futute. // bit of a hack to tell if a set is sorted - do it better in the futute.
boolean inOrder = set instanceof BitDocSet || set instanceof SortedIntDocSet; boolean inOrder = set instanceof BitDocSet || set instanceof SortedIntDocSet;
TopDocsCollector topCollector = TopFieldCollector.create(sort, nDocs, false, false, false, inOrder); TopDocsCollector topCollector = TopFieldCollector.create(weightSort(sort), nDocs, false, false, false, inOrder);
DocIterator iter = set.iterator(); DocIterator iter = set.iterator();
int base=0; int base=0;

View File

@ -0,0 +1,31 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.SortField;
import java.io.IOException;
/**@lucene.internal
*
*/
public interface SolrSortField {
public SortField weight(IndexSearcher searcher) throws IOException;
}

View File

@ -26,12 +26,13 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
import org.apache.lucene.index.MultiFields; import org.apache.lucene.index.MultiFields;
import org.apache.solr.common.SolrException;
import org.apache.solr.search.SolrSortField;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
import java.util.IdentityHashMap; import java.util.IdentityHashMap;
import java.util.Map; import java.util.Map;
import java.util.Collections;
/** /**
* Instantiates {@link org.apache.solr.search.function.DocValues} for a particular reader. * Instantiates {@link org.apache.solr.search.function.DocValues} for a particular reader.
@ -61,24 +62,6 @@ public abstract class ValueSource implements Serializable {
return description(); return description();
} }
/**
* EXPERIMENTAL: This method is subject to change.
* <br>WARNING: Sorted function queries are not currently weighted.
* <p>
* Get the SortField for this ValueSource. Uses the {@link #getValues(java.util.Map, AtomicReaderContext)}
* to populate the SortField.
*
* @param reverse true if this is a reverse sort.
* @return The {@link org.apache.lucene.search.SortField} for the ValueSource
* @throws IOException if there was a problem reading the values.
*/
public SortField getSortField(boolean reverse) throws IOException {
//should we pass in the description for the field name?
//Hmm, Lucene is going to intern whatever we pass in, not sure I like that
//and we can't pass in null, either, as that throws an illegal arg. exception
return new SortField(description(), new ValueSourceComparatorSource(), reverse);
}
/** /**
* Implementations should propagate createWeight to sub-ValueSources which can optionally store * Implementations should propagate createWeight to sub-ValueSources which can optionally store
@ -97,16 +80,56 @@ public abstract class ValueSource implements Serializable {
return context; return context;
} }
//
// Sorting by function
//
/**
* EXPERIMENTAL: This method is subject to change.
* <br>WARNING: Sorted function queries are not currently weighted.
* <p>
* Get the SortField for this ValueSource. Uses the {@link #getValues(java.util.Map, AtomicReaderContext)}
* to populate the SortField.
*
* @param reverse true if this is a reverse sort.
* @return The {@link org.apache.lucene.search.SortField} for the ValueSource
* @throws IOException if there was a problem reading the values.
*/
public SortField getSortField(boolean reverse) throws IOException {
return new ValueSourceSortField(reverse);
}
private static FieldComparatorSource dummyComparator = new FieldComparatorSource() {
@Override
public FieldComparator newComparator(String fieldname, int numHits, int sortPos, boolean reversed) throws IOException {
throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unweighted use of sort " + fieldname);
}
};
class ValueSourceSortField extends SortField implements SolrSortField {
public ValueSourceSortField(boolean reverse) {
super(description(), dummyComparator, reverse);
}
@Override
public SortField weight(IndexSearcher searcher) throws IOException {
Map context = newContext(searcher);
createWeight(context, searcher);
return new SortField(getField(), new ValueSourceComparatorSource(context), getReverse());
}
}
class ValueSourceComparatorSource extends FieldComparatorSource { class ValueSourceComparatorSource extends FieldComparatorSource {
private final Map context;
public ValueSourceComparatorSource(Map context) {
public ValueSourceComparatorSource() { this.context = context;
} }
public FieldComparator newComparator(String fieldname, int numHits, public FieldComparator newComparator(String fieldname, int numHits,
int sortPos, boolean reversed) throws IOException { int sortPos, boolean reversed) throws IOException {
return new ValueSourceComparator(numHits); return new ValueSourceComparator(context, numHits);
} }
} }
@ -119,8 +142,10 @@ public abstract class ValueSource implements Serializable {
private final double[] values; private final double[] values;
private DocValues docVals; private DocValues docVals;
private double bottom; private double bottom;
private Map fcontext;
ValueSourceComparator(int numHits) { ValueSourceComparator(Map fcontext, int numHits) {
this.fcontext = fcontext;
values = new double[numHits]; values = new double[numHits];
} }
@ -153,7 +178,7 @@ public abstract class ValueSource implements Serializable {
} }
public FieldComparator setNextReader(AtomicReaderContext context) throws IOException { public FieldComparator setNextReader(AtomicReaderContext context) throws IOException {
docVals = getValues(Collections.emptyMap(), context); docVals = getValues(fcontext, context);
return this; return this;
} }
@ -162,7 +187,7 @@ public abstract class ValueSource implements Serializable {
} }
public Comparable value(int slot) { public Comparable value(int slot) {
return Double.valueOf(values[slot]); return values[slot];
} }
} }
} }

View File

@ -95,6 +95,7 @@ public class TestDistributedSearch extends BaseDistributedSearchTestCase {
// these queries should be exactly ordered and scores should exactly match // these queries should be exactly ordered and scores should exactly match
query("q","*:*", "sort",i1+" desc"); query("q","*:*", "sort",i1+" desc");
query("q","*:*", "sort","{!func}add("+i1+",5)"+" desc");
query("q","*:*", "sort",i1+" asc"); query("q","*:*", "sort",i1+" asc");
query("q","*:*", "sort",i1+" desc", "fl","*,score"); query("q","*:*", "sort",i1+" desc", "fl","*,score");
query("q","*:*", "sort",tlong+" asc", "fl","score"); // test legacy behavior - "score"=="*,score" query("q","*:*", "sort",tlong+" asc", "fl","score"); // test legacy behavior - "score"=="*,score"

View File

@ -326,17 +326,18 @@ public class TestFunctionQuery extends SolrTestCaseJ4 {
assertU(adoc("id",""+i, "text","batman")); assertU(adoc("id",""+i, "text","batman"));
} }
assertU(commit()); assertU(commit());
assertU(adoc("id","120", "text","batman superman")); // in a segment by itself assertU(adoc("id","120", "text","batman superman")); // in a smaller segment
assertU(adoc("id","121", "text","superman"));
assertU(commit()); assertU(commit());
// batman and superman have the same idf in single-doc segment, but very different in the complete index. // superman has a higher df (thus lower idf) in one segment, but reversed in the complete index
String q ="{!func}query($qq)"; String q ="{!func}query($qq)";
String fq="id:120"; String fq="id:120";
assertQ(req("fl","*,score","q", q, "qq","text:batman", "fq",fq), "//float[@name='score']<'1.0'"); assertQ(req("fl","*,score","q", q, "qq","text:batman", "fq",fq), "//float[@name='score']<'1.0'");
assertQ(req("fl","*,score","q", q, "qq","text:superman", "fq",fq), "//float[@name='score']>'1.0'"); assertQ(req("fl","*,score","q", q, "qq","text:superman", "fq",fq), "//float[@name='score']>'1.0'");
// test weighting through a function range query // test weighting through a function range query
assertQ(req("fl","*,score", "q", "{!frange l=1 u=10}query($qq)", "qq","text:superman"), "//*[@numFound='1']"); assertQ(req("fl","*,score", "fq",fq, "q", "{!frange l=1 u=10}query($qq)", "qq","text:superman"), "//*[@numFound='1']");
// test weighting through a complex function // test weighting through a complex function
q ="{!func}sub(div(sum(0.0,product(1,query($qq))),1),0)"; q ="{!func}sub(div(sum(0.0,product(1,query($qq))),1),0)";
@ -360,6 +361,14 @@ public class TestFunctionQuery extends SolrTestCaseJ4 {
// OK // OK
} }
// test that sorting by function weights correctly. superman should sort higher than batman due to idf of the whole index
assertQ(req("q", "*:*", "fq","id:120 OR id:121", "sort","{!func v=$sortfunc} desc", "sortfunc","query($qq)", "qq","text:(batman OR superman)")
,"*//doc[1]/float[.='120.0']"
,"*//doc[2]/float[.='121.0']"
);
purgeFieldCache(FieldCache.DEFAULT); // avoid FC insanity purgeFieldCache(FieldCache.DEFAULT); // avoid FC insanity
} }