SOLR-3936: Fixed QueryElevationComponent sorting when used with Grouping

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1514795 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Chris M. Hostetter 2013-08-16 17:14:02 +00:00
parent b8097e9e53
commit 3399d7ec73
3 changed files with 242 additions and 14 deletions

View File

@ -131,6 +131,9 @@ Bug Fixes
of divide by zero, and makes estimated hit counts meaningful in non-optimized
indexes. (hossman)
* SOLR-3936: Fixed QueryElevationComponent sorting when used with Grouping
(Michael Garski via hossman)
Optimizations
----------------------

View File

@ -44,6 +44,7 @@ import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.QueryElevationParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.search.grouping.GroupingSpecification;
import org.apache.solr.util.DOMUtil;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SimpleOrderedMap;
@ -424,23 +425,25 @@ public class QueryElevationComponent extends SearchComponent implements SolrCore
}));
} else {
// Check if the sort is based on score
boolean modify = false;
SortField[] current = sortSpec.getSort().getSort();
ArrayList<SortField> sorts = new ArrayList<SortField>(current.length + 1);
// Perhaps force it to always sort by score
if (force && current[0].getType() != SortField.Type.SCORE) {
sorts.add(new SortField("_elevate_", comparator, true));
modify = true;
Sort modified = this.modifySort(current, force, comparator);
if(modified != null) {
sortSpec.setSort(modified);
}
for (SortField sf : current) {
if (sf.getType() == SortField.Type.SCORE) {
sorts.add(new SortField("_elevate_", comparator, !sf.getReverse()));
modify = true;
}
sorts.add(sf);
}
// alter the sorting in the grouping specification if there is one
GroupingSpecification groupingSpec = rb.getGroupingSpec();
if(groupingSpec != null) {
SortField[] groupSort = groupingSpec.getGroupSort().getSort();
Sort modGroupSort = this.modifySort(groupSort, force, comparator);
if(modGroupSort != null) {
groupingSpec.setGroupSort(modGroupSort);
}
if (modify) {
sortSpec.setSort(new Sort(sorts.toArray(new SortField[sorts.size()])));
SortField[] withinGroupSort = groupingSpec.getSortWithinGroup().getSort();
Sort modWithinGroupSort = this.modifySort(withinGroupSort, force, comparator);
if(modWithinGroupSort != null) {
groupingSpec.setSortWithinGroup(modWithinGroupSort);
}
}
}
@ -466,6 +469,25 @@ public class QueryElevationComponent extends SearchComponent implements SolrCore
}
}
private Sort modifySort(SortField[] current, boolean force, ElevationComparatorSource comparator) {
boolean modify = false;
ArrayList<SortField> sorts = new ArrayList<SortField>(current.length + 1);
// Perhaps force it to always sort by score
if (force && current[0].getType() != SortField.Type.SCORE) {
sorts.add(new SortField("_elevate_", comparator, true));
modify = true;
}
for (SortField sf : current) {
if (sf.getType() == SortField.Type.SCORE) {
sorts.add(new SortField("_elevate_", comparator, !sf.getReverse()));
modify = true;
}
sorts.add(sf);
}
return modify ? new Sort(sorts.toArray(new SortField[sorts.size()])) : null;
}
@Override
public void process(ResponseBuilder rb) throws IOException {
// Do nothing -- the real work is modifying the input query

View File

@ -22,6 +22,7 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.GroupParams;
import org.apache.solr.common.params.MapSolrParams;
import org.apache.solr.common.params.QueryElevationParams;
import org.apache.solr.util.FileUtils;
@ -105,6 +106,208 @@ public class QueryElevationComponentTest extends SolrTestCaseJ4 {
}
}
@Test
public void testGroupedQuery() throws Exception {
try {
init("schema11.xml");
clearIndex();
assertU(commit());
assertU(adoc("id", "1", "text", "XXXX XXXX", "str_s", "a"));
assertU(adoc("id", "2", "text", "XXXX AAAA", "str_s", "b"));
assertU(adoc("id", "3", "text", "ZZZZ", "str_s", "c"));
assertU(adoc("id", "4", "text", "XXXX ZZZZ", "str_s", "d"));
assertU(adoc("id", "5", "text", "ZZZZ ZZZZ", "str_s", "e"));
assertU(adoc("id", "6", "text", "AAAA AAAA AAAA", "str_s", "f"));
assertU(adoc("id", "7", "text", "AAAA AAAA ZZZZ", "str_s", "g"));
assertU(adoc("id", "8", "text", "XXXX", "str_s", "h"));
assertU(adoc("id", "9", "text", "YYYY ZZZZ", "str_s", "i"));
assertU(adoc("id", "22", "text", "XXXX ZZZZ AAAA", "str_s", "b"));
assertU(adoc("id", "66", "text", "XXXX ZZZZ AAAA", "str_s", "f"));
assertU(adoc("id", "77", "text", "XXXX ZZZZ AAAA", "str_s", "g"));
assertU(commit());
final String groups = "//arr[@name='groups']";
assertQ("non-elevated group query",
req(CommonParams.Q, "AAAA",
CommonParams.QT, "/elevate",
GroupParams.GROUP_FIELD, "str_s",
GroupParams.GROUP, "true",
GroupParams.GROUP_TOTAL_COUNT, "true",
GroupParams.GROUP_LIMIT, "100",
QueryElevationParams.ENABLE, "false",
CommonParams.FL, "id, score, [elevated]")
, "//*[@name='ngroups'][.='3']"
, "//*[@name='matches'][.='6']"
, groups +"/lst[1]//doc[1]/float[@name='id'][.='6.0']"
, groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[1]//doc[2]/float[@name='id'][.='66.0']"
, groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[1]/float[@name='id'][.='7.0']"
, groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[2]/float[@name='id'][.='77.0']"
, groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[1]/float[@name='id'][.='2.0']"
, groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[2]/float[@name='id'][.='22.0']"
, groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']"
);
assertQ("elevated group query",
req(CommonParams.Q, "AAAA",
CommonParams.QT, "/elevate",
GroupParams.GROUP_FIELD, "str_s",
GroupParams.GROUP, "true",
GroupParams.GROUP_TOTAL_COUNT, "true",
GroupParams.GROUP_LIMIT, "100",
CommonParams.FL, "id, score, [elevated]")
, "//*[@name='ngroups'][.='3']"
, "//*[@name='matches'][.='6']"
, groups +"/lst[1]//doc[1]/float[@name='id'][.='7.0']"
, groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='true']"
, groups +"/lst[1]//doc[2]/float[@name='id'][.='77.0']"
, groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[1]/float[@name='id'][.='6.0']"
, groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[2]/float[@name='id'][.='66.0']"
, groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[1]/float[@name='id'][.='2.0']"
, groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[2]/float[@name='id'][.='22.0']"
, groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']"
);
assertQ("non-elevated because sorted group query",
req(CommonParams.Q, "AAAA",
CommonParams.QT, "/elevate",
CommonParams.SORT, "id asc",
GroupParams.GROUP_FIELD, "str_s",
GroupParams.GROUP, "true",
GroupParams.GROUP_TOTAL_COUNT, "true",
GroupParams.GROUP_LIMIT, "100",
CommonParams.FL, "id, score, [elevated]")
, "//*[@name='ngroups'][.='3']"
, "//*[@name='matches'][.='6']"
, groups +"/lst[1]//doc[1]/float[@name='id'][.='2.0']"
, groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[1]//doc[2]/float[@name='id'][.='22.0']"
, groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[1]/float[@name='id'][.='6.0']"
, groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[2]/float[@name='id'][.='66.0']"
, groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[1]/float[@name='id'][.='7.0']"
, groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='true']"
, groups +"/lst[3]//doc[2]/float[@name='id'][.='77.0']"
, groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']"
);
assertQ("force-elevated sorted group query",
req(CommonParams.Q, "AAAA",
CommonParams.QT, "/elevate",
CommonParams.SORT, "id asc",
QueryElevationParams.FORCE_ELEVATION, "true",
GroupParams.GROUP_FIELD, "str_s",
GroupParams.GROUP, "true",
GroupParams.GROUP_TOTAL_COUNT, "true",
GroupParams.GROUP_LIMIT, "100",
CommonParams.FL, "id, score, [elevated]")
, "//*[@name='ngroups'][.='3']"
, "//*[@name='matches'][.='6']"
, groups +"/lst[1]//doc[1]/float[@name='id'][.='7.0']"
, groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='true']"
, groups +"/lst[1]//doc[2]/float[@name='id'][.='77.0']"
, groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[1]/float[@name='id'][.='2.0']"
, groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[2]/float[@name='id'][.='22.0']"
, groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[1]/float[@name='id'][.='6.0']"
, groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[2]/float[@name='id'][.='66.0']"
, groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']"
);
assertQ("non-elevated because of sort within group query",
req(CommonParams.Q, "AAAA",
CommonParams.QT, "/elevate",
CommonParams.SORT, "id asc",
GroupParams.GROUP_SORT, "id desc",
GroupParams.GROUP_FIELD, "str_s",
GroupParams.GROUP, "true",
GroupParams.GROUP_TOTAL_COUNT, "true",
GroupParams.GROUP_LIMIT, "100",
CommonParams.FL, "id, score, [elevated]")
, "//*[@name='ngroups'][.='3']"
, "//*[@name='matches'][.='6']"
, groups +"/lst[1]//doc[1]/float[@name='id'][.='22.0']"
, groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[1]//doc[2]/float[@name='id'][.='2.0']"
, groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[1]/float[@name='id'][.='66.0']"
, groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[2]/float[@name='id'][.='6.0']"
, groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[1]/float[@name='id'][.='77.0']"
, groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[2]/float[@name='id'][.='7.0']"
, groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='true']"
);
assertQ("force elevated sort within sorted group query",
req(CommonParams.Q, "AAAA",
CommonParams.QT, "/elevate",
CommonParams.SORT, "id asc",
GroupParams.GROUP_SORT, "id desc",
QueryElevationParams.FORCE_ELEVATION, "true",
GroupParams.GROUP_FIELD, "str_s",
GroupParams.GROUP, "true",
GroupParams.GROUP_TOTAL_COUNT, "true",
GroupParams.GROUP_LIMIT, "100",
CommonParams.FL, "id, score, [elevated]")
, "//*[@name='ngroups'][.='3']"
, "//*[@name='matches'][.='6']"
, groups +"/lst[1]//doc[1]/float[@name='id'][.='7.0']"
, groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='true']"
, groups +"/lst[1]//doc[2]/float[@name='id'][.='77.0']"
, groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[1]/float[@name='id'][.='22.0']"
, groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[2]//doc[2]/float[@name='id'][.='2.0']"
, groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[1]/float[@name='id'][.='66.0']"
, groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']"
, groups +"/lst[3]//doc[2]/float[@name='id'][.='6.0']"
, groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']"
);
} finally {
delete();
}
}
@Test
public void testTrieFieldType() throws Exception {
try {