LUCENE-10289: Change DocIdSetBuilder#grow() from taking an int to a long (#520)

This commit is contained in:
Ignacio Vera 2021-12-07 07:41:09 +01:00 committed by GitHub
parent 35eff443a7
commit af1e68b891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 9 deletions

View File

@ -41,6 +41,8 @@ API Changes
* LUCENE-10244: MultiCollector::getCollectors is now public, allowing users to access the wrapped * LUCENE-10244: MultiCollector::getCollectors is now public, allowing users to access the wrapped
collectors. (Andriy Redko) collectors. (Andriy Redko)
* LUCENE-10289: DocIdSetBuilder#grow() takes now a long instead of an int. (Ignacio Vera)
New Features New Features
--------------------- ---------------------

View File

@ -166,9 +166,9 @@ public final class DocIdSetBuilder {
bitSet.or(iter); bitSet.or(iter);
return; return;
} }
int cost = (int) Math.min(Integer.MAX_VALUE, iter.cost()); long cost = iter.cost();
BulkAdder adder = grow(cost); BulkAdder adder = grow(cost);
for (int i = 0; i < cost; ++i) { for (long i = 0; i < cost; ++i) {
int doc = iter.nextDoc(); int doc = iter.nextDoc();
if (doc == DocIdSetIterator.NO_MORE_DOCS) { if (doc == DocIdSetIterator.NO_MORE_DOCS) {
return; return;
@ -184,20 +184,35 @@ public final class DocIdSetBuilder {
* Reserve space and return a {@link BulkAdder} object that can be used to add up to {@code * Reserve space and return a {@link BulkAdder} object that can be used to add up to {@code
* numDocs} documents. * numDocs} documents.
*/ */
public BulkAdder grow(int numDocs) { public BulkAdder grow(long numDocs) {
if (bitSet == null) { if (bitSet == null) {
if ((long) totalAllocated + numDocs <= threshold) { if ((long) totalAllocated + checkTotalAllocatedOverflow(numDocs) <= threshold) {
ensureBufferCapacity(numDocs); // For extra safety we use toIntExact
ensureBufferCapacity(Math.toIntExact(numDocs));
} else { } else {
upgradeToBitSet(); upgradeToBitSet();
counter += numDocs; counter += checkCounterOverflow(numDocs);
} }
} else { } else {
counter += numDocs; counter += checkCounterOverflow(numDocs);
} }
return adder; return adder;
} }
private long checkTotalAllocatedOverflow(long numDocs) {
if ((long) totalAllocated + numDocs < totalAllocated) {
throw new ArithmeticException("long overflow");
}
return numDocs;
}
private long checkCounterOverflow(long numDocs) {
if (counter + numDocs < counter) {
throw new ArithmeticException("long overflow");
}
return numDocs;
}
private void ensureBufferCapacity(int numDocs) { private void ensureBufferCapacity(int numDocs) {
if (buffers.isEmpty()) { if (buffers.isEmpty()) {
addBuffer(additionalCapacity(numDocs)); addBuffer(additionalCapacity(numDocs));

View File

@ -128,9 +128,9 @@ public class TestDocIdSetBuilder extends LuceneTestCase {
for (j = 0; j < array.length; ) { for (j = 0; j < array.length; ) {
final int l = TestUtil.nextInt(random(), 1, array.length - j); final int l = TestUtil.nextInt(random(), 1, array.length - j);
DocIdSetBuilder.BulkAdder adder = null; DocIdSetBuilder.BulkAdder adder = null;
for (int k = 0, budget = 0; k < l; ++k) { for (long k = 0, budget = 0; k < l; ++k) {
if (budget == 0 || rarely()) { if (budget == 0 || rarely()) {
budget = TestUtil.nextInt(random(), 1, l - k + 5); budget = TestUtil.nextLong(random(), 1, l - k + 5);
adder = builder.grow(budget); adder = builder.grow(budget);
} }
adder.add(array[j++]); adder.add(array[j++]);
@ -241,6 +241,38 @@ public class TestDocIdSetBuilder extends LuceneTestCase {
assertTrue(builder.multivalued); assertTrue(builder.multivalued);
} }
@Nightly
public void testLotsOfDocs() throws IOException {
final int docCount = 1;
final long numDocs = (long) Integer.MAX_VALUE + 1;
PointValues values = new DummyPointValues(docCount, numDocs);
DocIdSetBuilder builder = new DocIdSetBuilder(100, values, "foo");
DocIdSetBuilder.BulkAdder adder = builder.grow(numDocs);
for (long i = 0; i < numDocs; ++i) {
adder.add(0);
}
DocIdSet result = builder.build();
assertTrue(result instanceof BitDocIdSet);
assertEquals(1, result.iterator().cost());
}
public void testLongOverflow() throws IOException {
{
DocIdSetBuilder builder = new DocIdSetBuilder(100);
builder.grow(1L);
Exception ex = expectThrows(ArithmeticException.class, () -> builder.grow(Long.MAX_VALUE));
assertEquals("long overflow", ex.getMessage());
}
{
DocIdSetBuilder builder = new DocIdSetBuilder(100);
builder.grow((long) Integer.MAX_VALUE + 1);
Exception ex =
expectThrows(
ArithmeticException.class, () -> builder.grow(Long.MAX_VALUE - Integer.MAX_VALUE));
assertEquals("long overflow", ex.getMessage());
}
}
private static class DummyTerms extends Terms { private static class DummyTerms extends Terms {
private final int docCount; private final int docCount;