mirror of https://github.com/apache/lucene.git
Use multi-select instead of sort for Dynamic Ranges
This commit is contained in:
parent
282945998d
commit
d197f012ef
|
@ -0,0 +1,407 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.util;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.SplittableRandom;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adaptive selection algorithm based on the introspective quick select algorithm. The quick select
|
||||||
|
* algorithm uses an interpolation variant of Tukey's ninther median-of-medians for pivot, and
|
||||||
|
* Bentley-McIlroy 3-way partitioning. For the introspective protection, it shuffles the sub-range
|
||||||
|
* if the max recursive depth is exceeded.
|
||||||
|
*
|
||||||
|
* <p>This selection algorithm is fast on most data shapes, especially on nearly sorted data, or
|
||||||
|
* when k is close to the boundaries. It runs in linear time on average.
|
||||||
|
*
|
||||||
|
* @lucene.internal
|
||||||
|
*/
|
||||||
|
public abstract class WeightedSelector {
|
||||||
|
|
||||||
|
// This selector is used repeatedly by the radix selector for sub-ranges of less than
|
||||||
|
// 100 entries. This means this selector is also optimized to be fast on small ranges.
|
||||||
|
// It uses the variant of medians-of-medians and 3-way partitioning, and finishes the
|
||||||
|
// last tiny range (3 entries or less) with a very specialized sort.
|
||||||
|
|
||||||
|
private SplittableRandom random;
|
||||||
|
|
||||||
|
protected abstract long getWeight(int i);
|
||||||
|
|
||||||
|
protected abstract long getValue(int i);
|
||||||
|
|
||||||
|
public final WeightRangeInfo[] select(
|
||||||
|
int from,
|
||||||
|
int to,
|
||||||
|
long rangeTotalValue,
|
||||||
|
long beforeTotalValue,
|
||||||
|
long rangeWeight,
|
||||||
|
long beforeWeight,
|
||||||
|
double[] kWeights) {
|
||||||
|
WeightRangeInfo[] kIndexResults = new WeightRangeInfo[kWeights.length];
|
||||||
|
Arrays.fill(kIndexResults, new WeightRangeInfo(-1, 0, 0));
|
||||||
|
checkArgs(rangeWeight, beforeWeight, kWeights);
|
||||||
|
select(
|
||||||
|
from,
|
||||||
|
to,
|
||||||
|
rangeTotalValue,
|
||||||
|
beforeTotalValue,
|
||||||
|
rangeWeight,
|
||||||
|
beforeWeight,
|
||||||
|
kWeights,
|
||||||
|
0,
|
||||||
|
kWeights.length,
|
||||||
|
kIndexResults,
|
||||||
|
2 * MathUtil.log(to - from, 2));
|
||||||
|
return kIndexResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
void checkArgs(long rangeWeight, long beforeWeight, double[] kWeights) {
|
||||||
|
if (kWeights.length < 1) {
|
||||||
|
throw new IllegalArgumentException("There must be at least one k to select, none given");
|
||||||
|
}
|
||||||
|
Arrays.sort(kWeights);
|
||||||
|
if (kWeights[0] < beforeWeight) {
|
||||||
|
throw new IllegalArgumentException("All kWeights must be >= beforeWeight");
|
||||||
|
}
|
||||||
|
if (kWeights[kWeights.length - 1] > beforeWeight + rangeWeight) {
|
||||||
|
throw new IllegalArgumentException("All kWeights must be < beforeWeight + rangeWeight");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visible for testing.
|
||||||
|
void select(
|
||||||
|
int from,
|
||||||
|
int to,
|
||||||
|
long rangeTotalValue,
|
||||||
|
long beforeTotalValue,
|
||||||
|
long rangeWeight,
|
||||||
|
long beforeWeight,
|
||||||
|
double[] kWeights,
|
||||||
|
int kFrom,
|
||||||
|
int kTo,
|
||||||
|
WeightRangeInfo[] kIndexResults,
|
||||||
|
int maxDepth) {
|
||||||
|
|
||||||
|
// This code is inspired from IntroSorter#sort, adapted to loop on a single partition.
|
||||||
|
|
||||||
|
// For efficiency, we must enter the loop with at least 4 entries to be able to skip
|
||||||
|
// some boundary tests during the 3-way partitioning.
|
||||||
|
int size;
|
||||||
|
if ((size = to - from) > 3) {
|
||||||
|
|
||||||
|
if (--maxDepth == -1) {
|
||||||
|
// Max recursion depth exceeded: shuffle (only once) and continue.
|
||||||
|
shuffle(from, to);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pivot selection based on medians.
|
||||||
|
int last = to - 1;
|
||||||
|
int mid = (from + last) >>> 1;
|
||||||
|
int pivot;
|
||||||
|
if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) {
|
||||||
|
// Select the pivot with a single median around the middle element.
|
||||||
|
// Do not take the median between [from, mid, last] because it hurts performance
|
||||||
|
// if the order is descending in conjunction with the 3-way partitioning.
|
||||||
|
int range = size >> 2;
|
||||||
|
pivot = median(mid - range, mid, mid + range);
|
||||||
|
} else {
|
||||||
|
// Select the pivot with a variant of the Tukey's ninther median of medians.
|
||||||
|
// If k is close to the boundaries, select either the lowest or highest median (this variant
|
||||||
|
// is inspired from the interpolation search).
|
||||||
|
int range = size >> 3;
|
||||||
|
int doubleRange = range << 1;
|
||||||
|
int medianFirst = median(from, from + range, from + doubleRange);
|
||||||
|
int medianMiddle = median(mid - range, mid, mid + range);
|
||||||
|
int medianLast = median(last - doubleRange, last - range, last);
|
||||||
|
|
||||||
|
double avgWeight = ((double) rangeWeight) / (to - from);
|
||||||
|
double middleWeight = kWeights[(kFrom + kTo - 1) >> 1];
|
||||||
|
// Approximate the k we are trying to find by assuming an equal weight amongst values
|
||||||
|
int middleK = from + (int) ((middleWeight - beforeWeight) / avgWeight);
|
||||||
|
if (middleK - from < range) {
|
||||||
|
// k is close to 'from': select the lowest median.
|
||||||
|
pivot = min(medianFirst, medianMiddle, medianLast);
|
||||||
|
} else if (to - middleK <= range) {
|
||||||
|
// k is close to 'to': select the highest median.
|
||||||
|
pivot = max(medianFirst, medianMiddle, medianLast);
|
||||||
|
} else {
|
||||||
|
// Otherwise select the median of medians.
|
||||||
|
pivot = median(medianFirst, medianMiddle, medianLast);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bentley-McIlroy 3-way partitioning.
|
||||||
|
setPivot(pivot);
|
||||||
|
swap(from, pivot);
|
||||||
|
int i = from;
|
||||||
|
int j = to;
|
||||||
|
int p = from + 1;
|
||||||
|
int q = last;
|
||||||
|
long leftTotalValue = 0;
|
||||||
|
long leftWeight = 0;
|
||||||
|
long rightTotalValue = 0;
|
||||||
|
long rightWeight = 0;
|
||||||
|
while (true) {
|
||||||
|
int leftCmp, rightCmp;
|
||||||
|
while ((leftCmp = comparePivot(++i)) > 0) {
|
||||||
|
leftTotalValue += getValue(i);
|
||||||
|
leftWeight += getWeight(i);
|
||||||
|
}
|
||||||
|
while ((rightCmp = comparePivot(--j)) < 0) {
|
||||||
|
rightTotalValue += getValue(j);
|
||||||
|
rightWeight += getWeight(j);
|
||||||
|
}
|
||||||
|
if (i >= j) {
|
||||||
|
if (i == j && rightCmp == 0) {
|
||||||
|
swap(i, p);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
swap(i, j);
|
||||||
|
if (rightCmp == 0) {
|
||||||
|
swap(i, p++);
|
||||||
|
} else {
|
||||||
|
leftTotalValue += getValue(i);
|
||||||
|
leftWeight += getWeight(i);
|
||||||
|
}
|
||||||
|
if (leftCmp == 0) {
|
||||||
|
swap(j, q--);
|
||||||
|
} else {
|
||||||
|
rightTotalValue += getValue(j);
|
||||||
|
rightWeight += getWeight(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i = j + 1;
|
||||||
|
for (int l = from; l < p; ) {
|
||||||
|
swap(l++, j--);
|
||||||
|
}
|
||||||
|
for (int l = last; l > q; ) {
|
||||||
|
swap(l--, i++);
|
||||||
|
}
|
||||||
|
long leftWeightEnd = beforeWeight + leftWeight;
|
||||||
|
long rightWeightStart = beforeWeight + rangeWeight - rightWeight;
|
||||||
|
|
||||||
|
// Select the K weight values contained in the bottom and top partitions.
|
||||||
|
int topKFrom = kTo;
|
||||||
|
int bottomKTo = kFrom;
|
||||||
|
for (int ki = kTo - 1; ki >= kFrom; ki--) {
|
||||||
|
if (kWeights[ki] >= rightWeightStart) {
|
||||||
|
topKFrom = ki;
|
||||||
|
}
|
||||||
|
if (kWeights[ki] <= leftWeightEnd) {
|
||||||
|
bottomKTo = ki + 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Recursively select the relevant k-values from the bottom group, if there are any k-values
|
||||||
|
// to select there
|
||||||
|
if (bottomKTo > kFrom) {
|
||||||
|
select(
|
||||||
|
from,
|
||||||
|
j + 1,
|
||||||
|
leftTotalValue,
|
||||||
|
beforeTotalValue,
|
||||||
|
leftWeight,
|
||||||
|
beforeWeight,
|
||||||
|
kWeights,
|
||||||
|
kFrom,
|
||||||
|
bottomKTo,
|
||||||
|
kIndexResults,
|
||||||
|
maxDepth);
|
||||||
|
}
|
||||||
|
// Recursively select the relevant k-values from the top group, if there are any k-values to
|
||||||
|
// select there
|
||||||
|
if (topKFrom < kTo) {
|
||||||
|
select(
|
||||||
|
i,
|
||||||
|
to,
|
||||||
|
rightTotalValue,
|
||||||
|
beforeTotalValue + rangeTotalValue - rightTotalValue,
|
||||||
|
rightWeight,
|
||||||
|
beforeWeight + rangeWeight - rightWeight,
|
||||||
|
kWeights,
|
||||||
|
topKFrom,
|
||||||
|
kTo,
|
||||||
|
kIndexResults,
|
||||||
|
maxDepth);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose the k result indexes for this partition
|
||||||
|
if (bottomKTo < topKFrom) {
|
||||||
|
findKIndexes(
|
||||||
|
j + 1,
|
||||||
|
i,
|
||||||
|
beforeTotalValue + leftTotalValue,
|
||||||
|
beforeWeight + leftWeight,
|
||||||
|
bottomKTo,
|
||||||
|
topKFrom,
|
||||||
|
kWeights,
|
||||||
|
kIndexResults);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the final tiny range (3 entries or less) with a very specialized sort.
|
||||||
|
switch (size) {
|
||||||
|
case 1:
|
||||||
|
kIndexResults[kTo - 1] =
|
||||||
|
new WeightRangeInfo(
|
||||||
|
from, beforeTotalValue + getValue(from), beforeWeight + getWeight(from));
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
if (compare(from, from + 1) > 0) {
|
||||||
|
swap(from, from + 1);
|
||||||
|
}
|
||||||
|
findKIndexes(
|
||||||
|
from, from + 2, beforeTotalValue, beforeWeight, kFrom, kTo, kWeights, kIndexResults);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
sort3(from);
|
||||||
|
findKIndexes(
|
||||||
|
from, from + 3, beforeTotalValue, beforeWeight, kFrom, kTo, kWeights, kIndexResults);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void findKIndexes(
|
||||||
|
int from,
|
||||||
|
int to,
|
||||||
|
long beforeTotalValue,
|
||||||
|
long beforeWeight,
|
||||||
|
int kFrom,
|
||||||
|
int kTo,
|
||||||
|
double[] kWeights,
|
||||||
|
WeightRangeInfo[] kIndexResults) {
|
||||||
|
long runningWeight = beforeWeight;
|
||||||
|
long runningTotalValue = beforeTotalValue;
|
||||||
|
int kIdx = kFrom;
|
||||||
|
for (int listIdx = from; listIdx < to && kIdx < kTo; listIdx++) {
|
||||||
|
runningWeight += getWeight(listIdx);
|
||||||
|
runningTotalValue += getValue(listIdx);
|
||||||
|
// Skip ahead in the weight list if the same value is used for multiple weights, we will only
|
||||||
|
// record a result index for the last weight that matches it.
|
||||||
|
while (++kIdx < kTo && kWeights[kIdx] <= runningWeight) {}
|
||||||
|
if (kWeights[--kIdx] <= runningWeight) {
|
||||||
|
kIndexResults[kIdx] = new WeightRangeInfo(listIdx, runningTotalValue, runningWeight);
|
||||||
|
// Now that we have recorded the resultIndex for this weight, go to the next weight
|
||||||
|
kIdx++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the index of the min element among three elements at provided indices. */
|
||||||
|
private int min(int i, int j, int k) {
|
||||||
|
if (compare(i, j) <= 0) {
|
||||||
|
return compare(i, k) <= 0 ? i : k;
|
||||||
|
}
|
||||||
|
return compare(j, k) <= 0 ? j : k;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the index of the max element among three elements at provided indices. */
|
||||||
|
private int max(int i, int j, int k) {
|
||||||
|
if (compare(i, j) <= 0) {
|
||||||
|
return compare(j, k) < 0 ? k : j;
|
||||||
|
}
|
||||||
|
return compare(i, k) < 0 ? k : i;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Copy of {@code IntroSorter#median}. */
|
||||||
|
private int median(int i, int j, int k) {
|
||||||
|
if (compare(i, j) < 0) {
|
||||||
|
if (compare(j, k) <= 0) {
|
||||||
|
return j;
|
||||||
|
}
|
||||||
|
return compare(i, k) < 0 ? k : i;
|
||||||
|
}
|
||||||
|
if (compare(j, k) >= 0) {
|
||||||
|
return j;
|
||||||
|
}
|
||||||
|
return compare(i, k) < 0 ? i : k;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sorts 3 entries starting at from (inclusive). This specialized method is more efficient than
|
||||||
|
* calling {@link Sorter#insertionSort(int, int)}.
|
||||||
|
*/
|
||||||
|
private void sort3(int from) {
|
||||||
|
final int mid = from + 1;
|
||||||
|
final int last = from + 2;
|
||||||
|
if (compare(from, mid) <= 0) {
|
||||||
|
if (compare(mid, last) > 0) {
|
||||||
|
swap(mid, last);
|
||||||
|
if (compare(from, mid) > 0) {
|
||||||
|
swap(from, mid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (compare(mid, last) >= 0) {
|
||||||
|
swap(from, last);
|
||||||
|
} else {
|
||||||
|
swap(from, mid);
|
||||||
|
if (compare(mid, last) > 0) {
|
||||||
|
swap(mid, last);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shuffles the entries between from (inclusive) and to (exclusive) with Durstenfeld's algorithm.
|
||||||
|
*/
|
||||||
|
private void shuffle(int from, int to) {
|
||||||
|
if (this.random == null) {
|
||||||
|
this.random = new SplittableRandom();
|
||||||
|
}
|
||||||
|
SplittableRandom random = this.random;
|
||||||
|
for (int i = to - 1; i > from; i--) {
|
||||||
|
swap(i, random.nextInt(from, i + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Swap values at slots <code>i</code> and <code>j</code>. */
|
||||||
|
protected abstract void swap(int i, int j);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save the value at slot <code>i</code> so that it can later be used as a pivot, see {@link
|
||||||
|
* #comparePivot(int)}.
|
||||||
|
*/
|
||||||
|
protected abstract void setPivot(int i);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compare the pivot with the slot at <code>j</code>, similarly to {@link #compare(int, int)
|
||||||
|
* compare(i, j)}.
|
||||||
|
*/
|
||||||
|
protected abstract int comparePivot(int j);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
|
||||||
|
* value is the same as {@link Comparator#compare(Object, Object)}.
|
||||||
|
*/
|
||||||
|
protected int compare(int i, int j) {
|
||||||
|
setPivot(i);
|
||||||
|
return comparePivot(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Holds information for a returned weight index result
|
||||||
|
*
|
||||||
|
* @param index the index at which the weight range limit was found
|
||||||
|
* @param runningValueSum the sum of values from the start of the list to the end of the range
|
||||||
|
* limit
|
||||||
|
* @param runningWeight the sum of weights from the start of the list to the end of the range
|
||||||
|
* limit
|
||||||
|
*/
|
||||||
|
public record WeightRangeInfo(int index, long runningValueSum, long runningWeight) {}
|
||||||
|
}
|
|
@ -28,7 +28,7 @@ import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.LongValues;
|
import org.apache.lucene.search.LongValues;
|
||||||
import org.apache.lucene.search.LongValuesSource;
|
import org.apache.lucene.search.LongValuesSource;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.InPlaceMergeSorter;
|
import org.apache.lucene.util.WeightedSelector;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Methods to create dynamic ranges for numeric fields.
|
* Methods to create dynamic ranges for numeric fields.
|
||||||
|
@ -66,6 +66,7 @@ public final class DynamicRangeUtil {
|
||||||
matchingDocsList.stream().mapToInt(FacetsCollector.MatchingDocs::totalHits).sum();
|
matchingDocsList.stream().mapToInt(FacetsCollector.MatchingDocs::totalHits).sum();
|
||||||
long[] values = new long[totalDoc];
|
long[] values = new long[totalDoc];
|
||||||
long[] weights = new long[totalDoc];
|
long[] weights = new long[totalDoc];
|
||||||
|
long totalValue = 0;
|
||||||
long totalWeight = 0;
|
long totalWeight = 0;
|
||||||
int overallLength = 0;
|
int overallLength = 0;
|
||||||
|
|
||||||
|
@ -107,6 +108,7 @@ public final class DynamicRangeUtil {
|
||||||
assert curSegmentOutput.values.length == curSegmentOutput.weights.length;
|
assert curSegmentOutput.values.length == curSegmentOutput.weights.length;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
totalValue = Math.addExact(curSegmentOutput.segmentTotalValue, totalValue);
|
||||||
totalWeight = Math.addExact(curSegmentOutput.segmentTotalWeight, totalWeight);
|
totalWeight = Math.addExact(curSegmentOutput.segmentTotalWeight, totalWeight);
|
||||||
} catch (ArithmeticException ae) {
|
} catch (ArithmeticException ae) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
@ -118,7 +120,8 @@ public final class DynamicRangeUtil {
|
||||||
System.arraycopy(curSegmentOutput.weights, 0, weights, overallLength, currSegmentLen);
|
System.arraycopy(curSegmentOutput.weights, 0, weights, overallLength, currSegmentLen);
|
||||||
overallLength += currSegmentLen;
|
overallLength += currSegmentLen;
|
||||||
}
|
}
|
||||||
return computeDynamicNumericRanges(values, weights, overallLength, totalWeight, topN);
|
return computeDynamicNumericRanges(
|
||||||
|
values, weights, overallLength, totalValue, totalWeight, topN);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class SegmentTask implements Callable<Void> {
|
private static class SegmentTask implements Callable<Void> {
|
||||||
|
@ -165,6 +168,8 @@ public final class DynamicRangeUtil {
|
||||||
segmentOutput.values[segmentOutput.segmentIdx] = curValue;
|
segmentOutput.values[segmentOutput.segmentIdx] = curValue;
|
||||||
segmentOutput.weights[segmentOutput.segmentIdx] = curWeight;
|
segmentOutput.weights[segmentOutput.segmentIdx] = curWeight;
|
||||||
try {
|
try {
|
||||||
|
segmentOutput.segmentTotalValue =
|
||||||
|
Math.addExact(segmentOutput.segmentTotalValue, curValue);
|
||||||
segmentOutput.segmentTotalWeight =
|
segmentOutput.segmentTotalWeight =
|
||||||
Math.addExact(segmentOutput.segmentTotalWeight, curWeight);
|
Math.addExact(segmentOutput.segmentTotalWeight, curWeight);
|
||||||
} catch (ArithmeticException ae) {
|
} catch (ArithmeticException ae) {
|
||||||
|
@ -180,6 +185,7 @@ public final class DynamicRangeUtil {
|
||||||
private static final class SegmentOutput {
|
private static final class SegmentOutput {
|
||||||
private final long[] values;
|
private final long[] values;
|
||||||
private final long[] weights;
|
private final long[] weights;
|
||||||
|
private long segmentTotalValue;
|
||||||
private long segmentTotalWeight;
|
private long segmentTotalWeight;
|
||||||
private int segmentIdx;
|
private int segmentIdx;
|
||||||
|
|
||||||
|
@ -202,7 +208,7 @@ public final class DynamicRangeUtil {
|
||||||
* is used to compute the equi-weight per bin.
|
* is used to compute the equi-weight per bin.
|
||||||
*/
|
*/
|
||||||
public static List<DynamicRangeInfo> computeDynamicNumericRanges(
|
public static List<DynamicRangeInfo> computeDynamicNumericRanges(
|
||||||
long[] values, long[] weights, int len, long totalWeight, int topN) {
|
long[] values, long[] weights, int len, long totalValue, long totalWeight, int topN) {
|
||||||
assert values.length == weights.length && len <= values.length && len >= 0;
|
assert values.length == weights.length && len <= values.length && len >= 0;
|
||||||
assert topN >= 0;
|
assert topN >= 0;
|
||||||
List<DynamicRangeInfo> dynamicRangeResult = new ArrayList<>();
|
List<DynamicRangeInfo> dynamicRangeResult = new ArrayList<>();
|
||||||
|
@ -210,16 +216,25 @@ public final class DynamicRangeUtil {
|
||||||
return dynamicRangeResult;
|
return dynamicRangeResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
new InPlaceMergeSorter() {
|
double rangeWeightTarget = (double) totalWeight / topN;
|
||||||
@Override
|
double[] kWeights = new double[topN];
|
||||||
protected int compare(int index1, int index2) {
|
for (int i = 0; i < topN; i++) {
|
||||||
int cmp = Long.compare(values[index1], values[index2]);
|
kWeights[i] = (i == 0 ? 0 : kWeights[i - 1]) + rangeWeightTarget;
|
||||||
if (cmp == 0) {
|
|
||||||
// If the values are equal, sort based on the weights.
|
|
||||||
// Any weight order is correct as long as it's deterministic.
|
|
||||||
return Long.compare(weights[index1], weights[index2]);
|
|
||||||
}
|
}
|
||||||
return cmp;
|
|
||||||
|
WeightedSelector.WeightRangeInfo[] kIndexResults =
|
||||||
|
new WeightedSelector() {
|
||||||
|
private long pivotValue;
|
||||||
|
private long pivotWeight;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected long getWeight(int i) {
|
||||||
|
return weights[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected long getValue(int i) {
|
||||||
|
return values[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -231,37 +246,45 @@ public final class DynamicRangeUtil {
|
||||||
weights[index1] = weights[index2];
|
weights[index1] = weights[index2];
|
||||||
weights[index2] = tmp;
|
weights[index2] = tmp;
|
||||||
}
|
}
|
||||||
}.sort(0, len);
|
|
||||||
|
|
||||||
long accuWeight = 0;
|
@Override
|
||||||
long valueSum = 0;
|
protected void setPivot(int i) {
|
||||||
int count = 0;
|
pivotValue = values[i];
|
||||||
int minIdx = 0;
|
pivotWeight = weights[i];
|
||||||
|
}
|
||||||
|
|
||||||
double rangeWeightTarget = (double) totalWeight / Math.min(topN, len);
|
@Override
|
||||||
|
protected int comparePivot(int j) {
|
||||||
|
int cmp = Long.compare(pivotValue, values[j]);
|
||||||
|
if (cmp == 0) {
|
||||||
|
// If the values are equal, sort based on the weights.
|
||||||
|
// Any weight order is correct as long as it's deterministic.
|
||||||
|
return Long.compare(pivotWeight, weights[j]);
|
||||||
|
}
|
||||||
|
return cmp;
|
||||||
|
}
|
||||||
|
}.select(0, len, totalValue, 0, totalWeight, 0, kWeights);
|
||||||
|
|
||||||
for (int i = 0; i < len; i++) {
|
int lastIdx = -1;
|
||||||
accuWeight += weights[i];
|
long lastTotalValue = 0;
|
||||||
valueSum += values[i];
|
long lastTotalWeight = 0;
|
||||||
count++;
|
for (int kIdx = 0; kIdx < topN; kIdx++) {
|
||||||
|
WeightedSelector.WeightRangeInfo weightRangeInfo = kIndexResults[kIdx];
|
||||||
if (accuWeight >= rangeWeightTarget) {
|
if (weightRangeInfo.index() > -1) {
|
||||||
|
int count = weightRangeInfo.index() - lastIdx;
|
||||||
dynamicRangeResult.add(
|
dynamicRangeResult.add(
|
||||||
new DynamicRangeInfo(
|
new DynamicRangeInfo(
|
||||||
count, accuWeight, values[minIdx], values[i], (double) valueSum / count));
|
count,
|
||||||
count = 0;
|
(weightRangeInfo.runningWeight() - lastTotalWeight),
|
||||||
accuWeight = 0;
|
values[lastIdx + 1],
|
||||||
valueSum = 0;
|
values[weightRangeInfo.index()],
|
||||||
minIdx = i + 1;
|
(double) (weightRangeInfo.runningValueSum() - lastTotalValue) / count));
|
||||||
|
lastIdx = weightRangeInfo.index();
|
||||||
|
lastTotalValue = weightRangeInfo.runningValueSum();
|
||||||
|
lastTotalWeight = weightRangeInfo.runningWeight();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// capture the remaining values to create the last range
|
|
||||||
if (minIdx < len) {
|
|
||||||
dynamicRangeResult.add(
|
|
||||||
new DynamicRangeInfo(
|
|
||||||
count, accuWeight, values[minIdx], values[len - 1], (double) valueSum / count));
|
|
||||||
}
|
|
||||||
return dynamicRangeResult;
|
return dynamicRangeResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,10 +26,12 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
|
||||||
long[] values = new long[1000];
|
long[] values = new long[1000];
|
||||||
long[] weights = new long[1000];
|
long[] weights = new long[1000];
|
||||||
|
|
||||||
|
long totalValue = 0;
|
||||||
long totalWeight = 0;
|
long totalWeight = 0;
|
||||||
for (int i = 0; i < 1000; i++) {
|
for (int i = 0; i < 1000; i++) {
|
||||||
values[i] = i + 1;
|
values[i] = i + 1;
|
||||||
weights[i] = i;
|
weights[i] = i;
|
||||||
|
totalValue += i + 1;
|
||||||
totalWeight += i;
|
totalWeight += i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +42,8 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
|
||||||
new DynamicRangeUtil.DynamicRangeInfo(159, 125133L, 709L, 867L, 788D));
|
new DynamicRangeUtil.DynamicRangeInfo(159, 125133L, 709L, 867L, 788D));
|
||||||
expectedRangeInfoList.add(
|
expectedRangeInfoList.add(
|
||||||
new DynamicRangeUtil.DynamicRangeInfo(133, 124089L, 868L, 1000L, 934D));
|
new DynamicRangeUtil.DynamicRangeInfo(133, 124089L, 868L, 1000L, 934D));
|
||||||
assertDynamicNumericRangeResults(values, weights, 4, totalWeight, expectedRangeInfoList);
|
assertDynamicNumericRangeResults(
|
||||||
|
values, weights, 4, totalValue, totalWeight, expectedRangeInfoList);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testComputeDynamicNumericRangesWithSameValues() {
|
public void testComputeDynamicNumericRangesWithSameValues() {
|
||||||
|
@ -55,11 +58,12 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(51, 1275L, 50L, 50L, 50D));
|
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(51, 1275L, 50L, 50L, 50D));
|
||||||
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(21, 1281L, 50L, 50L, 50D));
|
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(20, 1210L, 50L, 50L, 50D));
|
||||||
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(16, 1272L, 50L, 50L, 50D));
|
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(16, 1256L, 50L, 50L, 50D));
|
||||||
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(12, 1122L, 50L, 50L, 50D));
|
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(13, 1209L, 50L, 50L, 50D));
|
||||||
|
|
||||||
assertDynamicNumericRangeResults(values, weights, 4, totalWeight, expectedRangeInfoList);
|
assertDynamicNumericRangeResults(
|
||||||
|
values, weights, 4, 50 * values.length, totalWeight, expectedRangeInfoList);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testComputeDynamicNumericRangesWithOneValue() {
|
public void testComputeDynamicNumericRangesWithOneValue() {
|
||||||
|
@ -68,7 +72,7 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
|
||||||
List<DynamicRangeUtil.DynamicRangeInfo> expectedRangeInfoList = new ArrayList<>();
|
List<DynamicRangeUtil.DynamicRangeInfo> expectedRangeInfoList = new ArrayList<>();
|
||||||
|
|
||||||
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 1L, 50L, 50L, 50D));
|
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 1L, 50L, 50L, 50D));
|
||||||
assertDynamicNumericRangeResults(values, weights, 4, 1, expectedRangeInfoList);
|
assertDynamicNumericRangeResults(values, weights, 4, 50, 1, expectedRangeInfoList);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testComputeDynamicNumericRangesWithOneLargeWeight() {
|
public void testComputeDynamicNumericRangesWithOneLargeWeight() {
|
||||||
|
@ -80,24 +84,25 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
|
||||||
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 52343, 14L, 14L, 14D));
|
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 52343, 14L, 14L, 14D));
|
||||||
expectedRangeInfoList.add(
|
expectedRangeInfoList.add(
|
||||||
new DynamicRangeUtil.DynamicRangeInfo(6, 2766, 32L, 455L, 163.16666666666666D));
|
new DynamicRangeUtil.DynamicRangeInfo(6, 2766, 32L, 455L, 163.16666666666666D));
|
||||||
assertDynamicNumericRangeResults(values, weights, 4, 55109, expectedRangeInfoList);
|
assertDynamicNumericRangeResults(values, weights, 4, 993, 55109, expectedRangeInfoList);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertDynamicNumericRangeResults(
|
private static void assertDynamicNumericRangeResults(
|
||||||
long[] values,
|
long[] values,
|
||||||
long[] weights,
|
long[] weights,
|
||||||
int topN,
|
int topN,
|
||||||
|
long totalValue,
|
||||||
long totalWeight,
|
long totalWeight,
|
||||||
List<DynamicRangeUtil.DynamicRangeInfo> expectedDynamicRangeResult) {
|
List<DynamicRangeUtil.DynamicRangeInfo> expectedDynamicRangeResult) {
|
||||||
List<DynamicRangeUtil.DynamicRangeInfo> mockDynamicRangeResult =
|
List<DynamicRangeUtil.DynamicRangeInfo> mockDynamicRangeResult =
|
||||||
DynamicRangeUtil.computeDynamicNumericRanges(
|
DynamicRangeUtil.computeDynamicNumericRanges(
|
||||||
values, weights, values.length, totalWeight, topN);
|
values, weights, values.length, totalValue, totalWeight, topN);
|
||||||
assertTrue(compareDynamicRangeResult(mockDynamicRangeResult, expectedDynamicRangeResult));
|
compareDynamicRangeResult(mockDynamicRangeResult, expectedDynamicRangeResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean compareDynamicRangeResult(
|
private static void compareDynamicRangeResult(
|
||||||
List<DynamicRangeUtil.DynamicRangeInfo> mockResult,
|
List<DynamicRangeUtil.DynamicRangeInfo> mockResult,
|
||||||
List<DynamicRangeUtil.DynamicRangeInfo> expectedResult) {
|
List<DynamicRangeUtil.DynamicRangeInfo> expectedResult) {
|
||||||
return mockResult.size() == expectedResult.size() && mockResult.containsAll(expectedResult);
|
assertEquals(expectedResult, mockResult);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue