From 49075985fb9d833839494fa0a6af6fd3d721f4ef Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Sun, 24 Jul 2011 19:06:51 +0000 Subject: [PATCH] LUCENE-3097: Added a new grouping collector that can be used to retrieve all most relevant documents per group. git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1150470 13f79535-47bb-0310-9956-ffa450edef68 --- lucene/CHANGES.txt | 4 + .../AbstractAllGroupHeadsCollector.java | 179 ++++++ .../grouping/TermAllGroupHeadsCollector.java | 540 ++++++++++++++++++ .../lucene/search/grouping/package.html | 15 + .../TermAllGroupHeadsCollectorTest.java | 492 ++++++++++++++++ 5 files changed, 1230 insertions(+) create mode 100644 modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java create mode 100644 modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java create mode 100644 modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index c3b81468f30..27aabff7227 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -553,6 +553,10 @@ New Features AbstractField.setOmitTermFrequenciesAndPositions is deprecated, you should use DOCS_ONLY instead. (Robert Muir) +* LUCENE-3097: Added a new grouping collector that can be used to retrieve all most relevant + documents per group. This can be useful in situations when one wants to compute grouping + based facets / statistics on the complete query result. (Martijn van Groningen) + Optimizations * LUCENE-3201, LUCENE-3218: CompoundFileSystem code has been consolidated diff --git a/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java b/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java new file mode 100644 index 00000000000..9022324ad4a --- /dev/null +++ b/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java @@ -0,0 +1,179 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.util.FixedBitSet; + +import java.io.IOException; +import java.util.Collection; + +/** + * This collector specializes in collecting the most relevant document (group head) for each group that match the query. + * + * @lucene.experimental + */ +public abstract class AbstractAllGroupHeadsCollector extends Collector { + + protected final int[] reversed; + protected final int compIDXEnd; + protected final TemporalResult temporalResult; + + protected AbstractAllGroupHeadsCollector(int numberOfSorts) { + this.reversed = new int[numberOfSorts]; + this.compIDXEnd = numberOfSorts - 1; + temporalResult = new TemporalResult(); + } + + /** + * @param maxDoc The maxDoc of the top level {@link IndexReader}. + * @return an {@link FixedBitSet} containing all group heads. + */ + public FixedBitSet retrieveGroupHeads(int maxDoc) { + FixedBitSet bitSet = new FixedBitSet(maxDoc); + + Collection groupHeads = getCollectedGroupHeads(); + for (GroupHead groupHead : groupHeads) { + bitSet.set(groupHead.doc); + } + + return bitSet; + } + + /** + * @return an int array containing all group heads. The size of the array is equal to number of collected unique groups. + */ + public int[] retrieveGroupHeads() { + Collection groupHeads = getCollectedGroupHeads(); + int[] docHeads = new int[groupHeads.size()]; + + int i = 0; + for (GroupHead groupHead : groupHeads) { + docHeads[i++] = groupHead.doc; + } + + return docHeads; + } + + /** + * @return the number of group heads found for a query. + */ + public int groupHeadsSize() { + return getCollectedGroupHeads().size(); + } + + /** + * Returns the group head and puts it into {@link #temporalResult}. + * If the group head wasn't encountered before then it will be added to the collected group heads. + *

+ * The {@link TemporalResult#stop} property will be true if the group head wasn't encountered before + * otherwise false. + * + * @param doc The document to retrieve the group head for. + * @throws IOException If I/O related errors occur + */ + protected abstract void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException; + + /** + * Returns the collected group heads. + * Subsequent calls should return the same group heads. + * + * @return the collected group heads + */ + protected abstract Collection getCollectedGroupHeads(); + + public void collect(int doc) throws IOException { + retrieveGroupHeadAndAddIfNotExist(doc); + if (temporalResult.stop) { + return; + } + GH groupHead = temporalResult.groupHead; + + // Ok now we need to check if the current doc is more relevant then current doc for this group + for (int compIDX = 0; ; compIDX++) { + final int c = reversed[compIDX] * groupHead.compare(compIDX, doc); + if (c < 0) { + // Definitely not competitive. So don't even bother to continue + return; + } else if (c > 0) { + // Definitely competitive. + break; + } else if (compIDX == compIDXEnd) { + // Here c=0. If we're at the last comparator, this doc is not + // competitive, since docs are visited in doc Id order, which means + // this doc cannot compete with any other document in the queue. + return; + } + } + groupHead.updateDocHead(doc); + } + + public boolean acceptsDocsOutOfOrder() { + return true; + } + + /** + * Contains the result of group head retrieval. + * To prevent new object creations of this class for every collect. + */ + protected class TemporalResult { + + protected GH groupHead; + protected boolean stop; + + } + + /** + * Represents a group head. A group head is the most relevant document for a particular group. + * The relevancy is based is usually based on the sort. + * + * The group head contains a group value with its associated most relevant document id. + */ + public static abstract class GroupHead { + + public final GROUP_VALUE_TYPE groupValue; + public int doc; + + protected GroupHead(GROUP_VALUE_TYPE groupValue, int doc) { + this.groupValue = groupValue; + this.doc = doc; + } + + /** + * Compares the specified document for a specified comparator against the current most relevant document. + * + * @param compIDX The comparator index of the specified comparator. + * @param doc The specified document. + * @return -1 if the specified document wasn't competitive against the current most relevant document, 1 if the + * specified document was competitive against the current most relevant document. Otherwise 0. + * @throws IOException If I/O related errors occur + */ + protected abstract int compare(int compIDX, int doc) throws IOException; + + /** + * Updates the current most relevant document with the specified document. + * + * @param doc The specified document + * @throws IOException If I/O related errors occur + */ + protected abstract void updateDocHead(int doc) throws IOException; + + } + +} diff --git a/modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java b/modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java new file mode 100644 index 00000000000..edad23c2fc3 --- /dev/null +++ b/modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java @@ -0,0 +1,540 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.*; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.util.*; + +/** + * A base implementation of {@link AbstractAllGroupHeadsCollector} for retrieving the most relevant groups when grouping + * on a string based group field. More specifically this all concrete implementations of this base implementation + * use {@link org.apache.lucene.search.FieldCache.DocTermsIndex}. + * + * @lucene.experimental + */ +public abstract class TermAllGroupHeadsCollector extends AbstractAllGroupHeadsCollector { + + private static final int DEFAULT_INITIAL_SIZE = 128; + + final String groupField; + final BytesRef scratchBytesRef = new BytesRef(); + + FieldCache.DocTermsIndex groupIndex; + IndexReader.AtomicReaderContext readerContext; + + protected TermAllGroupHeadsCollector(String groupField, int numberOfSorts) { + super(numberOfSorts); + this.groupField = groupField; + } + + /** + * Creates an AbstractAllGroupHeadsCollector instance based on the supplied arguments. + * This factory method decides with implementation is best suited. + * + * Delegates to {@link #create(String, org.apache.lucene.search.Sort, int)} with an initialSize of 128. + * + * @param groupField The field to group by + * @param sortWithinGroup The sort within each group + * @return an AbstractAllGroupHeadsCollector instance based on the supplied arguments + * @throws IOException If I/O related errors occur + */ + public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup) throws IOException { + return create(groupField, sortWithinGroup, DEFAULT_INITIAL_SIZE); + } + + /** + * Creates an AbstractAllGroupHeadsCollector instance based on the supplied arguments. + * This factory method decides with implementation is best suited. + * + * @param groupField The field to group by + * @param sortWithinGroup The sort within each group + * @param initialSize The initial allocation size of the internal int set and group list which should roughly match + * the total number of expected unique groups. Be aware that the heap usage is + * 4 bytes * initialSize. + * @return an AbstractAllGroupHeadsCollector instance based on the supplied arguments + * @throws IOException If I/O related errors occur + */ + public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup, int initialSize) throws IOException { + boolean sortAllScore = true; + boolean sortAllFieldValue = true; + + for (SortField sortField : sortWithinGroup.getSort()) { + if (sortField.getType() == SortField.Type.SCORE) { + sortAllFieldValue = false; + } else if (needGeneralImpl(sortField)) { + return new GeneralAllGroupHeadsCollector(groupField, sortWithinGroup); + } else { + sortAllScore = false; + } + } + + if (sortAllScore) { + return new ScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else if (sortAllFieldValue) { + return new OrdAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else { + return new OrdScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } + } + + // Returns when a sort field needs the general impl. + private static boolean needGeneralImpl(SortField sortField) { + SortField.Type sortType = sortField.getType(); + // Note (MvG): We can also make an optimized impl when sorting is SortField.DOC + return sortType != SortField.Type.STRING_VAL && sortType != SortField.Type.STRING && sortType != SortField.Type.SCORE; + } + + // A general impl that works for any group sort. + static class GeneralAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final Sort sortWithinGroup; + private final Map groups; + + private Scorer scorer; + + GeneralAllGroupHeadsCollector(String groupField, Sort sortWithinGroup) throws IOException { + super(groupField, sortWithinGroup.getSort().length); + this.sortWithinGroup = sortWithinGroup; + groups = new HashMap(); + + final SortField[] sortFields = sortWithinGroup.getSort(); + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + } + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + final int ord = groupIndex.getOrd(doc); + final BytesRef groupValue = ord == 0 ? null : groupIndex.lookup(ord, scratchBytesRef); + GroupHead groupHead = groups.get(groupValue); + if (groupHead == null) { + groupHead = new GroupHead(groupValue, sortWithinGroup, doc); + groups.put(groupValue == null ? null : new BytesRef(groupValue), groupHead); + temporalResult.stop = true; + } else { + temporalResult.stop = false; + } + temporalResult.groupHead = groupHead; + } + + protected Collection getCollectedGroupHeads() { + return groups.values(); + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + for (GroupHead groupHead : groups.values()) { + for (int i = 0; i < groupHead.comparators.length; i++) { + groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context); + } + } + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + for (GroupHead groupHead : groups.values()) { + for (FieldComparator comparator : groupHead.comparators) { + comparator.setScorer(scorer); + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + final FieldComparator[] comparators; + + private GroupHead(BytesRef groupValue, Sort sort, int doc) throws IOException { + super(groupValue, doc + readerContext.docBase); + final SortField[] sortFields = sort.getSort(); + comparators = new FieldComparator[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + comparators[i] = sortFields[i].getComparator(1, i).setNextReader(readerContext); + comparators[i].setScorer(scorer); + comparators[i].copy(0, doc); + comparators[i].setBottom(0); + } + } + + public int compare(int compIDX, int doc) throws IOException { + return comparators[compIDX].compareBottom(doc); + } + + public void updateDocHead(int doc) throws IOException { + for (FieldComparator comparator : comparators) { + comparator.copy(0, doc); + comparator.setBottom(0); + } + this.doc = doc + readerContext.docBase; + } + } + } + + + // AbstractAllGroupHeadsCollector optimized for ord fields and scores. + static class OrdScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + + private FieldCache.DocTermsIndex[] sortsIndex; + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + + OrdScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + continue; + } + + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + continue; + } + + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); + } + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + BytesRef[] sortValues; + int[] sortOrds; + float[] scores; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc + readerContext.docBase); + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + scores = new float[sortsIndex.length]; + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + + } + + public int compare(int compIDX, int doc) throws IOException { + if (fields[compIDX].getType() == SortField.Type.SCORE) { + float score = scorer.score(); + if (scores[compIDX] < score) { + return 1; + } else if (scores[compIDX] > score) { + return -1; + } + return 0; + } else { + if (sortOrds[compIDX] < 0) { + // The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative. + return sortValues[compIDX].compareTo(sortsIndex[compIDX].getTerm(doc, scratchBytesRef)); + } else { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + } + } + + public void updateDocHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, sortValues[i]); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + + // AbstractAllGroupHeadsCollector optimized for ord fields. + static class OrdAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + + private FieldCache.DocTermsIndex[] sortsIndex; + private GroupHead[] segmentGroupHeads; + + OrdAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.size()]; + for (GroupHead collectedGroup : collectedGroups) { + int groupOrd = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (groupOrd >= 0) { + ordSet.put(groupOrd); + segmentGroupHeads[groupOrd] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); + } + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + BytesRef[] sortValues; + int[] sortOrds; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc + readerContext.docBase); + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + + public int compare(int compIDX, int doc) throws IOException { + if (sortOrds[compIDX] < 0) { + // The current segment doesn't contain the sort value we encountered before. Therefore the ord is negative. + return sortValues[compIDX].compareTo(sortsIndex[compIDX].getTerm(doc, scratchBytesRef)); + } else { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + } + + public void updateDocHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, sortValues[i]); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + + // AbstractAllGroupHeadsCollector optimized for scores. + static class ScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + + ScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + float[] scores; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc + readerContext.docBase); + scores = new float[fields.length]; + float score = scorer.score(); + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + } + + public int compare(int compIDX, int doc) throws IOException { + float score = scorer.score(); + if (scores[compIDX] < score) { + return 1; + } else if (scores[compIDX] > score) { + return -1; + } + return 0; + } + + public void updateDocHead(int doc) throws IOException { + float score = scorer.score(); + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + this.doc = doc + readerContext.docBase; + } + + } + + } + +} \ No newline at end of file diff --git a/modules/grouping/src/java/org/apache/lucene/search/grouping/package.html b/modules/grouping/src/java/org/apache/lucene/search/grouping/package.html index 12156aceb0b..dea41e41155 100644 --- a/modules/grouping/src/java/org/apache/lucene/search/grouping/package.html +++ b/modules/grouping/src/java/org/apache/lucene/search/grouping/package.html @@ -164,5 +164,20 @@ will be null, so if you need to present this value you'll have to separately retrieve it (for example using stored fields, FieldCache, etc.). +

Another collector is the TermAllGroupHeadsCollector that can be used to retrieve all most relevant + documents per group. Also known as group heads. This can be useful in situations when one wants to compute group + based facets / statistics on the complete query result. The collector can be executed during the first or second + phase.

+ +
+  AbstractAllGroupHeadsCollector c = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup);
+  s.search(new TermQuery(new Term("content", searchTerm)), c);
+  // Return all group heads as int array
+  int[] groupHeadsArray = c.retrieveGroupHeads()
+  // Return all group heads as OpenBitSet.
+  int maxDoc = s.maxDoc();
+  OpenBitSet groupHeadsBitSet = c.retrieveGroupHeads(maxDoc)
+
+ diff --git a/modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java b/modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java new file mode 100644 index 00000000000..6ca0e4e5146 --- /dev/null +++ b/modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java @@ -0,0 +1,492 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.NumericField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.*; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util._TestUtil; + +import java.io.IOException; +import java.util.*; + +public class TermAllGroupHeadsCollectorTest extends LuceneTestCase { + + public void testBasic() throws Exception { + final String groupField = "author"; + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random, + dir, + newIndexWriterConfig(TEST_VERSION_CURRENT, + new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy())); + + // 0 + Document doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "1", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 1 + doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random text blob", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "2", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 2 + doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random textual data", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "3", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + w.commit(); // To ensure a second segment + + // 3 + doc = new Document(); + doc.add(new Field(groupField, "author2", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "4", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 4 + doc = new Document(); + doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "5", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 5 + doc = new Document(); + doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "random blob", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 6 -- no author field + doc = new Document(); + doc.add(new Field("content", "random word stuck in alot of other text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 7 -- no author field + doc = new Document(); + doc.add(new Field("content", "random word stuck in alot of other text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "7", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); + w.close(); + int maxDoc = indexSearcher.maxDoc(); + + Sort sortWithinGroup = new Sort(new SortField("id", SortField.Type.INT, true)); + AbstractAllGroupHeadsCollector c1 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "random")), c1); + assertTrue(arrayContains(new int[]{2, 3, 5, 7}, c1.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{2, 3, 5, 7}, c1.retrieveGroupHeads(maxDoc), maxDoc)); + + AbstractAllGroupHeadsCollector c2 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "some")), c2); + assertTrue(arrayContains(new int[]{2, 3, 4}, c2.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{2, 3, 4}, c2.retrieveGroupHeads(maxDoc), maxDoc)); + + AbstractAllGroupHeadsCollector c3 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "blob")), c3); + assertTrue(arrayContains(new int[]{1, 5}, c3.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{1, 5}, c3.retrieveGroupHeads(maxDoc), maxDoc)); + + // STRING sort type triggers different implementation + Sort sortWithinGroup2 = new Sort(new SortField("id", SortField.Type.STRING, true)); + AbstractAllGroupHeadsCollector c4 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup2); + indexSearcher.search(new TermQuery(new Term("content", "random")), c4); + assertTrue(arrayContains(new int[]{2, 3, 5, 7}, c4.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{2, 3, 5, 7}, c4.retrieveGroupHeads(maxDoc), maxDoc)); + + Sort sortWithinGroup3 = new Sort(new SortField("id", SortField.Type.STRING, false)); + AbstractAllGroupHeadsCollector c5 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup3); + indexSearcher.search(new TermQuery(new Term("content", "random")), c5); + // 7 b/c higher doc id wins, even if order of field is in not in reverse. + assertTrue(arrayContains(new int[]{0, 3, 4, 6}, c5.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{0, 3, 4, 6}, c5.retrieveGroupHeads(maxDoc), maxDoc)); + + indexSearcher.getIndexReader().close(); + dir.close(); + } + + public void testRandom() throws Exception { + int numberOfRuns = _TestUtil.nextInt(random, 3, 6); + for (int iter = 0; iter < numberOfRuns; iter++) { + if (VERBOSE) { + System.out.println(String.format("TEST: iter=%d total=%d", iter, numberOfRuns)); + } + + final int numDocs = _TestUtil.nextInt(random, 100, 1000) * RANDOM_MULTIPLIER; + final int numGroups = _TestUtil.nextInt(random, 1, numDocs); + + if (VERBOSE) { + System.out.println("TEST: numDocs=" + numDocs + " numGroups=" + numGroups); + } + + final List groups = new ArrayList(); + for (int i = 0; i < numGroups; i++) { + groups.add(new BytesRef(_TestUtil.randomRealisticUnicodeString(random))); + } + final String[] contentStrings = new String[_TestUtil.nextInt(random, 2, 20)]; + if (VERBOSE) { + System.out.println("TEST: create fake content"); + } + for (int contentIDX = 0; contentIDX < contentStrings.length; contentIDX++) { + final StringBuilder sb = new StringBuilder(); + sb.append("real").append(random.nextInt(3)).append(' '); + final int fakeCount = random.nextInt(10); + for (int fakeIDX = 0; fakeIDX < fakeCount; fakeIDX++) { + sb.append("fake "); + } + contentStrings[contentIDX] = sb.toString(); + if (VERBOSE) { + System.out.println(" content=" + sb.toString()); + } + } + + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random, + dir, + newIndexWriterConfig(TEST_VERSION_CURRENT, + new MockAnalyzer(random))); + + Document doc = new Document(); + Document docNoGroup = new Document(); + Field group = newField("group", "", Field.Index.NOT_ANALYZED); + doc.add(group); + Field sort1 = newField("sort1", "", Field.Index.NOT_ANALYZED); + doc.add(sort1); + docNoGroup.add(sort1); + Field sort2 = newField("sort2", "", Field.Index.NOT_ANALYZED); + doc.add(sort2); + docNoGroup.add(sort2); + Field sort3 = newField("sort3", "", Field.Index.NOT_ANALYZED); + doc.add(sort3); + docNoGroup.add(sort3); + Field content = newField("content", "", Field.Index.ANALYZED); + doc.add(content); + docNoGroup.add(content); + NumericField id = new NumericField("id"); + doc.add(id); + docNoGroup.add(id); + final GroupDoc[] groupDocs = new GroupDoc[numDocs]; + for (int i = 0; i < numDocs; i++) { + final BytesRef groupValue; + if (random.nextInt(24) == 17) { + // So we test the "doc doesn't have the group'd + // field" case: + groupValue = null; + } else { + groupValue = groups.get(random.nextInt(groups.size())); + } + + final GroupDoc groupDoc = new GroupDoc( + i, + groupValue, + groups.get(random.nextInt(groups.size())), + groups.get(random.nextInt(groups.size())), + new BytesRef(String.format("%05d", i)), + contentStrings[random.nextInt(contentStrings.length)] + ); + + if (VERBOSE) { + System.out.println(" doc content=" + groupDoc.content + " id=" + i + " group=" + (groupDoc.group == null ? "null" : groupDoc.group.utf8ToString()) + " sort1=" + groupDoc.sort1.utf8ToString() + " sort2=" + groupDoc.sort2.utf8ToString() + " sort3=" + groupDoc.sort3.utf8ToString()); + } + + groupDocs[i] = groupDoc; + if (groupDoc.group != null) { + group.setValue(groupDoc.group.utf8ToString()); + } + sort1.setValue(groupDoc.sort1.utf8ToString()); + sort2.setValue(groupDoc.sort2.utf8ToString()); + sort3.setValue(groupDoc.sort3.utf8ToString()); + content.setValue(groupDoc.content); + id.setIntValue(groupDoc.id); + if (groupDoc.group == null) { + w.addDocument(docNoGroup); + } else { + w.addDocument(doc); + } + } + + final IndexReader r = w.getReader(); + w.close(); + + // NOTE: intentional but temporary field cache insanity! + final int[] docIdToFieldId = FieldCache.DEFAULT.getInts(r, "id"); + final int[] fieldIdToDocID = new int[numDocs]; + for (int i = 0; i < docIdToFieldId.length; i++) { + int fieldId = docIdToFieldId[i]; + fieldIdToDocID[fieldId] = i; + } + + try { + final IndexSearcher s = newSearcher(r); + + for (int contentID = 0; contentID < 3; contentID++) { + final ScoreDoc[] hits = s.search(new TermQuery(new Term("content", "real" + contentID)), numDocs).scoreDocs; + for (ScoreDoc hit : hits) { + final GroupDoc gd = groupDocs[docIdToFieldId[hit.doc]]; + assertTrue(gd.score == 0.0); + gd.score = hit.score; + int docId = gd.id; + assertEquals(docId, docIdToFieldId[hit.doc]); + } + } + + for (GroupDoc gd : groupDocs) { + assertTrue(gd.score != 0.0); + } + + for (int searchIter = 0; searchIter < 100; searchIter++) { + + if (VERBOSE) { + System.out.println("TEST: searchIter=" + searchIter); + } + + final String searchTerm = "real" + random.nextInt(3); + boolean sortByScoreOnly = random.nextBoolean(); + Sort sortWithinGroup = getRandomSort(sortByScoreOnly); + AbstractAllGroupHeadsCollector allGroupHeadsCollector = TermAllGroupHeadsCollector.create("group", sortWithinGroup); + s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollector); + int[] expectedGroupHeads = createExpectedGroupHeads(searchTerm, groupDocs, sortWithinGroup, sortByScoreOnly, fieldIdToDocID); + int[] actualGroupHeads = allGroupHeadsCollector.retrieveGroupHeads(); + // The actual group heads contains Lucene ids. Need to change them into our id value. + for (int i = 0; i < actualGroupHeads.length; i++) { + actualGroupHeads[i] = docIdToFieldId[actualGroupHeads[i]]; + } + // Allows us the easily iterate and assert the actual and expected results. + Arrays.sort(expectedGroupHeads); + Arrays.sort(actualGroupHeads); + + if (VERBOSE) { + System.out.println("Collector: " + allGroupHeadsCollector.getClass().getSimpleName()); + System.out.println("Sort within group: " + sortWithinGroup); + System.out.println("Num group: " + numGroups); + System.out.println("Num doc: " + numDocs); + System.out.println("\n=== Expected: \n"); + for (int expectedDocId : expectedGroupHeads) { + GroupDoc expectedGroupDoc = groupDocs[expectedDocId]; + String expectedGroup = expectedGroupDoc.group == null ? null : expectedGroupDoc.group.utf8ToString(); + System.out.println( + String.format( + "Group:%10s score%5f Sort1:%10s Sort2:%10s Sort3:%10s doc:%5d", + expectedGroup, expectedGroupDoc.score, expectedGroupDoc.sort1.utf8ToString(), + expectedGroupDoc.sort2.utf8ToString(), expectedGroupDoc.sort3.utf8ToString(), expectedDocId + ) + ); + } + System.out.println("\n=== Actual: \n"); + for (int actualDocId : actualGroupHeads) { + GroupDoc actualGroupDoc = groupDocs[actualDocId]; + String actualGroup = actualGroupDoc.group == null ? null : actualGroupDoc.group.utf8ToString(); + System.out.println( + String.format( + "Group:%10s score%5f Sort1:%10s Sort2:%10s Sort3:%10s doc:%5d", + actualGroup, actualGroupDoc.score, actualGroupDoc.sort1.utf8ToString(), + actualGroupDoc.sort2.utf8ToString(), actualGroupDoc.sort3.utf8ToString(), actualDocId + ) + ); + } + System.out.println("\n==================================================================================="); + } + + assertEquals(expectedGroupHeads.length, actualGroupHeads.length); + for (int i = 0; i < expectedGroupHeads.length; i++) { + assertEquals(expectedGroupHeads[i], actualGroupHeads[i]); + } + } + s.close(); + } finally { + FieldCache.DEFAULT.purge(r); + } + + r.close(); + dir.close(); + } + } + + + private boolean arrayContains(int[] expected, int[] actual) { + if (expected.length != actual.length) { + return false; + } + + for (int e : expected) { + boolean found = false; + for (int a : actual) { + if (e == a) { + found = true; + } + } + + if (!found) { + return false; + } + } + + return true; + } + + private boolean openBitSetContains(int[] expectedDocs, FixedBitSet actual, int maxDoc) throws IOException { + if (expectedDocs.length != actual.cardinality()) { + return false; + } + + FixedBitSet expected = new FixedBitSet(maxDoc); + for (int expectedDoc : expectedDocs) { + expected.set(expectedDoc); + } + + int docId; + DocIdSetIterator iterator = expected.iterator(); + while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (!actual.get(docId)) { + return false; + } + } + + return true; + } + + private int[] createExpectedGroupHeads(String searchTerm, GroupDoc[] groupDocs, Sort docSort, boolean sortByScoreOnly, int[] fieldIdToDocID) throws IOException { + Map> groupHeads = new HashMap>(); + for (GroupDoc groupDoc : groupDocs) { + if (!groupDoc.content.startsWith(searchTerm)) { + continue; + } + + if (!groupHeads.containsKey(groupDoc.group)) { + List list = new ArrayList(); + list.add(groupDoc); + groupHeads.put(groupDoc.group, list); + continue; + } + groupHeads.get(groupDoc.group).add(groupDoc); + } + + int[] allGroupHeads = new int[groupHeads.size()]; + int i = 0; + for (BytesRef groupValue : groupHeads.keySet()) { + List docs = groupHeads.get(groupValue); + Collections.sort(docs, getComparator(docSort, sortByScoreOnly, fieldIdToDocID)); + allGroupHeads[i++] = docs.get(0).id; + } + + return allGroupHeads; + } + + private Sort getRandomSort(boolean scoreOnly) { + final List sortFields = new ArrayList(); + if (random.nextInt(7) == 2 || scoreOnly) { + sortFields.add(SortField.FIELD_SCORE); + } else { + if (random.nextBoolean()) { + if (random.nextBoolean()) { + sortFields.add(new SortField("sort1", SortField.Type.STRING, random.nextBoolean())); + } else { + sortFields.add(new SortField("sort2", SortField.Type.STRING, random.nextBoolean())); + } + } else if (random.nextBoolean()) { + sortFields.add(new SortField("sort1", SortField.Type.STRING, random.nextBoolean())); + sortFields.add(new SortField("sort2", SortField.Type.STRING, random.nextBoolean())); + } + } + // Break ties: + if (random.nextBoolean() && !scoreOnly) { + sortFields.add(new SortField("sort3", SortField.Type.STRING)); + } else if (!scoreOnly) { + sortFields.add(new SortField("id", SortField.Type.INT)); + } + return new Sort(sortFields.toArray(new SortField[sortFields.size()])); + } + + private Comparator getComparator(Sort sort, final boolean sortByScoreOnly, final int[] fieldIdToDocID) { + final SortField[] sortFields = sort.getSort(); + return new Comparator() { + @Override + public int compare(GroupDoc d1, GroupDoc d2) { + for (SortField sf : sortFields) { + final int cmp; + if (sf.getType() == SortField.Type.SCORE) { + if (d1.score > d2.score) { + cmp = -1; + } else if (d1.score < d2.score) { + cmp = 1; + } else { + cmp = sortByScoreOnly ? fieldIdToDocID[d1.id] - fieldIdToDocID[d2.id] : 0; + } + } else if (sf.getField().equals("sort1")) { + cmp = d1.sort1.compareTo(d2.sort1); + } else if (sf.getField().equals("sort2")) { + cmp = d1.sort2.compareTo(d2.sort2); + } else if (sf.getField().equals("sort3")) { + cmp = d1.sort3.compareTo(d2.sort3); + } else { + assertEquals(sf.getField(), "id"); + cmp = d1.id - d2.id; + } + if (cmp != 0) { + return sf.getReverse() ? -cmp : cmp; + } + } + // Our sort always fully tie breaks: + fail(); + return 0; + } + }; + } + + + private static class GroupDoc { + final int id; + final BytesRef group; + final BytesRef sort1; + final BytesRef sort2; + final BytesRef sort3; + // content must be "realN ..." + final String content; + float score; + + public GroupDoc(int id, BytesRef group, BytesRef sort1, BytesRef sort2, BytesRef sort3, String content) { + this.id = id; + this.group = group; + this.sort1 = sort1; + this.sort2 = sort2; + this.sort3 = sort3; + this.content = content; + } + + } + +} \ No newline at end of file