mirror of https://github.com/apache/lucene.git
Implement radix multiSelect, add tests, rename, make default method
This commit is contained in:
parent
f5fa804c0f
commit
f985f75e62
|
@ -46,9 +46,9 @@ public abstract class IntroSelector extends Selector {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final void select(int from, int to, int[] k) {
|
||||
checkArgs(from, to, k);
|
||||
select(from, to, k, 0, k.length, 2 * MathUtil.log(to - from, 2));
|
||||
public final void multiSelect(int from, int to, int[] k, int kFrom, int kTo) {
|
||||
checkMultiArgs(from, to, k, kFrom, kTo);
|
||||
multiSelect(from, to, k, kFrom, kTo, 2 * MathUtil.log(to - from, 2));
|
||||
}
|
||||
|
||||
// Visible for testing.
|
||||
|
@ -153,8 +153,8 @@ public abstract class IntroSelector extends Selector {
|
|||
}
|
||||
|
||||
// Visible for testing.
|
||||
void select(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) {
|
||||
// If there is only 1 k value to select in this group, then use the single-k select method
|
||||
void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) {
|
||||
// If there is only 1 k value to select in this group, then use the single-k select method, which does not do recursion
|
||||
if (kTo - kFrom == 1) {
|
||||
select(from, to, k[kFrom], maxDepth);
|
||||
return;
|
||||
|
@ -251,11 +251,11 @@ public abstract class IntroSelector extends Selector {
|
|||
}
|
||||
// Recursively select the relevant k-values from the bottom group, if there are any k-values to select there
|
||||
if (bottomKTo > kFrom) {
|
||||
select(from, j + 1, k, kFrom, bottomKTo, maxDepth);
|
||||
multiSelect(from, j + 1, k, kFrom, bottomKTo, maxDepth);
|
||||
}
|
||||
// Recursively select the relevant k-values from the top group, if there are any k-values to select there
|
||||
if (topKFrom < kTo) {
|
||||
select(i, to, k, topKFrom, kTo, maxDepth);
|
||||
multiSelect(i, to, k, topKFrom, kTo, maxDepth);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
|
@ -124,6 +125,12 @@ public abstract class RadixSelector extends Selector {
|
|||
select(from, to, k, 0, 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) {
|
||||
checkMultiArgs(from, to, k, kFrom, kTo);
|
||||
multiSelect(from, to, k, kFrom, kTo, 0, 0);
|
||||
}
|
||||
|
||||
private void select(int from, int to, int k, int d, int l) {
|
||||
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) {
|
||||
getFallbackSelector(d).select(from, to, k);
|
||||
|
@ -132,6 +139,22 @@ public abstract class RadixSelector extends Selector {
|
|||
}
|
||||
}
|
||||
|
||||
private void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int d, int l) {
|
||||
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) {
|
||||
if (kTo - kFrom == 1) {
|
||||
getFallbackSelector(d).select(from, to, k[kFrom]);
|
||||
} else {
|
||||
getFallbackSelector(d).multiSelect(from, to, k, kFrom, kTo);
|
||||
}
|
||||
} else {
|
||||
if (kTo - kFrom == 1) {
|
||||
radixSelect(from, to, k[kFrom], d, l);
|
||||
} else {
|
||||
radixMultiSelect(from, to, k, kFrom, kTo, d, l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param d the character number to compare
|
||||
* @param l the level of recursion
|
||||
|
@ -171,6 +194,61 @@ public abstract class RadixSelector extends Selector {
|
|||
throw new AssertionError("Unreachable code");
|
||||
}
|
||||
|
||||
/**
|
||||
* @param d the character number to compare
|
||||
* @param l the level of recursion
|
||||
*/
|
||||
private void radixMultiSelect(int from, int to, int[] k, int kFrom, int kTo, int d, int l) {
|
||||
final int[] histogram = this.histogram;
|
||||
Arrays.fill(histogram, 0);
|
||||
|
||||
final int commonPrefixLength =
|
||||
computeCommonPrefixLengthAndBuildHistogram(from, to, d, histogram);
|
||||
if (commonPrefixLength > 0) {
|
||||
// if there are no more chars to compare or if all entries fell into the
|
||||
// first bucket (which means strings are shorter than d) then we are done
|
||||
// otherwise recurse
|
||||
if (d + commonPrefixLength < maxLength && histogram[0] < to - from) {
|
||||
radixMultiSelect(from, to, k, kFrom, kTo, d + commonPrefixLength, l);
|
||||
}
|
||||
return;
|
||||
}
|
||||
assert assertHistogram(commonPrefixLength, histogram);
|
||||
|
||||
int bucketFrom = from;
|
||||
int bucketKFrom = kFrom;
|
||||
ArrayList<Bucket> bucketsToRecurse = new ArrayList<>(kTo - kFrom);
|
||||
for (int bucket = 0; bucket < HISTOGRAM_SIZE && bucketKFrom < kTo; ++bucket) {
|
||||
if (histogram[bucket] == 0) {
|
||||
continue;
|
||||
}
|
||||
final int bucketTo = bucketFrom + histogram[bucket];
|
||||
int bucketKTo = bucketKFrom;
|
||||
// Move the right-side of the k-window up until the k-value is no longer in the current histogram bucket
|
||||
while (bucketKTo < kTo && k[bucketKTo] < bucketTo) {
|
||||
bucketKTo++;
|
||||
}
|
||||
|
||||
// If there are any k-values captured in this histogram, continue down this path with those k-values
|
||||
if (bucketKFrom < bucketKTo) {
|
||||
partition(from, to, bucket, bucketFrom, bucketTo, d);
|
||||
|
||||
// all elements in bucket 0 are equal so we only need to recurse if bucket != 0
|
||||
if (bucket != 0 && d + 1 < maxLength) {
|
||||
// Recurse after the loop, so that we do not override the histogram
|
||||
bucketsToRecurse.add(new Bucket(bucketFrom, bucketTo, bucketKFrom, bucketKTo));
|
||||
}
|
||||
}
|
||||
bucketFrom = bucketTo;
|
||||
bucketKFrom = bucketKTo;
|
||||
}
|
||||
for (Bucket b : bucketsToRecurse) {
|
||||
multiSelect(b.from, b.to, k, b.kFrom, b.kTo, d + 1, l + 1);
|
||||
}
|
||||
}
|
||||
|
||||
private record Bucket(int from, int to, int kFrom, int kTo) {}
|
||||
|
||||
// only used from assert
|
||||
private boolean assertHistogram(int commonPrefixLength, int[] histogram) {
|
||||
int numberOfUniqueBytes = 0;
|
||||
|
|
|
@ -38,8 +38,31 @@ public abstract class Selector {
|
|||
* elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements
|
||||
* that are greater than or equal to {@code k[n]}.
|
||||
*/
|
||||
public void select(int from, int to, int[] k) {
|
||||
select(from, to, k[0]);
|
||||
public void multiSelect(int from, int to, int[] k) {
|
||||
multiSelect(from, to, k, 0, k.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reorder elements so that the elements at all positions in {@code k} are the same as if all elements were
|
||||
* sorted and all other elements are partitioned around it: {@code [from, k[n])} only contains
|
||||
* elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements
|
||||
* that are greater than or equal to {@code k[n]}.
|
||||
*
|
||||
* The array {@code k} will be sorted, so {@code kFrom} and {@code kTo} must be referring to the sorted order.
|
||||
*/
|
||||
public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) {
|
||||
// Default implementation only uses select(), so it is not optimal
|
||||
checkMultiArgs(from, to, k, kFrom, kTo);
|
||||
int nextFrom = from;
|
||||
for (int i = kFrom; i < kTo; i++) {
|
||||
int currentK = k[i];
|
||||
if (currentK < nextFrom) {
|
||||
// This is a duplicate k
|
||||
continue;
|
||||
}
|
||||
select(nextFrom, to, currentK);
|
||||
nextFrom = currentK + 1;
|
||||
}
|
||||
}
|
||||
|
||||
void checkArgs(int from, int to, int k) {
|
||||
|
@ -51,15 +74,18 @@ public abstract class Selector {
|
|||
}
|
||||
}
|
||||
|
||||
void checkArgs(int from, int to, int[] k) {
|
||||
if (k.length < 1) {
|
||||
throw new IllegalArgumentException("There must be at least one k to select, none given");
|
||||
void checkMultiArgs(int from, int to, int[] k, int kFrom, int kTo) {
|
||||
if (kFrom < 0) {
|
||||
throw new IllegalArgumentException("kFrom must be >= 0");
|
||||
}
|
||||
if (kTo > k.length) {
|
||||
throw new IllegalArgumentException("kFrom must be <= k.length");
|
||||
}
|
||||
Arrays.sort(k);
|
||||
if (k[0] < from) {
|
||||
if (k[kFrom] < from) {
|
||||
throw new IllegalArgumentException("All k must be >= from");
|
||||
}
|
||||
if (k[k.length - 1] >= to) {
|
||||
if (k[kTo - 1] >= to) {
|
||||
throw new IllegalArgumentException("All k must be < to");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -587,7 +587,7 @@ public class ScalarQuantizer {
|
|||
selectorIndexes[2 * i + 1] = arr.length - selectorIndex - 1;
|
||||
}
|
||||
Selector selector = new FloatSelector(arr);
|
||||
selector.select(0, arr.length, selectorIndexes);
|
||||
selector.multiSelect(0, arr.length, selectorIndexes);
|
||||
|
||||
// After the selection process, pick out the given quantile values
|
||||
for (int i = 0; i < confidenceIntervals.length; i++) {
|
||||
|
|
|
@ -80,5 +80,30 @@ public class TestIntroSelector extends LuceneTestCase {
|
|||
assertTrue(actual[i] >= actual[k]);
|
||||
}
|
||||
}
|
||||
|
||||
final int[] kArr = new int[TestUtil.nextInt(random, 1, 10)];
|
||||
for (int i = 0; i < kArr.length; i++) {
|
||||
kArr[i] = TestUtil.nextInt(random, from, to - 1);
|
||||
}
|
||||
selector.multiSelect(from, to, kArr);
|
||||
|
||||
int nextKIdx = 0;
|
||||
Arrays.sort(kArr);
|
||||
for (int i = 0; i < actual.length; ++i) {
|
||||
if (i < from || i >= to) {
|
||||
assertSame(arr[i], actual[i]);
|
||||
} else if (nextKIdx < kArr.length) {
|
||||
if (i == kArr[nextKIdx]) {
|
||||
assertEquals(expected[i], actual[i]);
|
||||
while (nextKIdx < kArr.length && i == kArr[nextKIdx]) {
|
||||
nextKIdx++;
|
||||
}
|
||||
} else {
|
||||
assertTrue(actual[i].compareTo(expected[kArr[nextKIdx]]) <= 0);
|
||||
}
|
||||
} else {
|
||||
assertTrue(actual[i].compareTo(expected[kArr[kArr.length - 1]]) >= 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -108,5 +108,30 @@ public class TestRadixSelector extends LuceneTestCase {
|
|||
assertTrue(actual[i].compareTo(actual[k]) >= 0);
|
||||
}
|
||||
}
|
||||
|
||||
final int[] kArr = new int[TestUtil.nextInt(random(), 1, 10)];
|
||||
for (int i = 0; i < kArr.length; i++) {
|
||||
kArr[i] = TestUtil.nextInt(random(), from, to - 1);
|
||||
}
|
||||
selector.multiSelect(from, to, kArr);
|
||||
|
||||
int nextKIdx = 0;
|
||||
Arrays.sort(kArr);
|
||||
for (int i = 0; i < actual.length; ++i) {
|
||||
if (i < from || i >= to) {
|
||||
assertSame(arr[i], actual[i]);
|
||||
} else if (nextKIdx < kArr.length) {
|
||||
if (i == kArr[nextKIdx]) {
|
||||
assertEquals(expected[i], actual[i]);
|
||||
while (nextKIdx < kArr.length && i == kArr[nextKIdx]) {
|
||||
nextKIdx++;
|
||||
}
|
||||
} else {
|
||||
assertTrue(actual[i].compareTo(expected[kArr[nextKIdx]]) <= 0);
|
||||
}
|
||||
} else {
|
||||
assertTrue(actual[i].compareTo(expected[kArr[kArr.length - 1]]) >= 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue