address some nocommits; add randomized test for 1D ints

This commit is contained in:
Mike McCandless 2016-02-23 18:41:03 -05:00
parent 8d88bb7a13
commit 3f5ed6eb67
2 changed files with 148 additions and 5 deletions

View File

@ -42,7 +42,6 @@ import org.apache.lucene.util.StringHelper;
/** Finds all documents whose point value, previously indexed with e.g. {@link org.apache.lucene.document.LongPoint}, is contained in the /** Finds all documents whose point value, previously indexed with e.g. {@link org.apache.lucene.document.LongPoint}, is contained in the
* specified set */ * specified set */
// nocommit make abstract
public class PointInSetQuery extends Query { public class PointInSetQuery extends Query {
// A little bit overkill for us, since all of our "terms" are always in the same field: // A little bit overkill for us, since all of our "terms" are always in the same field:
final PrefixCodedTerms sortedPackedPoints; final PrefixCodedTerms sortedPackedPoints;
@ -54,8 +53,13 @@ public class PointInSetQuery extends Query {
/** {@code packedPoints} must already be sorted! */ /** {@code packedPoints} must already be sorted! */
protected PointInSetQuery(String field, int numDims, int bytesPerDim, BytesRefIterator packedPoints) throws IOException { protected PointInSetQuery(String field, int numDims, int bytesPerDim, BytesRefIterator packedPoints) throws IOException {
this.field = field; this.field = field;
// nocommit validate these: if (bytesPerDim < 1 || bytesPerDim > PointValues.MAX_NUM_BYTES) {
throw new IllegalArgumentException("bytesPerDim must be > 0 and <= " + PointValues.MAX_NUM_BYTES + "; got " + bytesPerDim);
}
this.bytesPerDim = bytesPerDim; this.bytesPerDim = bytesPerDim;
if (numDims < 1 || bytesPerDim > PointValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException("numDims must be > 0 and <= " + PointValues.MAX_DIMENSIONS + "; got " + numDims);
}
this.numDims = numDims; this.numDims = numDims;
// In the 1D case this works well (the more points, the more common prefixes they share, typically), but in // In the 1D case this works well (the more points, the more common prefixes they share, typically), but in
@ -64,9 +68,8 @@ public class PointInSetQuery extends Query {
BytesRefBuilder previous = null; BytesRefBuilder previous = null;
BytesRef current; BytesRef current;
while ((current = packedPoints.next()) != null) { while ((current = packedPoints.next()) != null) {
// nocommit make sure a test tests this:
if (current.length != numDims * bytesPerDim) { if (current.length != numDims * bytesPerDim) {
throw new IllegalArgumentException("packed point length should be " + (numDims * bytesPerDim) + " but got " + current.length + "; field=\"" + field + "\", numDims=" + numDims + " bytesPerDim=" + bytesPerDim); throw new IllegalArgumentException("packed point length should be " + (numDims * bytesPerDim) + " but got " + current.length + "; field=\"" + field + "\" numDims=" + numDims + " bytesPerDim=" + bytesPerDim);
} }
if (previous == null) { if (previous == null) {
previous = new BytesRefBuilder(); previous = new BytesRefBuilder();
@ -104,7 +107,7 @@ public class PointInSetQuery extends Query {
if (fieldInfo.getPointDimensionCount() != numDims) { if (fieldInfo.getPointDimensionCount() != numDims) {
throw new IllegalArgumentException("field=\"" + field + "\" was indexed with numDims=" + fieldInfo.getPointDimensionCount() + " but this query has numDims=" + numDims); throw new IllegalArgumentException("field=\"" + field + "\" was indexed with numDims=" + fieldInfo.getPointDimensionCount() + " but this query has numDims=" + numDims);
} }
if (bytesPerDim != fieldInfo.getPointNumBytes()) { if (fieldInfo.getPointNumBytes() != bytesPerDim) {
throw new IllegalArgumentException("field=\"" + field + "\" was indexed with bytesPerDim=" + fieldInfo.getPointNumBytes() + " but this query has bytesPerDim=" + bytesPerDim); throw new IllegalArgumentException("field=\"" + field + "\" was indexed with bytesPerDim=" + fieldInfo.getPointNumBytes() + " but this query has bytesPerDim=" + bytesPerDim);
} }
int bytesPerDim = fieldInfo.getPointNumBytes(); int bytesPerDim = fieldInfo.getPointNumBytes();

View File

@ -22,7 +22,9 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.BitSet; import java.util.BitSet;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -57,6 +59,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.NumericUtils; import org.apache.lucene.util.NumericUtils;
@ -1115,6 +1118,129 @@ public class TestPointQueries extends LuceneTestCase {
} }
private int[] toArray(Set<Integer> valuesSet) {
int[] values = new int[valuesSet.size()];
int upto = 0;
for(Integer value : valuesSet) {
values[upto++] = value;
}
return values;
}
public void testRandomPointInSetQuery() throws Exception {
final Set<Integer> valuesSet = new HashSet<>();
int numValues = TestUtil.nextInt(random(), 1, 100);
while (valuesSet.size() < numValues) {
valuesSet.add(random().nextInt());
}
int[] values = toArray(valuesSet);
int numDocs = TestUtil.nextInt(random(), 1, 10000);
if (VERBOSE) {
System.out.println("TEST: numValues=" + numValues + " numDocs=" + numDocs);
}
Directory dir;
if (numDocs > 100000) {
dir = newFSDirectory(createTempDir("TestPointQueries"));
} else {
dir = newDirectory();
}
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setCodec(getCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
// nocommit multi-valued too
int[] docValues = new int[numDocs];
for(int i=0;i<numDocs;i++) {
int x = values[random().nextInt(values.length)];
Document doc = new Document();
doc.add(new IntPoint("int", x));
docValues[i] = x;
w.addDocument(doc);
}
if (random().nextBoolean()) {
if (VERBOSE) {
System.out.println(" forceMerge(1)");
}
w.forceMerge(1);
}
final IndexReader r = w.getReader();
w.close();
IndexSearcher s = newSearcher(r);
int numThreads = TestUtil.nextInt(random(), 2, 5);
if (VERBOSE) {
System.out.println("TEST: use " + numThreads + " query threads; searcher=" + s);
}
List<Thread> threads = new ArrayList<>();
final int iters = atLeast(100);
final CountDownLatch startingGun = new CountDownLatch(1);
final AtomicBoolean failed = new AtomicBoolean();
for(int i=0;i<numThreads;i++) {
Thread thread = new Thread() {
@Override
public void run() {
try {
_run();
} catch (Exception e) {
failed.set(true);
throw new RuntimeException(e);
}
}
private void _run() throws Exception {
startingGun.await();
for (int iter=0;iter<iters && failed.get() == false;iter++) {
int numValidValuesToQuery = random().nextInt(values.length);
Set<Integer> valuesToQuery = new HashSet<>();
while (valuesToQuery.size() < numValidValuesToQuery) {
valuesToQuery.add(values[random().nextInt(values.length)]);
}
int numExtraValuesToQuery = random().nextInt(20);
while (valuesToQuery.size() < numValidValuesToQuery + numExtraValuesToQuery) {
// nocommit fix test to sometimes use "narrow" range of values
valuesToQuery.add(random().nextInt());
}
int expectedCount = 0;
for(int value : docValues) {
if (valuesToQuery.contains(value)) {
expectedCount++;
}
}
if (VERBOSE) {
System.out.println("TEST: thread=" + Thread.currentThread() + " values=" + valuesToQuery + " expectedCount=" + expectedCount);
}
assertEquals(expectedCount, s.count(IntPoint.newSetQuery("int", toArray(valuesToQuery))));
}
}
};
thread.setName("T" + i);
thread.start();
threads.add(thread);
}
startingGun.countDown();
for(Thread thread : threads) {
thread.join();
}
IOUtils.close(r, dir);
}
// nocommit fix existing randomized tests to sometimes randomly use PointInSet instead // nocommit fix existing randomized tests to sometimes randomly use PointInSet instead
// nocommit need 2D test too // nocommit need 2D test too
@ -1207,4 +1333,18 @@ public class TestPointQueries extends LuceneTestCase {
r.close(); r.close();
dir.close(); dir.close();
} }
public void testInvalidPointInSetQuery() throws Exception {
IllegalArgumentException expected = expectThrows(IllegalArgumentException.class,
() -> {
new PointInSetQuery("foo", 3, 4,
new BytesRefIterator() {
@Override
public BytesRef next() {
return new BytesRef(new byte[3]);
}
});
});
assertEquals("packed point length should be 12 but got 3; field=\"foo\" numDims=3 bytesPerDim=4", expected.getMessage());
}
} }