diff --git a/lucene/contrib/CHANGES.txt b/lucene/contrib/CHANGES.txt index e070a69a3f2..20b8042e9cb 100644 --- a/lucene/contrib/CHANGES.txt +++ b/lucene/contrib/CHANGES.txt @@ -72,6 +72,8 @@ New Features start/endOffset, if offsets are indexed. (Alan Woodward via Mike McCandless) + * LUCENE-3802: Support for grouped faceting. (Martijn van Groningen) + API Changes * LUCENE-2606: Changed RegexCapabilities interface to fix thread diff --git a/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java b/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java index 038c62f4de2..677facbf6e1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocTermOrds.java @@ -216,6 +216,13 @@ public class DocTermOrds { } } + /** + * @return The number of terms in this field + */ + public int numTerms() { + return numTermsInField; + } + /** Subclass can override this */ protected void visitTerm(TermsEnum te, int termNum) throws IOException { } diff --git a/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractGroupFacetCollector.java b/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractGroupFacetCollector.java new file mode 100644 index 00000000000..23e855a5fdc --- /dev/null +++ b/modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractGroupFacetCollector.java @@ -0,0 +1,224 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.util.*; + +/** + * Base class for computing grouped facets. + * + * @lucene.experimental + */ +public abstract class AbstractGroupFacetCollector extends Collector { + + protected final String groupField; + protected final String facetField; + protected final BytesRef facetPrefix; + + protected AbstractGroupFacetCollector(String groupField, String facetField, BytesRef facetPrefix) { + this.groupField = groupField; + this.facetField = facetField; + this.facetPrefix = facetPrefix; + } + + /** + * Returns grouped facet results that were computed over zero or more segments. + * Grouped facet counts are merged from zero or more segment results. + * + * @param size The total number of facets to include. This is typically offset + limit + * @param minCount The minimum count a facet entry should have to be included in the grouped facet result + * @param orderByCount Whether to sort the facet entries by facet entry count. If false then the facets + * are sorted lexicographically in ascending order. + * @return grouped facet results + * @throws IOException If I/O related errors occur during merging segment grouped facet counts. + */ + public abstract GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean orderByCount) throws IOException; + + public void setScorer(Scorer scorer) throws IOException { + } + + public boolean acceptsDocsOutOfOrder() { + return true; + } + + /** + * The grouped facet result. Containing grouped facet entries, total count and total missing count. + */ + public static class GroupedFacetResult { + + private final static Comparator orderByCountAndValue = new Comparator() { + + public int compare(FacetEntry a, FacetEntry b) { + int cmp = b.count - a.count; // Highest count first! + if (cmp != 0) { + return cmp; + } + return a.value.compareTo(b.value); + } + + }; + + private final static Comparator orderByValue = new Comparator() { + + public int compare(FacetEntry a, FacetEntry b) { + return a.value.compareTo(b.value); + } + + }; + + private final int maxSize; + private final NavigableSet facetEntries; + private final int totalMissingCount; + private final int totalCount; + + private int currentMin; + + public GroupedFacetResult(int size, int minCount, boolean orderByCount, int totalCount, int totalMissingCount) { + this.facetEntries = new TreeSet(orderByCount ? orderByCountAndValue : orderByValue); + this.totalMissingCount = totalMissingCount; + this.totalCount = totalCount; + maxSize = size; + currentMin = minCount; + } + + public void addFacetCount(BytesRef facetValue, int count) { + if (count < currentMin) { + return; + } + + FacetEntry facetEntry = new FacetEntry(facetValue, count); + if (facetEntries.size() == maxSize) { + if (facetEntries.higher(facetEntry) == null) { + return; + } + facetEntries.pollLast(); + } + facetEntries.add(facetEntry); + + if (facetEntries.size() == maxSize) { + currentMin = facetEntries.last().count; + } + } + + /** + * Returns a list of facet entries to be rendered based on the specified offset and limit. + * The facet entries are retrieved from the facet entries collected during merging. + * + * @param offset The offset in the collected facet entries during merging + * @param limit The number of facets to return starting from the offset. + * @return a list of facet entries to be rendered based on the specified offset and limit + */ + public List getFacetEntries(int offset, int limit) { + List entries = new LinkedList(); + limit += offset; + + int i = 0; + for (FacetEntry facetEntry : facetEntries) { + if (i < offset) { + i++; + continue; + } + if (i++ >= limit) { + break; + } + entries.add(facetEntry); + } + return entries; + } + + /** + * Returns the sum of all facet entries counts. + * + * @return the sum of all facet entries counts + */ + public int getTotalCount() { + return totalCount; + } + + /** + * Returns the number of groups that didn't have a facet value. + * + * @return the number of groups that didn't have a facet value + */ + public int getTotalMissingCount() { + return totalMissingCount; + } + } + + /** + * Represents a facet entry with a value and a count. + */ + public static class FacetEntry { + + private final BytesRef value; + private final int count; + + public FacetEntry(BytesRef value, int count) { + this.value = value; + this.count = count; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + FacetEntry that = (FacetEntry) o; + + if (count != that.count) return false; + if (!value.equals(that.value)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = value.hashCode(); + result = 31 * result + count; + return result; + } + + @Override + public String toString() { + return "FacetEntry{" + + "value=" + value.utf8ToString() + + ", count=" + count + + '}'; + } + + /** + * @return The value of this facet entry + */ + public BytesRef getValue() { + return value; + } + + /** + * @return The count (number of groups) of this facet entry. + */ + public int getCount() { + return count; + } + } + +} diff --git a/modules/grouping/src/java/org/apache/lucene/search/grouping/term/TermGroupFacetCollector.java b/modules/grouping/src/java/org/apache/lucene/search/grouping/term/TermGroupFacetCollector.java new file mode 100644 index 00000000000..4a9326e38e9 --- /dev/null +++ b/modules/grouping/src/java/org/apache/lucene/search/grouping/term/TermGroupFacetCollector.java @@ -0,0 +1,391 @@ +package org.apache.lucene.search.grouping.term; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.AtomicReaderContext; +import org.apache.lucene.index.DocTermOrds; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.FieldCache; +import org.apache.lucene.search.grouping.AbstractGroupFacetCollector; +import org.apache.lucene.util.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * An implementation of {@link AbstractGroupFacetCollector} that computes grouped facets based on the indexed terms + * from the {@link FieldCache}. + * + * @lucene.experimental + */ +public abstract class TermGroupFacetCollector extends AbstractGroupFacetCollector { + + final List groupedFacetHits; + final SentinelIntSet segmentGroupedFacetHits; + final List segmentResults; + final BytesRef spare = new BytesRef(); + + FieldCache.DocTermsIndex groupFieldTermsIndex; + int[] segmentFacetCounts; + int segmentTotalCount; + int startFacetOrd; + int endFacetOrd; + + /** + * Factory method for creating the right implementation based on the fact whether the facet field contains + * multiple tokens per documents. + * + * @param groupField The group field + * @param facetField The facet field + * @param facetFieldMultivalued Whether the facet field has multiple tokens per document + * @param facetPrefix The facet prefix a facet entry should start with to be included. + * @param initialSize The initial allocation size of the internal int set and group facet list which should roughly + * match the total number of expected unique groups. Be aware that the heap usage is + * 4 bytes * initialSize. + * @return TermGroupFacetCollector implementation + */ + public static TermGroupFacetCollector createTermGroupFacetCollector(String groupField, + String facetField, + boolean facetFieldMultivalued, + BytesRef facetPrefix, + int initialSize) { + if (facetFieldMultivalued) { + return new MV(groupField, facetField, facetPrefix, initialSize); + } else { + return new SV(groupField, facetField, facetPrefix, initialSize); + } + } + + TermGroupFacetCollector(String groupField, String facetField, BytesRef facetPrefix, int initialSize) { + super(groupField, facetField, facetPrefix); + groupedFacetHits = new ArrayList(initialSize); + segmentGroupedFacetHits = new SentinelIntSet(initialSize, -1); + segmentResults = new ArrayList(); + } + + /** + * {@inheritDoc} + */ + public GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean orderByCount) throws IOException { + if (segmentFacetCounts != null) { + segmentResults.add(createSegmentResult()); + segmentFacetCounts = null; // reset + } + + int totalCount = 0; + int missingCount = 0; + SegmentResultPriorityQueue segments = new SegmentResultPriorityQueue(segmentResults.size()); + for (SegmentResult segmentResult : segmentResults) { + missingCount += segmentResult.missing; + if (segmentResult.mergePos >= segmentResult.maxTermPos) { + continue; + } + totalCount += segmentResult.total; + segmentResult.initializeForMerge(); + segments.add(segmentResult); + } + + GroupedFacetResult facetResult = new GroupedFacetResult(size, minCount, orderByCount, totalCount, missingCount); + while (segments.size() > 0) { + SegmentResult segmentResult = segments.top(); + BytesRef currentFacetValue = BytesRef.deepCopyOf(segmentResult.mergeTerm); + int count = 0; + + do { + count += segmentResult.counts[segmentResult.mergePos++]; + if (segmentResult.mergePos < segmentResult.maxTermPos) { + segmentResult.nextTerm(); + segmentResult = segments.updateTop(); + } else { + segments.pop(); + segmentResult = segments.top(); + if (segmentResult == null) { + break; + } + } + } while (currentFacetValue.equals(segmentResult.mergeTerm)); + facetResult.addFacetCount(currentFacetValue, count); + } + return facetResult; + } + + protected abstract SegmentResult createSegmentResult(); + + // Implementation for single valued facet fields. + static class SV extends TermGroupFacetCollector { + + private FieldCache.DocTermsIndex facetFieldTermsIndex; + + SV(String groupField, String facetField, BytesRef facetPrefix, int initialSize) { + super(groupField, facetField, facetPrefix, initialSize); + } + + public void collect(int doc) throws IOException { + int facetOrd = facetFieldTermsIndex.getOrd(doc); + if (facetOrd < startFacetOrd || facetOrd >= endFacetOrd) { + return; + } + + int groupOrd = groupFieldTermsIndex.getOrd(doc); + int segmentGroupedFacetsIndex = (groupOrd * facetFieldTermsIndex.numOrd()) + facetOrd; + if (segmentGroupedFacetHits.exists(segmentGroupedFacetsIndex)) { + return; + } + + segmentTotalCount++; + segmentFacetCounts[facetOrd]++; + + segmentGroupedFacetHits.put(segmentGroupedFacetsIndex); + groupedFacetHits.add( + new GroupedFacetHit( + groupOrd == 0 ? null : groupFieldTermsIndex.lookup(groupOrd, new BytesRef()), + facetOrd == 0 ? null : facetFieldTermsIndex.lookup(facetOrd, new BytesRef()) + ) + ); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + if (segmentFacetCounts != null) { + segmentResults.add(createSegmentResult()); + } + + groupFieldTermsIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), groupField); + facetFieldTermsIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), facetField); + segmentFacetCounts = new int[facetFieldTermsIndex.numOrd()]; + segmentTotalCount = 0; + + segmentGroupedFacetHits.clear(); + for (GroupedFacetHit groupedFacetHit : groupedFacetHits) { + int facetOrd = facetFieldTermsIndex.binarySearchLookup(groupedFacetHit.facetValue, spare); + if (facetOrd < 0) { + continue; + } + + int groupOrd = groupFieldTermsIndex.binarySearchLookup(groupedFacetHit.groupValue, spare); + if (groupOrd < 0) { + continue; + } + + int segmentGroupedFacetsIndex = (groupOrd * facetFieldTermsIndex.numOrd()) + facetOrd; + segmentGroupedFacetHits.put(segmentGroupedFacetsIndex); + } + + if (facetPrefix != null) { + startFacetOrd = facetFieldTermsIndex.binarySearchLookup(facetPrefix, spare); + if (startFacetOrd < 0) { + // Points to the ord one higher than facetPrefix + startFacetOrd = -startFacetOrd - 1; + } + BytesRef facetEndPrefix = BytesRef.deepCopyOf(facetPrefix); + facetEndPrefix.append(UnicodeUtil.BIG_TERM); + endFacetOrd = facetFieldTermsIndex.binarySearchLookup(facetEndPrefix, spare); + endFacetOrd = -endFacetOrd - 1; // Points to the ord one higher than facetEndPrefix + } else { + startFacetOrd = 0; + endFacetOrd = facetFieldTermsIndex.numOrd(); + } + } + + protected SegmentResult createSegmentResult() { + return new SegmentResult(segmentFacetCounts, segmentTotalCount, facetFieldTermsIndex.getTermsEnum(), startFacetOrd, endFacetOrd); + } + } + + // Implementation for multi valued facet fields. + static class MV extends TermGroupFacetCollector { + + private DocTermOrds facetFieldDocTermOrds; + private TermsEnum facetOrdTermsEnum; + private DocTermOrds.TermOrdsIterator reuse; + + MV(String groupField, String facetField, BytesRef facetPrefix, int initialSize) { + super(groupField, facetField, facetPrefix, initialSize); + } + + public void collect(int doc) throws IOException { + int groupOrd = groupFieldTermsIndex.getOrd(doc); + reuse = facetFieldDocTermOrds.lookup(doc, reuse); + int chunk; + boolean first = true; + int[] buffer = new int[5]; + do { + chunk = reuse.read(buffer); + if (first && chunk == 0) { + chunk = 1; + buffer[0] = facetFieldDocTermOrds.numTerms(); // this facet ord is reserved for docs not containing facet field. + } + first = false; + + for (int pos = 0; pos < chunk; pos++) { + int facetOrd = buffer[pos]; + if (facetOrd < startFacetOrd || facetOrd >= endFacetOrd) { + continue; + } + + int segmentGroupedFacetsIndex = (groupOrd * (facetFieldDocTermOrds.numTerms() + 1)) + facetOrd; + if (segmentGroupedFacetHits.exists(segmentGroupedFacetsIndex)) { + continue; + } + + segmentTotalCount++; + segmentFacetCounts[facetOrd]++; + + segmentGroupedFacetHits.put(segmentGroupedFacetsIndex); + groupedFacetHits.add( + new GroupedFacetHit( + groupOrd == 0 ? null : groupFieldTermsIndex.lookup(groupOrd, new BytesRef()), + facetOrd == facetFieldDocTermOrds.numTerms() ? null : BytesRef.deepCopyOf(facetFieldDocTermOrds.lookupTerm(facetOrdTermsEnum, facetOrd)) + ) + ); + } + } while (chunk >= buffer.length); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + if (segmentFacetCounts != null) { + segmentResults.add(createSegmentResult()); + } + + reuse = null; + groupFieldTermsIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), groupField); + facetFieldDocTermOrds = FieldCache.DEFAULT.getDocTermOrds(context.reader(), facetField); + facetOrdTermsEnum = facetFieldDocTermOrds.getOrdTermsEnum(context.reader()); + // [facetFieldDocTermOrds.numTerms() + 1] for all possible facet values and docs not containing facet field + segmentFacetCounts = new int[facetFieldDocTermOrds.numTerms() + 1]; + segmentTotalCount = 0; + + segmentGroupedFacetHits.clear(); + for (GroupedFacetHit groupedFacetHit : groupedFacetHits) { + int groupOrd = groupFieldTermsIndex.binarySearchLookup(groupedFacetHit.groupValue, spare); + if (groupOrd < 0) { + continue; + } + + int facetOrd; + if (groupedFacetHit.facetValue != null) { + if (!facetOrdTermsEnum.seekExact(groupedFacetHit.facetValue, true)) { + continue; + } + facetOrd = (int) facetOrdTermsEnum.ord(); + } else { + facetOrd = facetFieldDocTermOrds.numTerms(); + } + + // (facetFieldDocTermOrds.numTerms() + 1) for all possible facet values and docs not containing facet field + int segmentGroupedFacetsIndex = (groupOrd * (facetFieldDocTermOrds.numTerms() + 1)) + facetOrd; + segmentGroupedFacetHits.put(segmentGroupedFacetsIndex); + } + + if (facetPrefix != null) { + TermsEnum.SeekStatus seekStatus = facetOrdTermsEnum.seekCeil(facetPrefix, true); + if (seekStatus != TermsEnum.SeekStatus.END) { + startFacetOrd = (int) facetOrdTermsEnum.ord(); + } else { + startFacetOrd = 0; + endFacetOrd = 0; + return; + } + + BytesRef facetEndPrefix = BytesRef.deepCopyOf(facetPrefix); + facetEndPrefix.append(UnicodeUtil.BIG_TERM); + seekStatus = facetOrdTermsEnum.seekCeil(facetEndPrefix, true); + if (seekStatus != TermsEnum.SeekStatus.END) { + endFacetOrd = (int) facetOrdTermsEnum.ord(); + } else { + endFacetOrd = facetFieldDocTermOrds.numTerms(); // Don't include null... + } + } else { + startFacetOrd = 0; + endFacetOrd = facetFieldDocTermOrds.numTerms() + 1; + } + } + + protected SegmentResult createSegmentResult() { + return new SegmentResult(segmentFacetCounts, segmentTotalCount, facetFieldDocTermOrds.numTerms(), facetOrdTermsEnum, startFacetOrd, endFacetOrd); + } + } + +} + +class SegmentResult { + + final int[] counts; + final int total; + final int missing; + + // Used for merging the segment results + BytesRef mergeTerm; + int mergePos; + final int maxTermPos; + final TermsEnum tenum; + + SegmentResult(int[] counts, int total, TermsEnum tenum, int startFacetOrd, int endFacetOrd) { + this.counts = counts; + this.missing = counts[0]; + this.total = total - missing; + this.tenum = tenum; + this.mergePos = startFacetOrd == 0 ? 1 : startFacetOrd; + this.maxTermPos = endFacetOrd; + } + + SegmentResult(int[] counts, int total, int missingCountIndex, TermsEnum tenum, int startFacetOrd, int endFacetOrd) { + this.counts = counts; + this.missing = counts[missingCountIndex]; + this.total = total - missing; + this.tenum = tenum; + this.mergePos = startFacetOrd; + if (endFacetOrd == missingCountIndex + 1) { + this.maxTermPos = missingCountIndex; + } else { + this.maxTermPos = endFacetOrd; + } + } + + void initializeForMerge() throws IOException { + tenum.seekExact(mergePos); + mergeTerm = tenum.term(); + } + + void nextTerm() throws IOException { + mergeTerm = tenum.next(); + } + +} + +class GroupedFacetHit { + + final BytesRef groupValue; + final BytesRef facetValue; + + GroupedFacetHit(BytesRef groupValue, BytesRef facetValue) { + this.groupValue = groupValue; + this.facetValue = facetValue; + } +} + +class SegmentResultPriorityQueue extends PriorityQueue { + + SegmentResultPriorityQueue(int maxSize) { + super(maxSize); + } + + protected boolean lessThan(SegmentResult a, SegmentResult b) { + return a.mergeTerm.compareTo(b.mergeTerm) < 0; + } +} diff --git a/modules/grouping/src/test/org/apache/lucene/search/grouping/AbstractGroupingTestCase.java b/modules/grouping/src/test/org/apache/lucene/search/grouping/AbstractGroupingTestCase.java new file mode 100644 index 00000000000..b2a33f02ec1 --- /dev/null +++ b/modules/grouping/src/test/org/apache/lucene/search/grouping/AbstractGroupingTestCase.java @@ -0,0 +1,52 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util._TestUtil; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * Base class for grouping related tests. + */ +// TODO (MvG) : The grouping tests contain a lot of code duplication. Try to move the common code to this class.. +public class AbstractGroupingTestCase extends LuceneTestCase { + + protected String generateRandomNonEmptyString() { + String randomValue; + do { + // B/c of DV based impl we can't see the difference between an empty string and a null value. + // For that reason we don't generate empty string groups. + randomValue = _TestUtil.randomRealisticUnicodeString(random); + } while ("".equals(randomValue)); + return randomValue; + } + +} diff --git a/modules/grouping/src/test/org/apache/lucene/search/grouping/TermGroupFacetCollectorTest.java b/modules/grouping/src/test/org/apache/lucene/search/grouping/TermGroupFacetCollectorTest.java new file mode 100644 index 00000000000..4dfea352a5f --- /dev/null +++ b/modules/grouping/src/test/org/apache/lucene/search/grouping/TermGroupFacetCollectorTest.java @@ -0,0 +1,600 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.*; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.grouping.term.TermGroupFacetCollector; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util._TestUtil; + +import java.io.IOException; +import java.util.*; + +public class TermGroupFacetCollectorTest extends AbstractGroupingTestCase { + + public void testSimple() throws Exception { + final String groupField = "hotel"; + FieldType customType = new FieldType(); + customType.setStored(true); + + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random, + dir, + newIndexWriterConfig(TEST_VERSION_CURRENT, + new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy())); + boolean canUseIDV = false;// Enable later... !"Lucene3x".equals(w.w.getConfig().getCodec().getName()); + + // 0 + Document doc = new Document(); + addGroupField(doc, groupField, "a", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 1 + doc = new Document(); + addGroupField(doc, groupField, "a", canUseIDV); + doc.add(new Field("airport", "dus", TextField.TYPE_STORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 2 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + w.commit(); // To ensure a second segment + + // 3 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 4 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); + TermGroupFacetCollector groupedAirportFacetCollector = + TermGroupFacetCollector.createTermGroupFacetCollector(groupField, "airport", false, null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedAirportFacetCollector); + TermGroupFacetCollector.GroupedFacetResult airportResult = groupedAirportFacetCollector.mergeSegmentResults(10, 0, false); + assertEquals(3, airportResult.getTotalCount()); + assertEquals(0, airportResult.getTotalMissingCount()); + + List entries = airportResult.getFacetEntries(0, 10); + assertEquals(2, entries.size()); + assertEquals("ams", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("dus", entries.get(1).getValue().utf8ToString()); + assertEquals(1, entries.get(1).getCount()); + + + TermGroupFacetCollector groupedDurationFacetCollector = + TermGroupFacetCollector.createTermGroupFacetCollector(groupField, "duration", false, null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedDurationFacetCollector); + TermGroupFacetCollector.GroupedFacetResult durationResult = groupedDurationFacetCollector.mergeSegmentResults(10, 0, false); + assertEquals(4, durationResult.getTotalCount()); + assertEquals(0, durationResult.getTotalMissingCount()); + + entries = durationResult.getFacetEntries(0, 10); + assertEquals(2, entries.size()); + assertEquals("10", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("5", entries.get(1).getValue().utf8ToString()); + assertEquals(2, entries.get(1).getCount()); + + // 5 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 6 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 7 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "15", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 8 + doc = new Document(); + addGroupField(doc, groupField, "a", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + indexSearcher.getIndexReader().close(); + indexSearcher = new IndexSearcher(w.getReader()); + groupedAirportFacetCollector = TermGroupFacetCollector.createTermGroupFacetCollector(groupField, "airport", true, null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedAirportFacetCollector); + airportResult = groupedAirportFacetCollector.mergeSegmentResults(3, 0, true); + assertEquals(5, airportResult.getTotalCount()); + assertEquals(1, airportResult.getTotalMissingCount()); + + entries = airportResult.getFacetEntries(1, 2); + assertEquals(2, entries.size()); + assertEquals("bru", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("dus", entries.get(1).getValue().utf8ToString()); + assertEquals(1, entries.get(1).getCount()); + + groupedDurationFacetCollector = TermGroupFacetCollector.createTermGroupFacetCollector(groupField, "duration", false, null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedDurationFacetCollector); + durationResult = groupedDurationFacetCollector.mergeSegmentResults(10, 2, true); + assertEquals(5, durationResult.getTotalCount()); + assertEquals(0, durationResult.getTotalMissingCount()); + + entries = durationResult.getFacetEntries(1, 1); + assertEquals(1, entries.size()); + assertEquals("5", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + + // 9 + doc = new Document(); + addGroupField(doc, groupField, "c", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "15", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 10 + doc = new Document(); + addGroupField(doc, groupField, "c", canUseIDV); + doc.add(new Field("airport", "dus", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + indexSearcher.getIndexReader().close(); + indexSearcher = new IndexSearcher(w.getReader()); + groupedAirportFacetCollector = TermGroupFacetCollector.createTermGroupFacetCollector(groupField, "airport", false, null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedAirportFacetCollector); + airportResult = groupedAirportFacetCollector.mergeSegmentResults(10, 0, false); + assertEquals(7, airportResult.getTotalCount()); + assertEquals(1, airportResult.getTotalMissingCount()); + + entries = airportResult.getFacetEntries(0, 10); + assertEquals(3, entries.size()); + assertEquals("ams", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("bru", entries.get(1).getValue().utf8ToString()); + assertEquals(3, entries.get(1).getCount()); + assertEquals("dus", entries.get(2).getValue().utf8ToString()); + assertEquals(2, entries.get(2).getCount()); + + groupedDurationFacetCollector = TermGroupFacetCollector.createTermGroupFacetCollector(groupField, "duration", false, new BytesRef("1"), 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedDurationFacetCollector); + durationResult = groupedDurationFacetCollector.mergeSegmentResults(10, 0, true); + assertEquals(5, durationResult.getTotalCount()); + assertEquals(0, durationResult.getTotalMissingCount()); + + entries = durationResult.getFacetEntries(0, 10); + assertEquals(2, entries.size()); + assertEquals("10", entries.get(0).getValue().utf8ToString()); + assertEquals(3, entries.get(0).getCount()); + assertEquals("15", entries.get(1).getValue().utf8ToString()); + assertEquals(2, entries.get(1).getCount()); + + w.close(); + indexSearcher.getIndexReader().close(); + dir.close(); + } + + private void addGroupField(Document doc, String groupField, String value, boolean canUseIDV) { + doc.add(new Field(groupField, value, TextField.TYPE_UNSTORED)); + if (canUseIDV) { + doc.add(new DocValuesField(groupField, new BytesRef(value), DocValues.Type.BYTES_VAR_SORTED)); + } + } + + public void testRandom() throws Exception { + int numberOfRuns = _TestUtil.nextInt(random, 3, 6); + for (int indexIter = 0; indexIter < numberOfRuns; indexIter++) { + boolean multipleFacetsPerDocument = random.nextBoolean(); + IndexContext context = createIndexContext(multipleFacetsPerDocument); + final IndexSearcher searcher = newSearcher(context.indexReader); + + for (int searchIter = 0; searchIter < 100; searchIter++) { + String searchTerm = context.contentStrings[random.nextInt(context.contentStrings.length)]; + int limit = random.nextInt(context.facetValues.size()); + int offset = random.nextInt(context.facetValues.size() - limit); + int size = offset + limit; + int minCount = random.nextBoolean() ? 0 : random.nextInt(1 + context.facetWithMostGroups / 10); + boolean orderByCount = random.nextBoolean(); + String randomStr = getFromSet(context.facetValues, random.nextInt(context.facetValues.size())); + final String facetPrefix; + if (randomStr == null) { + facetPrefix = null; + } else { + int codePointLen = randomStr.codePointCount(0, randomStr.length()); + int randomLen = random.nextInt(codePointLen); + if (codePointLen == randomLen - 1) { + facetPrefix = null; + } else { + int end = randomStr.offsetByCodePoints(0, randomLen); + facetPrefix = random.nextBoolean() ? null : randomStr.substring(end); + } + } + + GroupedFacetResult expectedFacetResult = createExpectedFacetResult(searchTerm, context, offset, limit, minCount, orderByCount, facetPrefix); + TermGroupFacetCollector groupFacetCollector = createRandomCollector("group", "facet", facetPrefix, multipleFacetsPerDocument); + searcher.search(new TermQuery(new Term("content", searchTerm)), groupFacetCollector); + TermGroupFacetCollector.GroupedFacetResult actualFacetResult = groupFacetCollector.mergeSegmentResults(size, minCount, orderByCount); + + List expectedFacetEntries = expectedFacetResult.getFacetEntries(); + List actualFacetEntries = actualFacetResult.getFacetEntries(offset, limit); + + if (VERBOSE) { + System.out.println("Collector: " + groupFacetCollector.getClass().getSimpleName()); + System.out.println("Num group: " + context.numGroups); + System.out.println("Num doc: " + context.numDocs); + System.out.println("Index iter: " + indexIter); + System.out.println("multipleFacetsPerDocument: " + multipleFacetsPerDocument); + System.out.println("Search iter: " + searchIter); + + System.out.println("Search term: " + searchTerm); + System.out.println("Min count: " + minCount); + System.out.println("Facet offset: " + offset); + System.out.println("Facet limit: " + limit); + System.out.println("Facet prefix: " + facetPrefix); + System.out.println("Order by count: " + orderByCount); + + System.out.println("\n=== Expected: \n"); + System.out.println("Total count " + expectedFacetResult.getTotalCount()); + System.out.println("Total missing count " + expectedFacetResult.getTotalMissingCount()); + int counter = 1; + for (TermGroupFacetCollector.FacetEntry expectedFacetEntry : expectedFacetEntries) { + System.out.println( + String.format( + "%d. Expected facet value %s with count %d", + counter++, expectedFacetEntry.getValue().utf8ToString(), expectedFacetEntry.getCount() + ) + ); + } + + System.out.println("\n=== Actual: \n"); + System.out.println("Total count " + actualFacetResult.getTotalCount()); + System.out.println("Total missing count " + actualFacetResult.getTotalMissingCount()); + counter = 1; + for (TermGroupFacetCollector.FacetEntry actualFacetEntry : actualFacetEntries) { + System.out.println( + String.format( + "%d. Actual facet value %s with count %d", + counter++, actualFacetEntry.getValue().utf8ToString(), actualFacetEntry.getCount() + ) + ); + } + System.out.println("\n==================================================================================="); + } + + assertEquals(expectedFacetResult.getTotalCount(), actualFacetResult.getTotalCount()); + assertEquals(expectedFacetResult.getTotalMissingCount(), actualFacetResult.getTotalMissingCount()); + assertEquals(expectedFacetEntries.size(), actualFacetEntries.size()); + for (int i = 0; i < expectedFacetEntries.size(); i++) { + TermGroupFacetCollector.FacetEntry expectedFacetEntry = expectedFacetEntries.get(i); + TermGroupFacetCollector.FacetEntry actualFacetEntry = actualFacetEntries.get(i); + assertEquals(expectedFacetEntry.getValue().utf8ToString() + " != " + actualFacetEntry.getValue().utf8ToString(), expectedFacetEntry.getValue(), actualFacetEntry.getValue()); + assertEquals(expectedFacetEntry.getCount() + " != " + actualFacetEntry.getCount(), expectedFacetEntry.getCount(), actualFacetEntry.getCount()); + } + } + + context.indexReader.close(); + context.dir.close(); + } + } + + private IndexContext createIndexContext(boolean multipleFacetValuesPerDocument) throws IOException { + final int numDocs = _TestUtil.nextInt(random, 138, 1145) * RANDOM_MULTIPLIER; + final int numGroups = _TestUtil.nextInt(random, 1, numDocs / 4); + final int numFacets = _TestUtil.nextInt(random, 1, numDocs / 6); + + if (VERBOSE) { + System.out.println("TEST: numDocs=" + numDocs + " numGroups=" + numGroups); + } + + final List groups = new ArrayList(); + for (int i = 0; i < numGroups; i++) { + groups.add(generateRandomNonEmptyString()); + } + final List facetValues = new ArrayList(); + for (int i = 0; i < numFacets; i++) { + facetValues.add(generateRandomNonEmptyString()); + } + final String[] contentBrs = new String[_TestUtil.nextInt(random, 2, 20)]; + if (VERBOSE) { + System.out.println("TEST: create fake content"); + } + for (int contentIDX = 0; contentIDX < contentBrs.length; contentIDX++) { + contentBrs[contentIDX] = generateRandomNonEmptyString(); + if (VERBOSE) { + System.out.println(" content=" + contentBrs[contentIDX]); + } + } + + Directory dir = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter( + random, + dir, + newIndexWriterConfig( + TEST_VERSION_CURRENT, + new MockAnalyzer(random) + ) + ); + + Document doc = new Document(); + Document docNoGroup = new Document(); + Document docNoFacet = new Document(); + Document docNoGroupNoFacet = new Document(); + Field group = newField("group", "", StringField.TYPE_UNSTORED); + doc.add(group); + docNoFacet.add(group); + Field[] facetFields = multipleFacetValuesPerDocument? new Field[2 + random.nextInt(6)] : new Field[1]; + for (int i = 0; i < facetFields.length; i++) { + facetFields[i] = newField("facet", "", StringField.TYPE_UNSTORED); + doc.add(facetFields[i]); + docNoGroup.add(facetFields[i]); + } + Field content = newField("content", "", StringField.TYPE_UNSTORED); + doc.add(content); + docNoGroup.add(content); + docNoFacet.add(content); + docNoGroupNoFacet.add(content); + + NavigableSet uniqueFacetValues = new TreeSet(new Comparator() { + + public int compare(String a, String b) { + if (a == b) { + return 0; + } else if (a == null) { + return -1; + } else if (b == null) { + return 1; + } else { + return a.compareTo(b); + } + } + + }); + Map>> searchTermToFacetToGroups = new HashMap>>(); + int facetWithMostGroups = 0; + for (int i = 0; i < numDocs; i++) { + final String groupValue; + if (random.nextInt(24) == 17) { + // So we test the "doc doesn't have the group'd + // field" case: + groupValue = null; + } else { + groupValue = groups.get(random.nextInt(groups.size())); + } + + String contentStr = contentBrs[random.nextInt(contentBrs.length)]; + if (!searchTermToFacetToGroups.containsKey(contentStr)) { + searchTermToFacetToGroups.put(contentStr, new HashMap>()); + } + Map> facetToGroups = searchTermToFacetToGroups.get(contentStr); + + List facetVals = new ArrayList(); + if (random.nextInt(24) != 18) { + for (Field facetField : facetFields) { + String facetValue = facetValues.get(random.nextInt(facetValues.size())); + uniqueFacetValues.add(facetValue); + if (!facetToGroups.containsKey(facetValue)) { + facetToGroups.put(facetValue, new HashSet()); + } + Set groupsInFacet = facetToGroups.get(facetValue); + groupsInFacet.add(groupValue); + if (groupsInFacet.size() > facetWithMostGroups) { + facetWithMostGroups = groupsInFacet.size(); + } + facetField.setStringValue(facetValue); + facetVals.add(facetValue); + } + } else { + uniqueFacetValues.add(null); + if (!facetToGroups.containsKey(null)) { + facetToGroups.put(null, new HashSet()); + } + Set groupsInFacet = facetToGroups.get(null); + groupsInFacet.add(groupValue); + if (groupsInFacet.size() > facetWithMostGroups) { + facetWithMostGroups = groupsInFacet.size(); + } + } + + if (VERBOSE) { + System.out.println(" doc content=" + contentStr + " group=" + (groupValue == null ? "null" : groupValue) + " facetVals=" + facetVals); + } + + if (groupValue != null) { + group.setStringValue(groupValue); + } + content.setStringValue(contentStr); + if (groupValue == null && facetVals.isEmpty()) { + writer.addDocument(docNoGroupNoFacet); + } else if (facetVals.isEmpty()) { + writer.addDocument(docNoFacet); + } else if (groupValue == null) { + writer.addDocument(docNoGroup); + } else { + writer.addDocument(doc); + } + } + + DirectoryReader reader = writer.getReader(); + writer.close(); + + return new IndexContext(searchTermToFacetToGroups, reader, numDocs, dir, facetWithMostGroups, numGroups, contentBrs, uniqueFacetValues); + } + + private GroupedFacetResult createExpectedFacetResult(String searchTerm, IndexContext context, int offset, int limit, int minCount, final boolean orderByCount, String facetPrefix) { + Map> facetGroups = context.searchTermToFacetGroups.get(searchTerm); + if (facetGroups == null) { + facetGroups = new HashMap>(); + } + + int totalCount = 0; + int totalMissCount = 0; + Set facetValues; + if (facetPrefix != null) { + facetValues = new HashSet(); + for (String facetValue : context.facetValues) { + if (facetValue != null && facetValue.startsWith(facetPrefix)) { + facetValues.add(facetValue); + } + } + } else { + facetValues = context.facetValues; + } + + List entries = new ArrayList(facetGroups.size()); + // also includes facets with count 0 + for (String facetValue : facetValues) { + if (facetValue == null) { + continue; + } + + Set groups = facetGroups.get(facetValue); + int count = groups != null ? groups.size() : 0; + if (count >= minCount) { + entries.add(new TermGroupFacetCollector.FacetEntry(new BytesRef(facetValue), count)); + } + totalCount += count; + } + + // Only include null count when no facet prefix is specified + if (facetPrefix == null) { + Set groups = facetGroups.get(null); + if (groups != null) { + totalMissCount = groups.size(); + } + } + + Collections.sort(entries, new Comparator() { + + public int compare(TermGroupFacetCollector.FacetEntry a, TermGroupFacetCollector.FacetEntry b) { + if (orderByCount) { + int cmp = b.getCount() - a.getCount(); + if (cmp != 0) { + return cmp; + } + } + return a.getValue().compareTo(b.getValue()); + } + + }); + + int endOffset = offset + limit; + List entriesResult; + if (offset >= entries.size()) { + entriesResult = Collections.emptyList(); + } else if (endOffset >= entries.size()) { + entriesResult = entries.subList(offset, entries.size()); + } else { + entriesResult = entries.subList(offset, endOffset); + } + return new GroupedFacetResult(totalCount, totalMissCount, entriesResult); + } + + private TermGroupFacetCollector createRandomCollector(String groupField, String facetField, String facetPrefix, boolean multipleFacetsPerDocument) { + BytesRef facetPrefixBR = facetPrefix == null ? null : new BytesRef(facetPrefix); + return TermGroupFacetCollector.createTermGroupFacetCollector(groupField, facetField, multipleFacetsPerDocument, facetPrefixBR, random.nextInt(1024)); + } + + private String getFromSet(Set set, int index) { + int currentIndex = 0; + for (String bytesRef : set) { + if (currentIndex++ == index) { + return bytesRef; + } + } + + return null; + } + + private class IndexContext { + + final int numDocs; + final DirectoryReader indexReader; + final Map>> searchTermToFacetGroups; + final NavigableSet facetValues; + final Directory dir; + final int facetWithMostGroups; + final int numGroups; + final String[] contentStrings; + + public IndexContext(Map>> searchTermToFacetGroups, DirectoryReader r, + int numDocs, Directory dir, int facetWithMostGroups, int numGroups, String[] contentStrings, NavigableSet facetValues) { + this.searchTermToFacetGroups = searchTermToFacetGroups; + this.indexReader = r; + this.numDocs = numDocs; + this.dir = dir; + this.facetWithMostGroups = facetWithMostGroups; + this.numGroups = numGroups; + this.contentStrings = contentStrings; + this.facetValues = facetValues; + } + } + + private class GroupedFacetResult { + + final int totalCount; + final int totalMissingCount; + final List facetEntries; + + private GroupedFacetResult(int totalCount, int totalMissingCount, List facetEntries) { + this.totalCount = totalCount; + this.totalMissingCount = totalMissingCount; + this.facetEntries = facetEntries; + } + + public int getTotalCount() { + return totalCount; + } + + public int getTotalMissingCount() { + return totalMissingCount; + } + + public List getFacetEntries() { + return facetEntries; + } + } + +}