diff --git a/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java b/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java index 9022324ad4a..78f415407c8 100644 --- a/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java +++ b/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java @@ -134,8 +134,8 @@ public abstract class AbstractAllGroupHeadsCollector getCommands() { return commands; } @@ -296,16 +306,21 @@ public class Grouping { cmd.prepare(); } + AbstractAllGroupHeadsCollector allGroupHeadsCollector = null; List collectors = new ArrayList(commands.size()); for (Command cmd : commands) { Collector collector = cmd.createFirstPassCollector(); - if (collector != null) + if (collector != null) { collectors.add(collector); + } + if (getGroupedDocSet && allGroupHeadsCollector == null) { + collectors.add(allGroupHeadsCollector = cmd.createAllGroupCollector()); + } } Collector allCollectors = MultiCollector.wrap(collectors.toArray(new Collector[collectors.size()])); DocSetCollector setCollector = null; - if (getDocSet) { + if (getDocSet && allGroupHeadsCollector == null) { setCollector = new DocSetDelegateCollector(maxDoc >> 6, maxDoc, allCollectors); allCollectors = setCollector; } @@ -329,7 +344,12 @@ public class Grouping { searcher.search(query, luceneFilter, allCollectors); } - if (getDocSet) { + if (getGroupedDocSet && allGroupHeadsCollector != null) { + FixedBitSet fixedBitSet = allGroupHeadsCollector.retrieveGroupHeads(maxDoc); + long[] bits = fixedBitSet.getBits(); + OpenBitSet openBitSet = new OpenBitSet(bits, bits.length); + qr.setDocSet(new BitDocSet(openBitSet)); + } else if (getDocSet) { qr.setDocSet(setCollector.getDocSet()); } @@ -383,8 +403,8 @@ public class Grouping { * Returns offset + len if len equals zero or higher. Otherwise returns max. * * @param offset The offset - * @param len The number of documents to return - * @param max The number of document to return if len < 0 or if offset + len < 0 + * @param len The number of documents to return + * @param max The number of document to return if len < 0 or if offset + len < 0 * @return offset + len if len equals zero or higher. Otherwise returns max */ int getMax(int offset, int len, int max) { @@ -483,6 +503,17 @@ public class Grouping { return null; } + /** + * Returns a collector that is able to return the most relevant document of all groups. + * Returns null if the command doesn't support this type of collector. + * + * @return a collector that is able to return the most relevant document of all groups. + * @throws IOException If I/O related errors occur + */ + public AbstractAllGroupHeadsCollector createAllGroupCollector() throws IOException { + return null; + } + /** * Performs any necessary post actions to prepare the response. * @@ -585,7 +616,8 @@ public class Grouping { } } - int len = docsGathered - offset;int[] docs = ArrayUtils.toPrimitive(ids.toArray(new Integer[ids.size()])); + int len = docsGathered - offset; + int[] docs = ArrayUtils.toPrimitive(ids.toArray(new Integer[ids.size()])); float[] docScores = ArrayUtils.toPrimitive(scores.toArray(new Float[scores.size()])); DocSlice docSlice = new DocSlice(offset, len, docs, docScores, getMatches(), maxScore); @@ -672,6 +704,15 @@ public class Grouping { } } + /** + * {@inheritDoc} + */ + @Override + public AbstractAllGroupHeadsCollector createAllGroupCollector() throws IOException { + Sort sortWithinGroup = groupSort != null ? groupSort : new Sort(); + return TermAllGroupHeadsCollector.create(groupBy, sortWithinGroup); + } + /** * {@inheritDoc} */ @@ -873,6 +914,12 @@ public class Grouping { } } + @Override + public AbstractAllGroupHeadsCollector createAllGroupCollector() throws IOException { + Sort sortWithinGroup = groupSort != null ? groupSort : new Sort(); + return new FunctionAllGroupHeadsCollector(groupBy, context, sortWithinGroup); + } + /** * {@inheritDoc} */ @@ -1091,4 +1138,102 @@ public class Grouping { } -} + + static class FunctionAllGroupHeadsCollector extends AbstractAllGroupHeadsCollector { + + private final ValueSource groupBy; + private final Map vsContext; + private final Map groups; + private final Sort sortWithinGroup; + + private DocValues docValues; + private DocValues.ValueFiller filler; + private MutableValue mval; + private AtomicReaderContext readerContext; + private Scorer scorer; + + FunctionAllGroupHeadsCollector(ValueSource groupBy, Map vsContext, Sort sortWithinGroup) { + super(sortWithinGroup.getSort().length); + groups = new HashMap(); + this.sortWithinGroup = sortWithinGroup; + this.groupBy = groupBy; + this.vsContext = vsContext; + + final SortField[] sortFields = sortWithinGroup.getSort(); + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + } + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + filler.fillValue(doc); + GroupHead groupHead = groups.get(mval); + if (groupHead == null) { + MutableValue groupValue = mval.duplicate(); + groupHead = new GroupHead(groupValue, sortWithinGroup, doc); + groups.put(groupValue, groupHead); + temporalResult.stop = true; + } else { + temporalResult.stop = false; + } + this.temporalResult.groupHead = groupHead; + } + + protected Collection getCollectedGroupHeads() { + return groups.values(); + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + for (GroupHead groupHead : groups.values()) { + for (FieldComparator comparator : groupHead.comparators) { + comparator.setScorer(scorer); + } + } + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + this.readerContext = context; + docValues = groupBy.getValues(vsContext, context); + filler = docValues.getValueFiller(); + mval = filler.getValue(); + + for (GroupHead groupHead : groups.values()) { + for (int i = 0; i < groupHead.comparators.length; i++) { + groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context); + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + final FieldComparator[] comparators; + + private GroupHead(MutableValue groupValue, Sort sort, int doc) throws IOException { + super(groupValue, doc + readerContext.docBase); + final SortField[] sortFields = sort.getSort(); + comparators = new FieldComparator[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + comparators[i] = sortFields[i].getComparator(1, i).setNextReader(readerContext); + comparators[i].setScorer(scorer); + comparators[i].copy(0, doc); + comparators[i].setBottom(0); + } + } + + public int compare(int compIDX, int doc) throws IOException { + return comparators[compIDX].compareBottom(doc); + } + + public void updateDocHead(int doc) throws IOException { + for (FieldComparator comparator : comparators) { + comparator.copy(0, doc); + comparator.setBottom(0); + } + this.doc = doc + readerContext.docBase; + } + } + + } + +} \ No newline at end of file diff --git a/solr/core/src/test/org/apache/solr/TestGroupingSearch.java b/solr/core/src/test/org/apache/solr/TestGroupingSearch.java index caad8953bfa..ea3451b91eb 100644 --- a/solr/core/src/test/org/apache/solr/TestGroupingSearch.java +++ b/solr/core/src/test/org/apache/solr/TestGroupingSearch.java @@ -234,6 +234,53 @@ public class TestGroupingSearch extends SolrTestCaseJ4 { ); } + @Test + public void testGroupingGroupedBasedFaceting() throws Exception { + assertU(add(doc("id", "1", "value1_s1", "1", "value2_i", "1", "value3_s1", "a", "value4_i", "1"))); + assertU(add(doc("id", "2", "value1_s1", "1", "value2_i", "2", "value3_s1", "a", "value4_i", "1"))); + assertU(commit()); + assertU(add(doc("id", "3", "value1_s1", "2", "value2_i", "3", "value3_s1", "b", "value4_i", "2"))); + assertU(add(doc("id", "4", "value1_s1", "1", "value2_i", "4", "value3_s1", "a", "value4_i", "1"))); + assertU(add(doc("id", "5", "value1_s1", "2", "value2_i", "5", "value3_s1", "b", "value4_i", "2"))); + assertU(commit()); + + // Facet counts based on documents + SolrQueryRequest req = req("q", "*:*", "sort", "value2_i asc", "rows", "1", "group", "true", "group.field", + "value1_s1", "fl", "id", "facet", "true", "facet.field", "value3_s1", "group.truncate", "false"); + assertJQ( + req, + "/grouped=={'value1_s1':{'matches':5,'groups':[{'groupValue':'1','doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}", + "/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',3,'b',2]},'facet_dates':{},'facet_ranges':{}}" + ); + + // Facet counts based on groups + req = req("q", "*:*", "sort", "value2_i asc", "rows", "1", "group", "true", "group.field", + "value1_s1", "fl", "id", "facet", "true", "facet.field", "value3_s1", "group.truncate", "true"); + assertJQ( + req, + "/grouped=={'value1_s1':{'matches':5,'groups':[{'groupValue':'1','doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}", + "/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',1,'b',1]},'facet_dates':{},'facet_ranges':{}}" + ); + + // Facet counts based on groups and with group.func. This should trigger FunctionAllGroupHeadsCollector + req = req("q", "*:*", "sort", "value2_i asc", "rows", "1", "group", "true", "group.func", + "strdist(1,value1_s1,edit)", "fl", "id", "facet", "true", "facet.field", "value3_s1", "group.truncate", "true"); + assertJQ( + req, + "/grouped=={'strdist(1,value1_s1,edit)':{'matches':5,'groups':[{'groupValue':1.0,'doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}", + "/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',1,'b',1]},'facet_dates':{},'facet_ranges':{}}" + ); + + // Facet counts based on groups without sort on an int field. + req = req("q", "*:*", "rows", "1", "group", "true", "group.field", "value4_i", "fl", "id", "facet", "true", + "facet.field", "value3_s1", "group.truncate", "true"); + assertJQ( + req, + "/grouped=={'value4_i':{'matches':5,'groups':[{'groupValue':1,'doclist':{'numFound':3,'start':0,'docs':[{'id':'1'}]}}]}}", + "/facet_counts=={'facet_queries':{},'facet_fields':{'value3_s1':['a',1,'b',1]},'facet_dates':{},'facet_ranges':{}}" + ); + } + static String f = "foo_i"; static String f2 = "foo2_i"; @@ -474,7 +521,7 @@ public class TestGroupingSearch extends SolrTestCaseJ4 { types.add(new FldType("id",ONE_ONE, new SVal('A','Z',4,4))); types.add(new FldType("score_f",ONE_ONE, new FVal(1,100))); // field used to score types.add(new FldType("foo_i",ZERO_ONE, new IRange(0,indexSize))); - types.add(new FldType(FOO_STRING_FIELD,ZERO_ONE, new SVal('a','z',1,2))); + types.add(new FldType(FOO_STRING_FIELD,ONE_ONE, new SVal('a','z',1,2))); types.add(new FldType(SMALL_STRING_FIELD,ZERO_ONE, new SVal('a',(char)('c'+indexSize/10),1,1))); types.add(new FldType(SMALL_INT_FIELD,ZERO_ONE, new IRange(0,5+indexSize/10))); @@ -567,25 +614,61 @@ public class TestGroupingSearch extends SolrTestCaseJ4 { for (Grp grp : groups.values()) grp.setMaxDoc(sortComparator); } - List sortedGroups = new ArrayList(groups.values()); + List sortedGroups = new ArrayList(groups.values()); Collections.sort(sortedGroups, groupComparator==sortComparator ? createFirstDocComparator(sortComparator) : createMaxDocComparator(sortComparator)); boolean includeNGroups = random.nextBoolean(); Object modelResponse = buildGroupedResult(h.getCore().getSchema(), sortedGroups, start, rows, group_offset, group_limit, includeNGroups); + boolean truncateGroups = random.nextBoolean(); + Map facetCounts = new TreeMap(); + if (truncateGroups) { + for (Grp grp : sortedGroups) { + Doc doc = grp.docs.get(0); + if (doc.getValues(FOO_STRING_FIELD) == null) { + continue; + } + + String key = doc.getFirstValue(FOO_STRING_FIELD).toString(); + boolean exists = facetCounts.containsKey(key); + int count = exists ? facetCounts.get(key) : 0; + facetCounts.put(key, ++count); + } + } else { + for (Doc doc : model.values()) { + if (doc.getValues(FOO_STRING_FIELD) == null) { + continue; + } + + for (Comparable field : doc.getValues(FOO_STRING_FIELD)) { + String key = field.toString(); + boolean exists = facetCounts.containsKey(key); + int count = exists ? facetCounts.get(key) : 0; + facetCounts.put(key, ++count); + } + } + } + List expectedFacetResponse = new ArrayList(); + for (Map.Entry stringIntegerEntry : facetCounts.entrySet()) { + expectedFacetResponse.add(stringIntegerEntry.getKey()); + expectedFacetResponse.add(stringIntegerEntry.getValue()); + } + int randomPercentage = random.nextInt(101); // TODO: create a random filter too SolrQueryRequest req = req("group","true","wt","json","indent","true", "echoParams","all", "q","{!func}score_f", "group.field",groupField ,sortStr==null ? "nosort":"sort", sortStr ==null ? "": sortStr - ,(groupSortStr==null || groupSortStr==sortStr) ? "noGroupsort":"group.sort", groupSortStr==null ? "": groupSortStr + ,(groupSortStr == null || groupSortStr == sortStr) ? "noGroupsort":"group.sort", groupSortStr==null ? "": groupSortStr ,"rows",""+rows, "start",""+start, "group.offset",""+group_offset, "group.limit",""+group_limit, - GroupParams.GROUP_CACHE_PERCENTAGE, Integer.toString(randomPercentage), GroupParams.GROUP_TOTAL_COUNT, includeNGroups ? "true" : "false" + GroupParams.GROUP_CACHE_PERCENTAGE, Integer.toString(randomPercentage), GroupParams.GROUP_TOTAL_COUNT, includeNGroups ? "true" : "false", + "facet", "true", "facet.sort", "index", "facet.limit", "-1", "facet.field", FOO_STRING_FIELD, + GroupParams.GROUP_TRUNCATE, truncateGroups ? "true" : "false", "facet.mincount", "1", "facet.method", "fcs" // to avoid FC insanity ); String strResponse = h.query(req); Object realResponse = ObjectBuilder.fromJSON(strResponse); - String err = JSONTestUtil.matchObj("/grouped/"+groupField, realResponse, modelResponse); + String err = JSONTestUtil.matchObj("/grouped/" + groupField, realResponse, modelResponse); if (err != null) { log.error("GROUPING MISMATCH: " + err + "\n\trequest="+req @@ -599,6 +682,20 @@ public class TestGroupingSearch extends SolrTestCaseJ4 { fail(err); } + + // assert post / pre grouping facets + err = JSONTestUtil.matchObj("/facet_counts/facet_fields/"+FOO_STRING_FIELD, realResponse, expectedFacetResponse); + if (err != null) { + log.error("GROUPING MISMATCH: " + err + + "\n\trequest="+req + + "\n\tresult="+strResponse + + "\n\texpected="+ JSONUtil.toJSON(expectedFacetResponse) + ); + + // re-execute the request... good for putting a breakpoint here for debugging + h.query(req); + fail(err); + } } // end query iter } // end index iter diff --git a/solr/solrj/src/java/org/apache/solr/common/params/GroupParams.java b/solr/solrj/src/java/org/apache/solr/common/params/GroupParams.java index 806e147845f..4c619be3fbc 100755 --- a/solr/solrj/src/java/org/apache/solr/common/params/GroupParams.java +++ b/solr/solrj/src/java/org/apache/solr/common/params/GroupParams.java @@ -48,8 +48,9 @@ public interface GroupParams { // Note: Since you can supply multiple fields to group on, but only have a facets for the whole result. It only makes // sense to me to support these parameters for the first group. - /** Whether the docSet (for example for faceting) should be based on plain documents (a.k.a UNGROUPED) or on the groups (a.k.a GROUPED). */ - public static final String GROUP_COLLAPSE = GROUP + ".collapse"; + /** Whether the docSet (for example for faceting) should be based on plain documents (a.k.a UNGROUPED) or on the groups (a.k.a GROUPED). + * The docSet will only the most relevant documents per group. It is if you query for everything with group.limit=1 */ + public static final String GROUP_TRUNCATE = GROUP + ".truncate"; /** Whether the group count should be included in the response. */ public static final String GROUP_TOTAL_COUNT = GROUP + ".ngroups";