diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java index ce6fb5cb0d4..f41bc31b7c1 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java @@ -48,405 +48,405 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** Tests HNSW KNN graphs */ public class TestHnsw extends LuceneTestCase { - // test writing out and reading in a graph gives the same graph - public void testReadWrite() throws IOException { - int dim = random().nextInt(100) + 1; - int nDoc = random().nextInt(100) + 1; - RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random()); - RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy(); - long seed = random().nextLong(); - HnswGraphBuilder.randSeed = seed; - HnswGraph hnsw = HnswGraphBuilder.build((RandomAccessVectorValuesProducer) vectors); - // Recreate the graph while indexing with the same random seed and write it out - HnswGraphBuilder.randSeed = seed; - try (Directory dir = newDirectory()) { - int nVec = 0, indexedDoc = 0; - // Don't merge randomly, create a single segment because we rely on the docid ordering for this test - IndexWriterConfig iwc = new IndexWriterConfig() - .setCodec(Codec.forName("Lucene90")); - try (IndexWriter iw = new IndexWriter(dir, iwc)) { - while (v2.nextDoc() != NO_MORE_DOCS) { - while (indexedDoc < v2.docID()) { - // increment docId in the index by adding empty documents - iw.addDocument(new Document()); - indexedDoc++; - } - Document doc = new Document(); - doc.add(new VectorField("field", v2.vectorValue(), v2.searchStrategy)); - doc.add(new StoredField("id", v2.docID())); - iw.addDocument(doc); - nVec++; - indexedDoc++; - } - } - try (IndexReader reader = DirectoryReader.open(dir)) { - for (LeafReaderContext ctx : reader.leaves()) { - VectorValues values = ctx.reader().getVectorValues("field"); - assertEquals(vectors.searchStrategy, values.searchStrategy()); - assertEquals(dim, values.dimension()); - assertEquals(nVec, values.size()); - assertEquals(indexedDoc, ctx.reader().maxDoc()); - assertEquals(indexedDoc, ctx.reader().numDocs()); - assertVectorsEqual(v3, values); - KnnGraphValues graphValues = ((Lucene90VectorReader) ((CodecReader) ctx.reader()).getVectorReader()).getGraphValues("field"); - assertGraphEqual(hnsw.getGraphValues(), graphValues, nVec); - } - } + // test writing out and reading in a graph gives the same graph + public void testReadWrite() throws IOException { + int dim = random().nextInt(100) + 1; + int nDoc = random().nextInt(100) + 1; + RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random()); + RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy(); + long seed = random().nextLong(); + HnswGraphBuilder.randSeed = seed; + HnswGraph hnsw = HnswGraphBuilder.build((RandomAccessVectorValuesProducer) vectors); + // Recreate the graph while indexing with the same random seed and write it out + HnswGraphBuilder.randSeed = seed; + try (Directory dir = newDirectory()) { + int nVec = 0, indexedDoc = 0; + // Don't merge randomly, create a single segment because we rely on the docid ordering for this test + IndexWriterConfig iwc = new IndexWriterConfig() + .setCodec(Codec.forName("Lucene90")); + try (IndexWriter iw = new IndexWriter(dir, iwc)) { + while (v2.nextDoc() != NO_MORE_DOCS) { + while (indexedDoc < v2.docID()) { + // increment docId in the index by adding empty documents + iw.addDocument(new Document()); + indexedDoc++; + } + Document doc = new Document(); + doc.add(new VectorField("field", v2.vectorValue(), v2.searchStrategy)); + doc.add(new StoredField("id", v2.docID())); + iw.addDocument(doc); + nVec++; + indexedDoc++; } + } + try (IndexReader reader = DirectoryReader.open(dir)) { + for (LeafReaderContext ctx : reader.leaves()) { + VectorValues values = ctx.reader().getVectorValues("field"); + assertEquals(vectors.searchStrategy, values.searchStrategy()); + assertEquals(dim, values.dimension()); + assertEquals(nVec, values.size()); + assertEquals(indexedDoc, ctx.reader().maxDoc()); + assertEquals(indexedDoc, ctx.reader().numDocs()); + assertVectorsEqual(v3, values); + KnnGraphValues graphValues = ((Lucene90VectorReader) ((CodecReader) ctx.reader()).getVectorReader()).getGraphValues("field"); + assertGraphEqual(hnsw.getGraphValues(), graphValues, nVec); + } + } + } + } + + // Make sure we actually approximately find the closest k elements. Mostly this is about + // ensuring that we have all the distance functions, comparators, priority queues and so on + // oriented in the right directions + public void testAknn() throws IOException { + int nDoc = 100; + RandomAccessVectorValuesProducer vectors = new CircularVectorValues(nDoc); + HnswGraph hnsw = HnswGraphBuilder.build(vectors); + // run some searches + Neighbors nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw.getGraphValues(), random()); + int sum = 0; + Neighbors.NeighborIterator it = nn.iterator(); + for (int node = it.next(); node != NO_MORE_DOCS; node = it.next()) { + sum += node; + } + // We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = 45 + assertTrue("sum(result docs)=" + sum, sum < 75); + } + + public void testMaxConnections() { + // verify that maxConnections is observed, and that the retained arcs point to the best-scoring neighbors + HnswGraph graph = new HnswGraph(1, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW); + graph.connectNodes(0, 1, 0); + assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); + graph.connectNodes(0, 2, 0.4f); + assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2)); + graph.connectNodes(2, 3, 0); + assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2)); + assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3)); + + graph = new HnswGraph(1, VectorValues.SearchStrategy.EUCLIDEAN_HNSW); + graph.connectNodes(0, 1, 1); + assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); + graph.connectNodes(0, 2, 2); + assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2)); + graph.connectNodes(2, 3, 1); + assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); + assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); + assertArrayEquals(new int[]{3}, graph.getNeighborNodes(2)); + assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3)); + } + + /** Returns vectors evenly distributed around the unit circle. + */ + class CircularVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { + private final int size; + private final float[] value; + + int doc = -1; + + CircularVectorValues(int size) { + this.size = size; + value = new float[2]; } - // Make sure we actually approximately find the closest k elements. Mostly this is about - // ensuring that we have all the distance functions, comparators, priority queues and so on - // oriented in the right directions - public void testAknn() throws IOException { - int nDoc = 100; - RandomAccessVectorValuesProducer vectors = new CircularVectorValues(nDoc); - HnswGraph hnsw = HnswGraphBuilder.build(vectors); - // run some searches - Neighbors nn = HnswGraph.search(new float[]{1, 0}, 10, 5, vectors.randomAccess(), hnsw.getGraphValues(), random()); - int sum = 0; - Neighbors.NeighborIterator it = nn.iterator(); - for (int node = it.next(); node != NO_MORE_DOCS; node = it.next()) { - sum += node; - } - // We expect to get approximately 100% recall; the lowest docIds are closest to zero; sum(0,9) = 45 - assertTrue("sum(result docs)=" + sum, sum < 75); + public CircularVectorValues copy() { + return new CircularVectorValues(size); } - public void testMaxConnections() { - // verify that maxConnections is observed, and that the retained arcs point to the best-scoring neighbors - HnswGraph graph = new HnswGraph(1, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW); - graph.connectNodes(0, 1, 0); - assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); - graph.connectNodes(0, 2, 0.4f); - assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2)); - graph.connectNodes(2, 3, 0); - assertArrayEquals(new int[]{2}, graph.getNeighborNodes(0)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2)); - assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3)); - - graph = new HnswGraph(1, VectorValues.SearchStrategy.EUCLIDEAN_HNSW); - graph.connectNodes(0, 1, 1); - assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); - graph.connectNodes(0, 2, 2); - assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(2)); - graph.connectNodes(2, 3, 1); - assertArrayEquals(new int[]{1}, graph.getNeighborNodes(0)); - assertArrayEquals(new int[]{0}, graph.getNeighborNodes(1)); - assertArrayEquals(new int[]{3}, graph.getNeighborNodes(2)); - assertArrayEquals(new int[]{2}, graph.getNeighborNodes(3)); + @Override + public SearchStrategy searchStrategy() { + return SearchStrategy.DOT_PRODUCT_HNSW; } - /** Returns vectors evenly distributed around the unit circle. - */ - class CircularVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { - private final int size; - private final float[] value; - - int doc = -1; - - CircularVectorValues(int size) { - this.size = size; - value = new float[2]; - } - - public CircularVectorValues copy() { - return new CircularVectorValues(size); - } - - @Override - public SearchStrategy searchStrategy() { - return SearchStrategy.DOT_PRODUCT_HNSW; - } - - @Override - public int dimension() { - return 2; - } - - @Override - public int size() { - return size; - } - - @Override - public float[] vectorValue() { - return vectorValue(doc); - } - - @Override - public RandomAccessVectorValues randomAccess() { - return new CircularVectorValues(size); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() { - return advance(doc + 1); - } - - @Override - public int advance(int target) { - if (target >= 0 && target < size) { - doc = target; - } else { - doc = NO_MORE_DOCS; - } - return doc; - } - - @Override - public long cost() { - return size; - } - - @Override - public float[] vectorValue(int ord) { - value[0] = (float) Math.cos(Math.PI * ord / (double) size); - value[1] = (float) Math.sin(Math.PI * ord / (double) size); - return value; - } - - @Override - public BytesRef binaryValue(int ord) { - return null; - } - - @Override - public TopDocs search(float[] target, int k, int fanout) { - return null; - } - + @Override + public int dimension() { + return 2; } - private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException { - for (int node = 0; node < size; node ++) { - g.seek(node); - h.seek(node); - assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h)); - } + @Override + public int size() { + return size; } - private Set getNeighborNodes(KnnGraphValues g) throws IOException { - Set neighbors = new HashSet<>(); - for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) { - neighbors.add(n); - } - return neighbors; + @Override + public float[] vectorValue() { + return vectorValue(doc); } - private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException { - int uDoc, vDoc; - while (true) { - uDoc = u.nextDoc(); - vDoc = v.nextDoc(); - assertEquals(uDoc, vDoc); - if (uDoc == NO_MORE_DOCS) { - break; - } - assertArrayEquals("vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f); - } + @Override + public RandomAccessVectorValues randomAccess() { + return new CircularVectorValues(size); } - public void testNeighbors() { - // make sure we have the sign correct - Neighbors nn = Neighbors.create(2, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW); - assertTrue(nn.insertWithOverflow(2, 0.5f)); - assertTrue(nn.insertWithOverflow(1, 0.2f)); - assertTrue(nn.insertWithOverflow(3, 1f)); - assertEquals(0.5f, nn.topScore(), 0); - nn.pop(); - assertEquals(1f, nn.topScore(), 0); - nn.pop(); - - Neighbors fn = Neighbors.create(2, VectorValues.SearchStrategy.EUCLIDEAN_HNSW); - assertTrue(fn.insertWithOverflow(2, 2)); - assertTrue(fn.insertWithOverflow(1, 1)); - assertFalse(fn.insertWithOverflow(3, 3)); - assertEquals(2f, fn.topScore(), 0); - fn.pop(); - assertEquals(1f, fn.topScore(), 0); + @Override + public int docID() { + return doc; } - private static float[] randomVector(Random random, int dim) { - float[] vec = new float[dim]; - for (int i = 0; i < dim; i++) { - vec[i] = random.nextFloat(); - } - VectorUtil.l2normalize(vec); - return vec; + @Override + public int nextDoc() { + return advance(doc + 1); } - /** - * Produces random vectors and caches them for random-access. - */ - class RandomVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { - - private final int dimension; - private final float[][] denseValues; - private final float[][] values; - private final float[] scratch; - private final SearchStrategy searchStrategy; - - final int numVectors; - final int maxDoc; - - private int pos = -1; - - RandomVectorValues(int size, int dimension, Random random) { - this.dimension = dimension; - values = new float[size][]; - denseValues = new float[size][]; - scratch = new float[dimension]; - int sz = 0; - int md = -1; - for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) { - values[offset] = randomVector(random, dimension); - denseValues[sz++] = values[offset]; - md = offset; - } - numVectors = sz; - maxDoc = md; - // get a random SearchStrategy other than NONE (0) - searchStrategy = SearchStrategy.values()[random.nextInt(SearchStrategy.values().length - 1) + 1]; - } - - private RandomVectorValues(int dimension, SearchStrategy searchStrategy, float[][] denseValues, float[][] values, int size) { - this.dimension = dimension; - this.searchStrategy = searchStrategy; - this.values = values; - this.denseValues = denseValues; - scratch = new float[dimension]; - numVectors = size; - maxDoc = values.length - 1; - } - - public RandomVectorValues copy() { - return new RandomVectorValues(dimension, searchStrategy, denseValues, values, numVectors); - } - - @Override - public int size() { - return numVectors; - } - - @Override - public SearchStrategy searchStrategy() { - return searchStrategy; - } - - @Override - public int dimension() { - return dimension; - } - - @Override - public float[] vectorValue() { - if(random().nextBoolean()) { - return values[pos]; - } else { - // Sometimes use the same scratch array repeatedly, mimicing what the codec will do. - // This should help us catch cases of aliasing where the same VectorValues source is used twice in a - // single computation. - System.arraycopy(values[pos], 0, scratch, 0, dimension); - return scratch; - } - } - - @Override - public RandomAccessVectorValues randomAccess() { - return copy(); - } - - @Override - public float[] vectorValue(int targetOrd) { - return denseValues[targetOrd]; - } - - @Override - public BytesRef binaryValue(int targetOrd) { - return null; - } - - @Override - public TopDocs search(float[] target, int k, int fanout) { - return null; - } - - private boolean seek(int target) { - if (target >= 0 && target < values.length && values[target] != null) { - pos = target; - return true; - } else { - return false; - } - } - - @Override - public int docID() { - return pos; - } - - @Override - public int nextDoc() { - return advance(pos + 1); - } - - public int advance(int target) { - while (++pos < values.length) { - if (seek(pos)) { - return pos; - } - } - return NO_MORE_DOCS; - } - - @Override - public long cost() { - return size(); - } - + @Override + public int advance(int target) { + if (target >= 0 && target < size) { + doc = target; + } else { + doc = NO_MORE_DOCS; + } + return doc; } - public void testBoundsCheckerMax() { - BoundsChecker max = BoundsChecker.create(false); - float f = random().nextFloat() - 0.5f; - // any float > -MAX_VALUE is in bounds - assertFalse(max.check(f)); - // f is now the bound (minus some delta) - max.update(f); - assertFalse(max.check(f)); // f is not out of bounds - assertFalse(max.check(f + 1)); // anything greater than f is in bounds - assertTrue(max.check(f - 1e-5f)); // delta is zero initially + @Override + public long cost() { + return size; } - public void testBoundsCheckerMin() { - BoundsChecker min = BoundsChecker.create(true); - float f = random().nextFloat() - 0.5f; - // any float < MAX_VALUE is in bounds - assertFalse(min.check(f)); - // f is now the bound (minus some delta) - min.update(f); - assertFalse(min.check(f)); // f is not out of bounds - assertFalse(min.check(f - 1)); // anything less than f is in bounds - assertTrue(min.check(f + 1e-5f)); // delta is zero initially + @Override + public float[] vectorValue(int ord) { + value[0] = (float) Math.cos(Math.PI * ord / (double) size); + value[1] = (float) Math.sin(Math.PI * ord / (double) size); + return value; } - public void testHnswGraphBuilderInvalid() { - expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0)); - expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0)); - expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0)); + @Override + public BytesRef binaryValue(int ord) { + return null; } + @Override + public TopDocs search(float[] target, int k, int fanout) { + return null; + } + + } + + private void assertGraphEqual(KnnGraphValues g, KnnGraphValues h, int size) throws IOException { + for (int node = 0; node < size; node ++) { + g.seek(node); + h.seek(node); + assertEquals("arcs differ for node " + node, getNeighborNodes(g), getNeighborNodes(h)); + } + } + + private Set getNeighborNodes(KnnGraphValues g) throws IOException { + Set neighbors = new HashSet<>(); + for (int n = g.nextNeighbor(); n != NO_MORE_DOCS; n = g.nextNeighbor()) { + neighbors.add(n); + } + return neighbors; + } + + private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException { + int uDoc, vDoc; + while (true) { + uDoc = u.nextDoc(); + vDoc = v.nextDoc(); + assertEquals(uDoc, vDoc); + if (uDoc == NO_MORE_DOCS) { + break; + } + assertArrayEquals("vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f); + } + } + + public void testNeighbors() { + // make sure we have the sign correct + Neighbors nn = Neighbors.create(2, VectorValues.SearchStrategy.DOT_PRODUCT_HNSW); + assertTrue(nn.insertWithOverflow(2, 0.5f)); + assertTrue(nn.insertWithOverflow(1, 0.2f)); + assertTrue(nn.insertWithOverflow(3, 1f)); + assertEquals(0.5f, nn.topScore(), 0); + nn.pop(); + assertEquals(1f, nn.topScore(), 0); + nn.pop(); + + Neighbors fn = Neighbors.create(2, VectorValues.SearchStrategy.EUCLIDEAN_HNSW); + assertTrue(fn.insertWithOverflow(2, 2)); + assertTrue(fn.insertWithOverflow(1, 1)); + assertFalse(fn.insertWithOverflow(3, 3)); + assertEquals(2f, fn.topScore(), 0); + fn.pop(); + assertEquals(1f, fn.topScore(), 0); + } + + private static float[] randomVector(Random random, int dim) { + float[] vec = new float[dim]; + for (int i = 0; i < dim; i++) { + vec[i] = random.nextFloat(); + } + VectorUtil.l2normalize(vec); + return vec; + } + + /** + * Produces random vectors and caches them for random-access. + */ + class RandomVectorValues extends VectorValues implements RandomAccessVectorValues, RandomAccessVectorValuesProducer { + + private final int dimension; + private final float[][] denseValues; + private final float[][] values; + private final float[] scratch; + private final SearchStrategy searchStrategy; + + final int numVectors; + final int maxDoc; + + private int pos = -1; + + RandomVectorValues(int size, int dimension, Random random) { + this.dimension = dimension; + values = new float[size][]; + denseValues = new float[size][]; + scratch = new float[dimension]; + int sz = 0; + int md = -1; + for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) { + values[offset] = randomVector(random, dimension); + denseValues[sz++] = values[offset]; + md = offset; + } + numVectors = sz; + maxDoc = md; + // get a random SearchStrategy other than NONE (0) + searchStrategy = SearchStrategy.values()[random.nextInt(SearchStrategy.values().length - 1) + 1]; + } + + private RandomVectorValues(int dimension, SearchStrategy searchStrategy, float[][] denseValues, float[][] values, int size) { + this.dimension = dimension; + this.searchStrategy = searchStrategy; + this.values = values; + this.denseValues = denseValues; + scratch = new float[dimension]; + numVectors = size; + maxDoc = values.length - 1; + } + + public RandomVectorValues copy() { + return new RandomVectorValues(dimension, searchStrategy, denseValues, values, numVectors); + } + + @Override + public int size() { + return numVectors; + } + + @Override + public SearchStrategy searchStrategy() { + return searchStrategy; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public float[] vectorValue() { + if(random().nextBoolean()) { + return values[pos]; + } else { + // Sometimes use the same scratch array repeatedly, mimicing what the codec will do. + // This should help us catch cases of aliasing where the same VectorValues source is used twice in a + // single computation. + System.arraycopy(values[pos], 0, scratch, 0, dimension); + return scratch; + } + } + + @Override + public RandomAccessVectorValues randomAccess() { + return copy(); + } + + @Override + public float[] vectorValue(int targetOrd) { + return denseValues[targetOrd]; + } + + @Override + public BytesRef binaryValue(int targetOrd) { + return null; + } + + @Override + public TopDocs search(float[] target, int k, int fanout) { + return null; + } + + private boolean seek(int target) { + if (target >= 0 && target < values.length && values[target] != null) { + pos = target; + return true; + } else { + return false; + } + } + + @Override + public int docID() { + return pos; + } + + @Override + public int nextDoc() { + return advance(pos + 1); + } + + public int advance(int target) { + while (++pos < values.length) { + if (seek(pos)) { + return pos; + } + } + return NO_MORE_DOCS; + } + + @Override + public long cost() { + return size(); + } + + } + + public void testBoundsCheckerMax() { + BoundsChecker max = BoundsChecker.create(false); + float f = random().nextFloat() - 0.5f; + // any float > -MAX_VALUE is in bounds + assertFalse(max.check(f)); + // f is now the bound (minus some delta) + max.update(f); + assertFalse(max.check(f)); // f is not out of bounds + assertFalse(max.check(f + 1)); // anything greater than f is in bounds + assertTrue(max.check(f - 1e-5f)); // delta is zero initially + } + + public void testBoundsCheckerMin() { + BoundsChecker min = BoundsChecker.create(true); + float f = random().nextFloat() - 0.5f; + // any float < MAX_VALUE is in bounds + assertFalse(min.check(f)); + // f is now the bound (minus some delta) + min.update(f); + assertFalse(min.check(f)); // f is not out of bounds + assertFalse(min.check(f - 1)); // anything less than f is in bounds + assertTrue(min.check(f + 1e-5f)); // delta is zero initially + } + + public void testHnswGraphBuilderInvalid() { + expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0)); + expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0)); + expectThrows(IllegalArgumentException.class, () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0)); + } + }