This commit is contained in:
Houston Putman 2024-10-15 14:53:14 -05:00
parent 5aa46e5cae
commit 08196572bf
6 changed files with 36 additions and 27 deletions

View File

@ -616,16 +616,17 @@ public final class ArrayUtil {
} }
/** /**
* Reorganize {@code arr[from:to[} so that the elements at the offsets included in {@code k} are at the same position as if * Reorganize {@code arr[from:to[} so that the elements at the offsets included in {@code k} are
* {@code arr[from:to]} was sorted, and all elements on their left are less than or equal to them, and * at the same position as if {@code arr[from:to]} was sorted, and all elements on their left are
* all elements on their right are greater than or equal to them. * less than or equal to them, and all elements on their right are greater than or equal to them.
* *
* <p>This runs in linear time on average and in {@code n log(n)} time in the worst case. * <p>This runs in linear time on average and in {@code n log(n)} time in the worst case.
* *
* @param arr Array to be re-organized. * @param arr Array to be re-organized.
* @param from Starting index for re-organization. Elements before this index will be left as is. * @param from Starting index for re-organization. Elements before this index will be left as is.
* @param to Ending index. Elements after this index will be left as is. * @param to Ending index. Elements after this index will be left as is.
* @param k Array containing the Indexes of elements to sort from. Values must be less than 'to' and greater than or equal to 'from'. This list will be sorted during the call. * @param k Array containing the Indexes of elements to sort from. Values must be less than 'to'
* and greater than or equal to 'from'. This list will be sorted during the call.
* @param comparator Comparator to use for sorting * @param comparator Comparator to use for sorting
*/ */
public static <T> void multiSelect( public static <T> void multiSelect(

View File

@ -154,7 +154,8 @@ public abstract class IntroSelector extends Selector {
// Visible for testing. // Visible for testing.
void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) { void multiSelect(int from, int to, int[] k, int kFrom, int kTo, int maxDepth) {
// If there is only 1 k value to select in this group, then use the single-k select method, which does not do recursion // If there is only 1 k value to select in this group, then use the single-k select method,
// which does not do recursion
if (kTo - kFrom == 1) { if (kTo - kFrom == 1) {
select(from, to, k[kFrom], maxDepth); select(from, to, k[kFrom], maxDepth);
return; return;
@ -240,7 +241,7 @@ public abstract class IntroSelector extends Selector {
// Select the K values contained in the bottom and top partitions. // Select the K values contained in the bottom and top partitions.
int topKFrom = kTo; int topKFrom = kTo;
int bottomKTo = kFrom; int bottomKTo = kFrom;
for (int ki = kTo-1; ki >= kFrom; ki--) { for (int ki = kTo - 1; ki >= kFrom; ki--) {
if (k[ki] >= i) { if (k[ki] >= i) {
topKFrom = ki; topKFrom = ki;
} }
@ -249,11 +250,13 @@ public abstract class IntroSelector extends Selector {
break; break;
} }
} }
// Recursively select the relevant k-values from the bottom group, if there are any k-values to select there // Recursively select the relevant k-values from the bottom group, if there are any k-values
// to select there
if (bottomKTo > kFrom) { if (bottomKTo > kFrom) {
multiSelect(from, j + 1, k, kFrom, bottomKTo, maxDepth); multiSelect(from, j + 1, k, kFrom, bottomKTo, maxDepth);
} }
// Recursively select the relevant k-values from the top group, if there are any k-values to select there // Recursively select the relevant k-values from the top group, if there are any k-values to
// select there
if (topKFrom < kTo) { if (topKFrom < kTo) {
multiSelect(i, to, k, topKFrom, kTo, maxDepth); multiSelect(i, to, k, topKFrom, kTo, maxDepth);
} }

View File

@ -224,12 +224,14 @@ public abstract class RadixSelector extends Selector {
} }
final int bucketTo = bucketFrom + histogram[bucket]; final int bucketTo = bucketFrom + histogram[bucket];
int bucketKTo = bucketKFrom; int bucketKTo = bucketKFrom;
// Move the right-side of the k-window up until the k-value is no longer in the current histogram bucket // Move the right-side of the k-window up until the k-value is no longer in the current
// histogram bucket
while (bucketKTo < kTo && k[bucketKTo] < bucketTo) { while (bucketKTo < kTo && k[bucketKTo] < bucketTo) {
bucketKTo++; bucketKTo++;
} }
// If there are any k-values captured in this histogram, continue down this path with those k-values // If there are any k-values captured in this histogram, continue down this path with those
// k-values
if (bucketKFrom < bucketKTo) { if (bucketKFrom < bucketKTo) {
partition(from, to, bucket, bucketFrom, bucketTo, d); partition(from, to, bucket, bucketFrom, bucketTo, d);

View File

@ -33,22 +33,23 @@ public abstract class Selector {
public abstract void select(int from, int to, int k); public abstract void select(int from, int to, int k);
/** /**
* Reorder elements so that the elements at all positions in {@code k} are the same as if all elements were * Reorder elements so that the elements at all positions in {@code k} are the same as if all
* sorted and all other elements are partitioned around it: {@code [from, k[n])} only contains * elements were sorted and all other elements are partitioned around it: {@code [from, k[n])}
* elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements * only contains elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only
* that are greater than or equal to {@code k[n]}. * contains elements that are greater than or equal to {@code k[n]}.
*/ */
public void multiSelect(int from, int to, int[] k) { public void multiSelect(int from, int to, int[] k) {
multiSelect(from, to, k, 0, k.length); multiSelect(from, to, k, 0, k.length);
} }
/** /**
* Reorder elements so that the elements at all positions in {@code k} are the same as if all elements were * Reorder elements so that the elements at all positions in {@code k} are the same as if all
* sorted and all other elements are partitioned around it: {@code [from, k[n])} only contains * elements were sorted and all other elements are partitioned around it: {@code [from, k[n])}
* elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only contains elements * only contains elements that are less than or equal to {@code k[n]} and {@code (k[n], to)} only
* that are greater than or equal to {@code k[n]}. * contains elements that are greater than or equal to {@code k[n]}.
* *
* The array {@code k} will be sorted, so {@code kFrom} and {@code kTo} must be referring to the sorted order. * <p>The array {@code k} will be sorted, so {@code kFrom} and {@code kTo} must be referring to
* the sorted order.
*/ */
public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) { public void multiSelect(int from, int to, int[] k, int kFrom, int kTo) {
// Default implementation only uses select(), so it is not optimal // Default implementation only uses select(), so it is not optimal

View File

@ -439,7 +439,8 @@ public class ScalarQuantizer {
double[] lowerSum) { double[] lowerSum) {
assert confidenceIntervals.length == upperSum.length assert confidenceIntervals.length == upperSum.length
&& confidenceIntervals.length == lowerSum.length; && confidenceIntervals.length == lowerSum.length;
float[][] upperAndLowerQuantiles = getUpperAndLowerQuantiles(quantileGatheringScratch, confidenceIntervals); float[][] upperAndLowerQuantiles =
getUpperAndLowerQuantiles(quantileGatheringScratch, confidenceIntervals);
for (int i = 0; i < confidenceIntervals.length; i++) { for (int i = 0; i < confidenceIntervals.length; i++) {
upperSum[i] += upperAndLowerQuantiles[i][1]; upperSum[i] += upperAndLowerQuantiles[i][1];
lowerSum[i] += upperAndLowerQuantiles[i][0]; lowerSum[i] += upperAndLowerQuantiles[i][0];
@ -591,8 +592,8 @@ public class ScalarQuantizer {
// After the selection process, pick out the given quantile values // After the selection process, pick out the given quantile values
for (int i = 0; i < confidenceIntervals.length; i++) { for (int i = 0; i < confidenceIntervals.length; i++) {
minAndMaxPerInterval[i][0] = arr[selectorIndexes[2*i]]; minAndMaxPerInterval[i][0] = arr[selectorIndexes[2 * i]];
minAndMaxPerInterval[i][1] = arr[selectorIndexes[2*i + 1]]; minAndMaxPerInterval[i][1] = arr[selectorIndexes[2 * i + 1]];
} }
return minAndMaxPerInterval; return minAndMaxPerInterval;
} }

View File

@ -125,22 +125,23 @@ public class TestScalarQuantizer extends LuceneTestCase {
percs[i] = (float) i; percs[i] = (float) i;
} }
shuffleArray(percs); shuffleArray(percs);
float[][] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.9f}); float[][] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[] {0.9f});
assertEquals(50f, upperAndLower[0][0], 1e-7); assertEquals(50f, upperAndLower[0][0], 1e-7);
assertEquals(949f, upperAndLower[0][1], 1e-7); assertEquals(949f, upperAndLower[0][1], 1e-7);
shuffleArray(percs); shuffleArray(percs);
upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.95f}); upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[] {0.95f});
assertEquals(25f, upperAndLower[0][0], 1e-7); assertEquals(25f, upperAndLower[0][0], 1e-7);
assertEquals(974f, upperAndLower[0][1], 1e-7); assertEquals(974f, upperAndLower[0][1], 1e-7);
shuffleArray(percs); shuffleArray(percs);
upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[]{0.99f}); upperAndLower = ScalarQuantizer.getUpperAndLowerQuantiles(percs, new float[] {0.99f});
assertEquals(5f, upperAndLower[0][0], 1e-7); assertEquals(5f, upperAndLower[0][0], 1e-7);
assertEquals(994f, upperAndLower[0][1], 1e-7); assertEquals(994f, upperAndLower[0][1], 1e-7);
} }
public void testEdgeCase() { public void testEdgeCase() {
float[][] upperAndLower = float[][] upperAndLower =
ScalarQuantizer.getUpperAndLowerQuantiles(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, new float[]{0.9f}); ScalarQuantizer.getUpperAndLowerQuantiles(
new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, new float[] {0.9f});
assertEquals(1f, upperAndLower[0][0], 1e-7f); assertEquals(1f, upperAndLower[0][0], 1e-7f);
assertEquals(1f, upperAndLower[0][1], 1e-7f); assertEquals(1f, upperAndLower[0][1], 1e-7f);
} }
@ -194,7 +195,7 @@ public class TestScalarQuantizer extends LuceneTestCase {
public void testFromVectorsAutoInterval4Bit() throws IOException { public void testFromVectorsAutoInterval4Bit() throws IOException {
int dims = 128; int dims = 128;
int numVecs = 100; int numVecs = 1000;
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
float[][] floats = randomFloats(numVecs, dims); float[][] floats = randomFloats(numVecs, dims);