diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index 8b5edf33c1d..0f2583be610 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -131,6 +131,9 @@ Bug Fixes of divide by zero, and makes estimated hit counts meaningful in non-optimized indexes. (hossman) +* SOLR-3936: Fixed QueryElevationComponent sorting when used with Grouping + (Michael Garski via hossman) + Optimizations ---------------------- diff --git a/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java b/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java index ec245d20061..87bd748110f 100644 --- a/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java +++ b/solr/core/src/java/org/apache/solr/handler/component/QueryElevationComponent.java @@ -44,6 +44,7 @@ import org.apache.solr.common.SolrException; import org.apache.solr.common.params.QueryElevationParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.search.grouping.GroupingSpecification; import org.apache.solr.util.DOMUtil; import org.apache.solr.common.util.NamedList; import org.apache.solr.common.util.SimpleOrderedMap; @@ -424,23 +425,25 @@ public class QueryElevationComponent extends SearchComponent implements SolrCore })); } else { // Check if the sort is based on score - boolean modify = false; SortField[] current = sortSpec.getSort().getSort(); - ArrayList sorts = new ArrayList(current.length + 1); - // Perhaps force it to always sort by score - if (force && current[0].getType() != SortField.Type.SCORE) { - sorts.add(new SortField("_elevate_", comparator, true)); - modify = true; + Sort modified = this.modifySort(current, force, comparator); + if(modified != null) { + sortSpec.setSort(modified); } - for (SortField sf : current) { - if (sf.getType() == SortField.Type.SCORE) { - sorts.add(new SortField("_elevate_", comparator, !sf.getReverse())); - modify = true; - } - sorts.add(sf); + } + + // alter the sorting in the grouping specification if there is one + GroupingSpecification groupingSpec = rb.getGroupingSpec(); + if(groupingSpec != null) { + SortField[] groupSort = groupingSpec.getGroupSort().getSort(); + Sort modGroupSort = this.modifySort(groupSort, force, comparator); + if(modGroupSort != null) { + groupingSpec.setGroupSort(modGroupSort); } - if (modify) { - sortSpec.setSort(new Sort(sorts.toArray(new SortField[sorts.size()]))); + SortField[] withinGroupSort = groupingSpec.getSortWithinGroup().getSort(); + Sort modWithinGroupSort = this.modifySort(withinGroupSort, force, comparator); + if(modWithinGroupSort != null) { + groupingSpec.setSortWithinGroup(modWithinGroupSort); } } } @@ -466,6 +469,25 @@ public class QueryElevationComponent extends SearchComponent implements SolrCore } } + private Sort modifySort(SortField[] current, boolean force, ElevationComparatorSource comparator) { + boolean modify = false; + ArrayList sorts = new ArrayList(current.length + 1); + // Perhaps force it to always sort by score + if (force && current[0].getType() != SortField.Type.SCORE) { + sorts.add(new SortField("_elevate_", comparator, true)); + modify = true; + } + for (SortField sf : current) { + if (sf.getType() == SortField.Type.SCORE) { + sorts.add(new SortField("_elevate_", comparator, !sf.getReverse())); + modify = true; + } + sorts.add(sf); + } + + return modify ? new Sort(sorts.toArray(new SortField[sorts.size()])) : null; + } + @Override public void process(ResponseBuilder rb) throws IOException { // Do nothing -- the real work is modifying the input query diff --git a/solr/core/src/test/org/apache/solr/handler/component/QueryElevationComponentTest.java b/solr/core/src/test/org/apache/solr/handler/component/QueryElevationComponentTest.java index 66982fb2263..576b8f78207 100644 --- a/solr/core/src/test/org/apache/solr/handler/component/QueryElevationComponentTest.java +++ b/solr/core/src/test/org/apache/solr/handler/component/QueryElevationComponentTest.java @@ -22,6 +22,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.IOUtils; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.params.CommonParams; +import org.apache.solr.common.params.GroupParams; import org.apache.solr.common.params.MapSolrParams; import org.apache.solr.common.params.QueryElevationParams; import org.apache.solr.util.FileUtils; @@ -105,6 +106,208 @@ public class QueryElevationComponentTest extends SolrTestCaseJ4 { } } + @Test + public void testGroupedQuery() throws Exception { + try { + init("schema11.xml"); + clearIndex(); + assertU(commit()); + assertU(adoc("id", "1", "text", "XXXX XXXX", "str_s", "a")); + assertU(adoc("id", "2", "text", "XXXX AAAA", "str_s", "b")); + assertU(adoc("id", "3", "text", "ZZZZ", "str_s", "c")); + assertU(adoc("id", "4", "text", "XXXX ZZZZ", "str_s", "d")); + assertU(adoc("id", "5", "text", "ZZZZ ZZZZ", "str_s", "e")); + assertU(adoc("id", "6", "text", "AAAA AAAA AAAA", "str_s", "f")); + assertU(adoc("id", "7", "text", "AAAA AAAA ZZZZ", "str_s", "g")); + assertU(adoc("id", "8", "text", "XXXX", "str_s", "h")); + assertU(adoc("id", "9", "text", "YYYY ZZZZ", "str_s", "i")); + + assertU(adoc("id", "22", "text", "XXXX ZZZZ AAAA", "str_s", "b")); + assertU(adoc("id", "66", "text", "XXXX ZZZZ AAAA", "str_s", "f")); + assertU(adoc("id", "77", "text", "XXXX ZZZZ AAAA", "str_s", "g")); + + assertU(commit()); + + final String groups = "//arr[@name='groups']"; + + assertQ("non-elevated group query", + req(CommonParams.Q, "AAAA", + CommonParams.QT, "/elevate", + GroupParams.GROUP_FIELD, "str_s", + GroupParams.GROUP, "true", + GroupParams.GROUP_TOTAL_COUNT, "true", + GroupParams.GROUP_LIMIT, "100", + QueryElevationParams.ENABLE, "false", + CommonParams.FL, "id, score, [elevated]") + , "//*[@name='ngroups'][.='3']" + , "//*[@name='matches'][.='6']" + + , groups +"/lst[1]//doc[1]/float[@name='id'][.='6.0']" + , groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[1]//doc[2]/float[@name='id'][.='66.0']" + , groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[2]//doc[1]/float[@name='id'][.='7.0']" + , groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[2]//doc[2]/float[@name='id'][.='77.0']" + , groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[3]//doc[1]/float[@name='id'][.='2.0']" + , groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[3]//doc[2]/float[@name='id'][.='22.0']" + , groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']" + ); + + assertQ("elevated group query", + req(CommonParams.Q, "AAAA", + CommonParams.QT, "/elevate", + GroupParams.GROUP_FIELD, "str_s", + GroupParams.GROUP, "true", + GroupParams.GROUP_TOTAL_COUNT, "true", + GroupParams.GROUP_LIMIT, "100", + CommonParams.FL, "id, score, [elevated]") + , "//*[@name='ngroups'][.='3']" + , "//*[@name='matches'][.='6']" + + , groups +"/lst[1]//doc[1]/float[@name='id'][.='7.0']" + , groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='true']" + , groups +"/lst[1]//doc[2]/float[@name='id'][.='77.0']" + , groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[2]//doc[1]/float[@name='id'][.='6.0']" + , groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[2]//doc[2]/float[@name='id'][.='66.0']" + , groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[3]//doc[1]/float[@name='id'][.='2.0']" + , groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[3]//doc[2]/float[@name='id'][.='22.0']" + , groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']" + ); + + assertQ("non-elevated because sorted group query", + req(CommonParams.Q, "AAAA", + CommonParams.QT, "/elevate", + CommonParams.SORT, "id asc", + GroupParams.GROUP_FIELD, "str_s", + GroupParams.GROUP, "true", + GroupParams.GROUP_TOTAL_COUNT, "true", + GroupParams.GROUP_LIMIT, "100", + CommonParams.FL, "id, score, [elevated]") + , "//*[@name='ngroups'][.='3']" + , "//*[@name='matches'][.='6']" + + , groups +"/lst[1]//doc[1]/float[@name='id'][.='2.0']" + , groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[1]//doc[2]/float[@name='id'][.='22.0']" + , groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[2]//doc[1]/float[@name='id'][.='6.0']" + , groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[2]//doc[2]/float[@name='id'][.='66.0']" + , groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[3]//doc[1]/float[@name='id'][.='7.0']" + , groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='true']" + , groups +"/lst[3]//doc[2]/float[@name='id'][.='77.0']" + , groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']" + ); + + assertQ("force-elevated sorted group query", + req(CommonParams.Q, "AAAA", + CommonParams.QT, "/elevate", + CommonParams.SORT, "id asc", + QueryElevationParams.FORCE_ELEVATION, "true", + GroupParams.GROUP_FIELD, "str_s", + GroupParams.GROUP, "true", + GroupParams.GROUP_TOTAL_COUNT, "true", + GroupParams.GROUP_LIMIT, "100", + CommonParams.FL, "id, score, [elevated]") + , "//*[@name='ngroups'][.='3']" + , "//*[@name='matches'][.='6']" + + , groups +"/lst[1]//doc[1]/float[@name='id'][.='7.0']" + , groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='true']" + , groups +"/lst[1]//doc[2]/float[@name='id'][.='77.0']" + , groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[2]//doc[1]/float[@name='id'][.='2.0']" + , groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[2]//doc[2]/float[@name='id'][.='22.0']" + , groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[3]//doc[1]/float[@name='id'][.='6.0']" + , groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[3]//doc[2]/float[@name='id'][.='66.0']" + , groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']" + ); + + + assertQ("non-elevated because of sort within group query", + req(CommonParams.Q, "AAAA", + CommonParams.QT, "/elevate", + CommonParams.SORT, "id asc", + GroupParams.GROUP_SORT, "id desc", + GroupParams.GROUP_FIELD, "str_s", + GroupParams.GROUP, "true", + GroupParams.GROUP_TOTAL_COUNT, "true", + GroupParams.GROUP_LIMIT, "100", + CommonParams.FL, "id, score, [elevated]") + , "//*[@name='ngroups'][.='3']" + , "//*[@name='matches'][.='6']" + + , groups +"/lst[1]//doc[1]/float[@name='id'][.='22.0']" + , groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[1]//doc[2]/float[@name='id'][.='2.0']" + , groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[2]//doc[1]/float[@name='id'][.='66.0']" + , groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[2]//doc[2]/float[@name='id'][.='6.0']" + , groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[3]//doc[1]/float[@name='id'][.='77.0']" + , groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[3]//doc[2]/float[@name='id'][.='7.0']" + , groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='true']" + ); + + + assertQ("force elevated sort within sorted group query", + req(CommonParams.Q, "AAAA", + CommonParams.QT, "/elevate", + CommonParams.SORT, "id asc", + GroupParams.GROUP_SORT, "id desc", + QueryElevationParams.FORCE_ELEVATION, "true", + GroupParams.GROUP_FIELD, "str_s", + GroupParams.GROUP, "true", + GroupParams.GROUP_TOTAL_COUNT, "true", + GroupParams.GROUP_LIMIT, "100", + CommonParams.FL, "id, score, [elevated]") + , "//*[@name='ngroups'][.='3']" + , "//*[@name='matches'][.='6']" + + , groups +"/lst[1]//doc[1]/float[@name='id'][.='7.0']" + , groups +"/lst[1]//doc[1]/bool[@name='[elevated]'][.='true']" + , groups +"/lst[1]//doc[2]/float[@name='id'][.='77.0']" + , groups +"/lst[1]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[2]//doc[1]/float[@name='id'][.='22.0']" + , groups +"/lst[2]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[2]//doc[2]/float[@name='id'][.='2.0']" + , groups +"/lst[2]//doc[2]/bool[@name='[elevated]'][.='false']" + + , groups +"/lst[3]//doc[1]/float[@name='id'][.='66.0']" + , groups +"/lst[3]//doc[1]/bool[@name='[elevated]'][.='false']" + , groups +"/lst[3]//doc[2]/float[@name='id'][.='6.0']" + , groups +"/lst[3]//doc[2]/bool[@name='[elevated]'][.='false']" + ); + + } finally { + delete(); + } + } + @Test public void testTrieFieldType() throws Exception { try {