From d197f012ef4577891233c220305ced0e525a5a10 Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Mon, 14 Oct 2024 19:10:01 -0500 Subject: [PATCH] Use multi-select instead of sort for Dynamic Ranges --- .../apache/lucene/util/WeightedSelector.java | 407 ++++++++++++++++++ .../lucene/facet/range/DynamicRangeUtil.java | 113 +++-- .../facet/range/TestDynamicRangeUtil.java | 27 +- 3 files changed, 491 insertions(+), 56 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/util/WeightedSelector.java diff --git a/lucene/core/src/java/org/apache/lucene/util/WeightedSelector.java b/lucene/core/src/java/org/apache/lucene/util/WeightedSelector.java new file mode 100644 index 00000000000..723e8a98cf4 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/WeightedSelector.java @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.util; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.SplittableRandom; + +/** + * Adaptive selection algorithm based on the introspective quick select algorithm. The quick select + * algorithm uses an interpolation variant of Tukey's ninther median-of-medians for pivot, and + * Bentley-McIlroy 3-way partitioning. For the introspective protection, it shuffles the sub-range + * if the max recursive depth is exceeded. + * + *

This selection algorithm is fast on most data shapes, especially on nearly sorted data, or + * when k is close to the boundaries. It runs in linear time on average. + * + * @lucene.internal + */ +public abstract class WeightedSelector { + + // This selector is used repeatedly by the radix selector for sub-ranges of less than + // 100 entries. This means this selector is also optimized to be fast on small ranges. + // It uses the variant of medians-of-medians and 3-way partitioning, and finishes the + // last tiny range (3 entries or less) with a very specialized sort. + + private SplittableRandom random; + + protected abstract long getWeight(int i); + + protected abstract long getValue(int i); + + public final WeightRangeInfo[] select( + int from, + int to, + long rangeTotalValue, + long beforeTotalValue, + long rangeWeight, + long beforeWeight, + double[] kWeights) { + WeightRangeInfo[] kIndexResults = new WeightRangeInfo[kWeights.length]; + Arrays.fill(kIndexResults, new WeightRangeInfo(-1, 0, 0)); + checkArgs(rangeWeight, beforeWeight, kWeights); + select( + from, + to, + rangeTotalValue, + beforeTotalValue, + rangeWeight, + beforeWeight, + kWeights, + 0, + kWeights.length, + kIndexResults, + 2 * MathUtil.log(to - from, 2)); + return kIndexResults; + } + + void checkArgs(long rangeWeight, long beforeWeight, double[] kWeights) { + if (kWeights.length < 1) { + throw new IllegalArgumentException("There must be at least one k to select, none given"); + } + Arrays.sort(kWeights); + if (kWeights[0] < beforeWeight) { + throw new IllegalArgumentException("All kWeights must be >= beforeWeight"); + } + if (kWeights[kWeights.length - 1] > beforeWeight + rangeWeight) { + throw new IllegalArgumentException("All kWeights must be < beforeWeight + rangeWeight"); + } + } + + // Visible for testing. + void select( + int from, + int to, + long rangeTotalValue, + long beforeTotalValue, + long rangeWeight, + long beforeWeight, + double[] kWeights, + int kFrom, + int kTo, + WeightRangeInfo[] kIndexResults, + int maxDepth) { + + // This code is inspired from IntroSorter#sort, adapted to loop on a single partition. + + // For efficiency, we must enter the loop with at least 4 entries to be able to skip + // some boundary tests during the 3-way partitioning. + int size; + if ((size = to - from) > 3) { + + if (--maxDepth == -1) { + // Max recursion depth exceeded: shuffle (only once) and continue. + shuffle(from, to); + } + + // Pivot selection based on medians. + int last = to - 1; + int mid = (from + last) >>> 1; + int pivot; + if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) { + // Select the pivot with a single median around the middle element. + // Do not take the median between [from, mid, last] because it hurts performance + // if the order is descending in conjunction with the 3-way partitioning. + int range = size >> 2; + pivot = median(mid - range, mid, mid + range); + } else { + // Select the pivot with a variant of the Tukey's ninther median of medians. + // If k is close to the boundaries, select either the lowest or highest median (this variant + // is inspired from the interpolation search). + int range = size >> 3; + int doubleRange = range << 1; + int medianFirst = median(from, from + range, from + doubleRange); + int medianMiddle = median(mid - range, mid, mid + range); + int medianLast = median(last - doubleRange, last - range, last); + + double avgWeight = ((double) rangeWeight) / (to - from); + double middleWeight = kWeights[(kFrom + kTo - 1) >> 1]; + // Approximate the k we are trying to find by assuming an equal weight amongst values + int middleK = from + (int) ((middleWeight - beforeWeight) / avgWeight); + if (middleK - from < range) { + // k is close to 'from': select the lowest median. + pivot = min(medianFirst, medianMiddle, medianLast); + } else if (to - middleK <= range) { + // k is close to 'to': select the highest median. + pivot = max(medianFirst, medianMiddle, medianLast); + } else { + // Otherwise select the median of medians. + pivot = median(medianFirst, medianMiddle, medianLast); + } + } + + // Bentley-McIlroy 3-way partitioning. + setPivot(pivot); + swap(from, pivot); + int i = from; + int j = to; + int p = from + 1; + int q = last; + long leftTotalValue = 0; + long leftWeight = 0; + long rightTotalValue = 0; + long rightWeight = 0; + while (true) { + int leftCmp, rightCmp; + while ((leftCmp = comparePivot(++i)) > 0) { + leftTotalValue += getValue(i); + leftWeight += getWeight(i); + } + while ((rightCmp = comparePivot(--j)) < 0) { + rightTotalValue += getValue(j); + rightWeight += getWeight(j); + } + if (i >= j) { + if (i == j && rightCmp == 0) { + swap(i, p); + } + break; + } + swap(i, j); + if (rightCmp == 0) { + swap(i, p++); + } else { + leftTotalValue += getValue(i); + leftWeight += getWeight(i); + } + if (leftCmp == 0) { + swap(j, q--); + } else { + rightTotalValue += getValue(j); + rightWeight += getWeight(j); + } + } + i = j + 1; + for (int l = from; l < p; ) { + swap(l++, j--); + } + for (int l = last; l > q; ) { + swap(l--, i++); + } + long leftWeightEnd = beforeWeight + leftWeight; + long rightWeightStart = beforeWeight + rangeWeight - rightWeight; + + // Select the K weight values contained in the bottom and top partitions. + int topKFrom = kTo; + int bottomKTo = kFrom; + for (int ki = kTo - 1; ki >= kFrom; ki--) { + if (kWeights[ki] >= rightWeightStart) { + topKFrom = ki; + } + if (kWeights[ki] <= leftWeightEnd) { + bottomKTo = ki + 1; + break; + } + } + // Recursively select the relevant k-values from the bottom group, if there are any k-values + // to select there + if (bottomKTo > kFrom) { + select( + from, + j + 1, + leftTotalValue, + beforeTotalValue, + leftWeight, + beforeWeight, + kWeights, + kFrom, + bottomKTo, + kIndexResults, + maxDepth); + } + // Recursively select the relevant k-values from the top group, if there are any k-values to + // select there + if (topKFrom < kTo) { + select( + i, + to, + rightTotalValue, + beforeTotalValue + rangeTotalValue - rightTotalValue, + rightWeight, + beforeWeight + rangeWeight - rightWeight, + kWeights, + topKFrom, + kTo, + kIndexResults, + maxDepth); + } + + // Choose the k result indexes for this partition + if (bottomKTo < topKFrom) { + findKIndexes( + j + 1, + i, + beforeTotalValue + leftTotalValue, + beforeWeight + leftWeight, + bottomKTo, + topKFrom, + kWeights, + kIndexResults); + } + } + + // Sort the final tiny range (3 entries or less) with a very specialized sort. + switch (size) { + case 1: + kIndexResults[kTo - 1] = + new WeightRangeInfo( + from, beforeTotalValue + getValue(from), beforeWeight + getWeight(from)); + break; + case 2: + if (compare(from, from + 1) > 0) { + swap(from, from + 1); + } + findKIndexes( + from, from + 2, beforeTotalValue, beforeWeight, kFrom, kTo, kWeights, kIndexResults); + break; + case 3: + sort3(from); + findKIndexes( + from, from + 3, beforeTotalValue, beforeWeight, kFrom, kTo, kWeights, kIndexResults); + break; + } + } + + private void findKIndexes( + int from, + int to, + long beforeTotalValue, + long beforeWeight, + int kFrom, + int kTo, + double[] kWeights, + WeightRangeInfo[] kIndexResults) { + long runningWeight = beforeWeight; + long runningTotalValue = beforeTotalValue; + int kIdx = kFrom; + for (int listIdx = from; listIdx < to && kIdx < kTo; listIdx++) { + runningWeight += getWeight(listIdx); + runningTotalValue += getValue(listIdx); + // Skip ahead in the weight list if the same value is used for multiple weights, we will only + // record a result index for the last weight that matches it. + while (++kIdx < kTo && kWeights[kIdx] <= runningWeight) {} + if (kWeights[--kIdx] <= runningWeight) { + kIndexResults[kIdx] = new WeightRangeInfo(listIdx, runningTotalValue, runningWeight); + // Now that we have recorded the resultIndex for this weight, go to the next weight + kIdx++; + } + } + } + + /** Returns the index of the min element among three elements at provided indices. */ + private int min(int i, int j, int k) { + if (compare(i, j) <= 0) { + return compare(i, k) <= 0 ? i : k; + } + return compare(j, k) <= 0 ? j : k; + } + + /** Returns the index of the max element among three elements at provided indices. */ + private int max(int i, int j, int k) { + if (compare(i, j) <= 0) { + return compare(j, k) < 0 ? k : j; + } + return compare(i, k) < 0 ? k : i; + } + + /** Copy of {@code IntroSorter#median}. */ + private int median(int i, int j, int k) { + if (compare(i, j) < 0) { + if (compare(j, k) <= 0) { + return j; + } + return compare(i, k) < 0 ? k : i; + } + if (compare(j, k) >= 0) { + return j; + } + return compare(i, k) < 0 ? i : k; + } + + /** + * Sorts 3 entries starting at from (inclusive). This specialized method is more efficient than + * calling {@link Sorter#insertionSort(int, int)}. + */ + private void sort3(int from) { + final int mid = from + 1; + final int last = from + 2; + if (compare(from, mid) <= 0) { + if (compare(mid, last) > 0) { + swap(mid, last); + if (compare(from, mid) > 0) { + swap(from, mid); + } + } + } else if (compare(mid, last) >= 0) { + swap(from, last); + } else { + swap(from, mid); + if (compare(mid, last) > 0) { + swap(mid, last); + } + } + } + + /** + * Shuffles the entries between from (inclusive) and to (exclusive) with Durstenfeld's algorithm. + */ + private void shuffle(int from, int to) { + if (this.random == null) { + this.random = new SplittableRandom(); + } + SplittableRandom random = this.random; + for (int i = to - 1; i > from; i--) { + swap(i, random.nextInt(from, i + 1)); + } + } + + /** Swap values at slots i and j. */ + protected abstract void swap(int i, int j); + + /** + * Save the value at slot i so that it can later be used as a pivot, see {@link + * #comparePivot(int)}. + */ + protected abstract void setPivot(int i); + + /** + * Compare the pivot with the slot at j, similarly to {@link #compare(int, int) + * compare(i, j)}. + */ + protected abstract int comparePivot(int j); + + /** + * Compare entries found in slots i and j. The contract for the returned + * value is the same as {@link Comparator#compare(Object, Object)}. + */ + protected int compare(int i, int j) { + setPivot(i); + return comparePivot(j); + } + + /** + * Holds information for a returned weight index result + * + * @param index the index at which the weight range limit was found + * @param runningValueSum the sum of values from the start of the list to the end of the range + * limit + * @param runningWeight the sum of weights from the start of the list to the end of the range + * limit + */ + public record WeightRangeInfo(int index, long runningValueSum, long runningWeight) {} +} diff --git a/lucene/facet/src/java/org/apache/lucene/facet/range/DynamicRangeUtil.java b/lucene/facet/src/java/org/apache/lucene/facet/range/DynamicRangeUtil.java index b6ae71217f2..6d54c139710 100644 --- a/lucene/facet/src/java/org/apache/lucene/facet/range/DynamicRangeUtil.java +++ b/lucene/facet/src/java/org/apache/lucene/facet/range/DynamicRangeUtil.java @@ -28,7 +28,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.LongValues; import org.apache.lucene.search.LongValuesSource; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.InPlaceMergeSorter; +import org.apache.lucene.util.WeightedSelector; /** * Methods to create dynamic ranges for numeric fields. @@ -66,6 +66,7 @@ public final class DynamicRangeUtil { matchingDocsList.stream().mapToInt(FacetsCollector.MatchingDocs::totalHits).sum(); long[] values = new long[totalDoc]; long[] weights = new long[totalDoc]; + long totalValue = 0; long totalWeight = 0; int overallLength = 0; @@ -107,6 +108,7 @@ public final class DynamicRangeUtil { assert curSegmentOutput.values.length == curSegmentOutput.weights.length; try { + totalValue = Math.addExact(curSegmentOutput.segmentTotalValue, totalValue); totalWeight = Math.addExact(curSegmentOutput.segmentTotalWeight, totalWeight); } catch (ArithmeticException ae) { throw new IllegalArgumentException( @@ -118,7 +120,8 @@ public final class DynamicRangeUtil { System.arraycopy(curSegmentOutput.weights, 0, weights, overallLength, currSegmentLen); overallLength += currSegmentLen; } - return computeDynamicNumericRanges(values, weights, overallLength, totalWeight, topN); + return computeDynamicNumericRanges( + values, weights, overallLength, totalValue, totalWeight, topN); } private static class SegmentTask implements Callable { @@ -165,6 +168,8 @@ public final class DynamicRangeUtil { segmentOutput.values[segmentOutput.segmentIdx] = curValue; segmentOutput.weights[segmentOutput.segmentIdx] = curWeight; try { + segmentOutput.segmentTotalValue = + Math.addExact(segmentOutput.segmentTotalValue, curValue); segmentOutput.segmentTotalWeight = Math.addExact(segmentOutput.segmentTotalWeight, curWeight); } catch (ArithmeticException ae) { @@ -180,6 +185,7 @@ public final class DynamicRangeUtil { private static final class SegmentOutput { private final long[] values; private final long[] weights; + private long segmentTotalValue; private long segmentTotalWeight; private int segmentIdx; @@ -202,7 +208,7 @@ public final class DynamicRangeUtil { * is used to compute the equi-weight per bin. */ public static List computeDynamicNumericRanges( - long[] values, long[] weights, int len, long totalWeight, int topN) { + long[] values, long[] weights, int len, long totalValue, long totalWeight, int topN) { assert values.length == weights.length && len <= values.length && len >= 0; assert topN >= 0; List dynamicRangeResult = new ArrayList<>(); @@ -210,58 +216,75 @@ public final class DynamicRangeUtil { return dynamicRangeResult; } - new InPlaceMergeSorter() { - @Override - protected int compare(int index1, int index2) { - int cmp = Long.compare(values[index1], values[index2]); - if (cmp == 0) { - // If the values are equal, sort based on the weights. - // Any weight order is correct as long as it's deterministic. - return Long.compare(weights[index1], weights[index2]); - } - return cmp; - } + double rangeWeightTarget = (double) totalWeight / topN; + double[] kWeights = new double[topN]; + for (int i = 0; i < topN; i++) { + kWeights[i] = (i == 0 ? 0 : kWeights[i - 1]) + rangeWeightTarget; + } - @Override - protected void swap(int index1, int index2) { - long tmp = values[index1]; - values[index1] = values[index2]; - values[index2] = tmp; - tmp = weights[index1]; - weights[index1] = weights[index2]; - weights[index2] = tmp; - } - }.sort(0, len); + WeightedSelector.WeightRangeInfo[] kIndexResults = + new WeightedSelector() { + private long pivotValue; + private long pivotWeight; - long accuWeight = 0; - long valueSum = 0; - int count = 0; - int minIdx = 0; + @Override + protected long getWeight(int i) { + return weights[i]; + } - double rangeWeightTarget = (double) totalWeight / Math.min(topN, len); + @Override + protected long getValue(int i) { + return values[i]; + } - for (int i = 0; i < len; i++) { - accuWeight += weights[i]; - valueSum += values[i]; - count++; + @Override + protected void swap(int index1, int index2) { + long tmp = values[index1]; + values[index1] = values[index2]; + values[index2] = tmp; + tmp = weights[index1]; + weights[index1] = weights[index2]; + weights[index2] = tmp; + } - if (accuWeight >= rangeWeightTarget) { + @Override + protected void setPivot(int i) { + pivotValue = values[i]; + pivotWeight = weights[i]; + } + + @Override + protected int comparePivot(int j) { + int cmp = Long.compare(pivotValue, values[j]); + if (cmp == 0) { + // If the values are equal, sort based on the weights. + // Any weight order is correct as long as it's deterministic. + return Long.compare(pivotWeight, weights[j]); + } + return cmp; + } + }.select(0, len, totalValue, 0, totalWeight, 0, kWeights); + + int lastIdx = -1; + long lastTotalValue = 0; + long lastTotalWeight = 0; + for (int kIdx = 0; kIdx < topN; kIdx++) { + WeightedSelector.WeightRangeInfo weightRangeInfo = kIndexResults[kIdx]; + if (weightRangeInfo.index() > -1) { + int count = weightRangeInfo.index() - lastIdx; dynamicRangeResult.add( new DynamicRangeInfo( - count, accuWeight, values[minIdx], values[i], (double) valueSum / count)); - count = 0; - accuWeight = 0; - valueSum = 0; - minIdx = i + 1; + count, + (weightRangeInfo.runningWeight() - lastTotalWeight), + values[lastIdx + 1], + values[weightRangeInfo.index()], + (double) (weightRangeInfo.runningValueSum() - lastTotalValue) / count)); + lastIdx = weightRangeInfo.index(); + lastTotalValue = weightRangeInfo.runningValueSum(); + lastTotalWeight = weightRangeInfo.runningWeight(); } } - // capture the remaining values to create the last range - if (minIdx < len) { - dynamicRangeResult.add( - new DynamicRangeInfo( - count, accuWeight, values[minIdx], values[len - 1], (double) valueSum / count)); - } return dynamicRangeResult; } diff --git a/lucene/facet/src/test/org/apache/lucene/facet/range/TestDynamicRangeUtil.java b/lucene/facet/src/test/org/apache/lucene/facet/range/TestDynamicRangeUtil.java index db78b03e6e3..866f1352ff3 100644 --- a/lucene/facet/src/test/org/apache/lucene/facet/range/TestDynamicRangeUtil.java +++ b/lucene/facet/src/test/org/apache/lucene/facet/range/TestDynamicRangeUtil.java @@ -26,10 +26,12 @@ public class TestDynamicRangeUtil extends LuceneTestCase { long[] values = new long[1000]; long[] weights = new long[1000]; + long totalValue = 0; long totalWeight = 0; for (int i = 0; i < 1000; i++) { values[i] = i + 1; weights[i] = i; + totalValue += i + 1; totalWeight += i; } @@ -40,7 +42,8 @@ public class TestDynamicRangeUtil extends LuceneTestCase { new DynamicRangeUtil.DynamicRangeInfo(159, 125133L, 709L, 867L, 788D)); expectedRangeInfoList.add( new DynamicRangeUtil.DynamicRangeInfo(133, 124089L, 868L, 1000L, 934D)); - assertDynamicNumericRangeResults(values, weights, 4, totalWeight, expectedRangeInfoList); + assertDynamicNumericRangeResults( + values, weights, 4, totalValue, totalWeight, expectedRangeInfoList); } public void testComputeDynamicNumericRangesWithSameValues() { @@ -55,11 +58,12 @@ public class TestDynamicRangeUtil extends LuceneTestCase { } expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(51, 1275L, 50L, 50L, 50D)); - expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(21, 1281L, 50L, 50L, 50D)); - expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(16, 1272L, 50L, 50L, 50D)); - expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(12, 1122L, 50L, 50L, 50D)); + expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(20, 1210L, 50L, 50L, 50D)); + expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(16, 1256L, 50L, 50L, 50D)); + expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(13, 1209L, 50L, 50L, 50D)); - assertDynamicNumericRangeResults(values, weights, 4, totalWeight, expectedRangeInfoList); + assertDynamicNumericRangeResults( + values, weights, 4, 50 * values.length, totalWeight, expectedRangeInfoList); } public void testComputeDynamicNumericRangesWithOneValue() { @@ -68,7 +72,7 @@ public class TestDynamicRangeUtil extends LuceneTestCase { List expectedRangeInfoList = new ArrayList<>(); expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 1L, 50L, 50L, 50D)); - assertDynamicNumericRangeResults(values, weights, 4, 1, expectedRangeInfoList); + assertDynamicNumericRangeResults(values, weights, 4, 50, 1, expectedRangeInfoList); } public void testComputeDynamicNumericRangesWithOneLargeWeight() { @@ -80,24 +84,25 @@ public class TestDynamicRangeUtil extends LuceneTestCase { expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 52343, 14L, 14L, 14D)); expectedRangeInfoList.add( new DynamicRangeUtil.DynamicRangeInfo(6, 2766, 32L, 455L, 163.16666666666666D)); - assertDynamicNumericRangeResults(values, weights, 4, 55109, expectedRangeInfoList); + assertDynamicNumericRangeResults(values, weights, 4, 993, 55109, expectedRangeInfoList); } private static void assertDynamicNumericRangeResults( long[] values, long[] weights, int topN, + long totalValue, long totalWeight, List expectedDynamicRangeResult) { List mockDynamicRangeResult = DynamicRangeUtil.computeDynamicNumericRanges( - values, weights, values.length, totalWeight, topN); - assertTrue(compareDynamicRangeResult(mockDynamicRangeResult, expectedDynamicRangeResult)); + values, weights, values.length, totalValue, totalWeight, topN); + compareDynamicRangeResult(mockDynamicRangeResult, expectedDynamicRangeResult); } - private static boolean compareDynamicRangeResult( + private static void compareDynamicRangeResult( List mockResult, List expectedResult) { - return mockResult.size() == expectedResult.size() && mockResult.containsAll(expectedResult); + assertEquals(expectedResult, mockResult); } }