diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java index 8244026060c..6a69ab9f4e2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java @@ -259,7 +259,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader { float score = results.topScore(); results.pop(); if (reversed) { - score = (float) Math.exp(-score / target.length); + score = 1 / (1 + score); } scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score); } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java new file mode 100644 index 00000000000..5dccb8042e4 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Objects; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; + +/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */ +public class KnnVectorQuery extends Query { + + private static final TopDocs NO_RESULTS = + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + + private final String field; + private final float[] target; + private final int k; + + /** + * Find the k nearest documents to the target vector according to the vectors in the + * given field. target vector. + * + * @param field a field that has been indexed as a {@link KnnVectorField}. + * @param target the target of the search + * @param k the number of documents to find + * @throws IllegalArgumentException if k is less than 1 + */ + public KnnVectorQuery(String field, float[] target, int k) { + this.field = field; + this.target = target; + this.k = k; + if (k < 1) { + throw new IllegalArgumentException("k must be at least 1, got: " + k); + } + } + + @Override + public Query rewrite(IndexReader reader) throws IOException { + TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()]; + for (LeafReaderContext ctx : reader.leaves()) { + perLeafResults[ctx.ord] = searchLeaf(ctx, Math.min(k, reader.numDocs())); + } + // Merge sort the results + TopDocs topK = TopDocs.merge(k, perLeafResults); + if (topK.scoreDocs.length == 0) { + return new MatchNoDocsQuery(); + } + return createRewrittenQuery(reader, topK); + } + + private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException { + TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf); + if (results == null) { + return NO_RESULTS; + } + if (ctx.docBase > 0) { + for (ScoreDoc scoreDoc : results.scoreDocs) { + scoreDoc.doc += ctx.docBase; + } + } + return results; + } + + private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { + int len = topK.scoreDocs.length; + Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); + int[] docs = new int[len]; + float[] scores = new float[len]; + for (int i = 0; i < len; i++) { + docs[i] = topK.scoreDocs[i].doc; + scores[i] = topK.scoreDocs[i].score; + } + int[] segmentStarts = findSegmentStarts(reader, docs); + return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.hashCode()); + } + + private int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + @Override + public String toString(String field) { + return ""; + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(field)) { + visitor.visitLeaf(this); + } + } + + @Override + public boolean equals(Object obj) { + return obj instanceof KnnVectorQuery + && ((KnnVectorQuery) obj).k == k + && ((KnnVectorQuery) obj).field.equals(field) + && Arrays.equals(((KnnVectorQuery) obj).target, target); + } + + @Override + public int hashCode() { + return Objects.hash(field, k, Arrays.hashCode(target)); + } + + /** Caches the results of a KnnVector search: a list of docs and their scores */ + static class DocAndScoreQuery extends Query { + + private final int k; + private final int[] docs; + private final float[] scores; + private final int[] segmentStarts; + private final int readerHash; + + /** + * Constructor + * + * @param k the number of documents requested + * @param docs the global docids of documents that match, in ascending order + * @param scores the scores of the matching documents + * @param segmentStarts the indexes in docs and scores corresponding to the first matching + * document in each segment. If a segment has no matching documents, it should be assigned + * the index of the next segment that does. There should be a final entry that is always + * docs.length-1. + * @param readerHash a hash code identifying the IndexReader used to create this query + */ + DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, int readerHash) { + this.k = k; + this.docs = docs; + this.scores = scores; + this.segmentStarts = segmentStarts; + this.readerHash = readerHash; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + if (searcher.getIndexReader().hashCode() != readerHash) { + throw new IllegalStateException("This DocAndScore query was created by a different reader"); + } + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) { + int found = Arrays.binarySearch(docs, doc); + if (found < 0) { + return Explanation.noMatch("not in top " + k); + } + return Explanation.match(scores[found], "within top " + k); + } + + @Override + public Scorer scorer(LeafReaderContext context) { + + return new Scorer(this) { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return docIdNoShadow(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return docIdNoShadow(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; + } + + @Override + public float getMaxScore(int docid) { + docid += context.docBase; + float maxScore = 0; + for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) { + maxScore = Math.max(maxScore, scores[idx]); + } + return maxScore; + } + + @Override + public float score() { + return scores[upTo]; + } + + @Override + public int advanceShallow(int docid) { + int start = Math.max(upTo, lower); + int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); + if (docidIndex < 0) { + docidIndex = -1 - docidIndex; + } + if (docidIndex >= upper) { + return NO_MORE_DOCS; + } + return docs[docidIndex]; + } + + /** + * move the implementation of docID() into a differently-named method so we can call it + * from DocIDSetIterator.docID() even though this class is anonymous + * + * @return the current docid + */ + private int docIdNoShadow() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo] - context.docBase; + } + + @Override + public int docID() { + return docIdNoShadow(); + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public String toString(String field) { + return "DocAndScore[" + k + "]"; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof DocAndScoreQuery == false) { + return false; + } + return Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) + && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); + } + + @Override + public int hashCode() { + return Objects.hash( + DocAndScoreQuery.class.hashCode(), Arrays.hashCode(docs), Arrays.hashCode(scores)); + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java new file mode 100644 index 00000000000..862f8f7c114 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.TestVectorUtil.randomVector; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; + +/** TestKnnVectorQuery tests KnnVectorQuery. */ +public class TestKnnVectorQuery extends LuceneTestCase { + + public void testEquals() { + KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10); + + assertEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 10)); + + assertNotEquals(null, q1); + + assertNotEquals(q1, new TermQuery(new Term("f1", "x"))); + + assertNotEquals(q1, new KnnVectorQuery("f2", new float[] {0, 1}, 10)); + assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {1, 1}, 10)); + assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 2)); + assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0}, 10)); + } + + public void testToString() { + KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10); + assertEquals("", q1.toString("ignored")); + } + + /** + * Tests if a KnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no documents to + * match. + */ + public void testEmptyIndex() throws IOException { + try (Directory indexStore = getIndexStore("field"); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {1, 2}, 10); + assertMatches(searcher, kvq, 0); + Query q = searcher.rewrite(kvq); + assertTrue(q instanceof MatchNoDocsQuery); + } + } + + /** + * Tests that a KnnVectorQuery whose topK >= numDocs returns all the documents in score order + */ + public void testFindAll() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10); + assertMatches(searcher, kvq, reader.numDocs()); + TopDocs topDocs = searcher.search(kvq, 3); + assertEquals(2, topDocs.scoreDocs[0].doc); + assertEquals(0, topDocs.scoreDocs[1].doc); + assertEquals(1, topDocs.scoreDocs[2].doc); + } + } + + /** testDimensionMismatch */ + public void testDimensionMismatch() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10)); + assertEquals("vector dimensions differ: 1!=2", e.getMessage()); + } + } + + /** testNonVectorField */ + public void testNonVectorField() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + assertMatches(searcher, new KnnVectorQuery("xyzzy", new float[] {0}, 10), 0); + assertMatches(searcher, new KnnVectorQuery("id", new float[] {0}, 10), 0); + } + } + + /** Test bad parameters */ + public void testIllegalArguments() throws IOException { + expectThrows( + IllegalArgumentException.class, () -> new KnnVectorQuery("xx", new float[] {1}, 0)); + } + + public void testDifferentReader() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3); + Query dasq = query.rewrite(reader); + IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader()); + expectThrows( + IllegalStateException.class, + () -> dasq.createWeight(leafSearcher, ScoreMode.COMPLETE, 1)); + } + } + + public void testAdvanceShallow() throws IOException { + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { + for (int j = 0; j < 5; j++) { + Document doc = new Document(); + doc.add(new KnnVectorField("field", new float[] {j, j})); + w.addDocument(doc); + } + } + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = new IndexSearcher(reader); + KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3); + Query dasq = query.rewrite(reader); + Scorer scorer = + dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0)); + // before advancing the iterator + assertEquals(1, scorer.advanceShallow(0)); + assertEquals(1, scorer.advanceShallow(1)); + assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10)); + + // after advancing the iterator + scorer.iterator().advance(2); + assertEquals(2, scorer.advanceShallow(0)); + assertEquals(2, scorer.advanceShallow(2)); + assertEquals(3, scorer.advanceShallow(3)); + assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10)); + } + } + } + + public void testScore() throws IOException { + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { + for (int j = 0; j < 5; j++) { + Document doc = new Document(); + doc.add(new KnnVectorField("field", new float[] {j, j})); + w.addDocument(doc); + } + } + try (IndexReader reader = DirectoryReader.open(d)) { + assertEquals(1, reader.leaves().size()); + IndexSearcher searcher = new IndexSearcher(reader); + KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3); + Query rewritten = query.rewrite(reader); + Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1); + Scorer scorer = weight.scorer(reader.leaves().get(0)); + + // prior to advancing, score is 0 + assertEquals(-1, scorer.docID()); + expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score()); + + // test getMaxScore + assertEquals(0, scorer.getMaxScore(-1), 0); + assertEquals(0, scorer.getMaxScore(0), 0); + // This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5 + assertEquals(1 / 2f, scorer.getMaxScore(2), 0); + assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0); + + DocIdSetIterator it = scorer.iterator(); + assertEquals(3, it.cost()); + assertEquals(1, it.nextDoc()); + assertEquals(1 / 6f, scorer.score(), 0); + assertEquals(3, it.advance(3)); + assertEquals(1 / 2f, scorer.score(), 0); + assertEquals(NO_MORE_DOCS, it.advance(4)); + expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score()); + } + } + } + + public void testExplain() throws IOException { + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { + for (int j = 0; j < 5; j++) { + Document doc = new Document(); + doc.add(new KnnVectorField("field", new float[] {j, j})); + w.addDocument(doc); + } + } + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = new IndexSearcher(reader); + KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3); + Explanation matched = searcher.explain(query, 2); + assertTrue(matched.isMatch()); + assertEquals(1 / 2f, matched.getValue()); + assertEquals(0, matched.getDetails().length); + assertEquals("within top 3", matched.getDescription()); + + Explanation nomatch = searcher.explain(query, 4); + assertFalse(nomatch.isMatch()); + assertEquals(0f, nomatch.getValue()); + assertEquals(0, matched.getDetails().length); + assertEquals("not in top 3", nomatch.getDescription()); + } + } + } + + /** Test that when vectors are abnormally distributed among segments, we still find the top K */ + public void testSkewedIndex() throws IOException { + /* We have to choose the numbers carefully here so that some segment has more than the expected + * number of top K documents, but no more than K documents in total (otherwise we might occasionally + * randomly fail to find one). + */ + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { + int r = 0; + for (int i = 0; i < 5; i++) { + for (int j = 0; j < 5; j++) { + Document doc = new Document(); + doc.add(new KnnVectorField("field", new float[] {r, r})); + w.addDocument(doc); + ++r; + } + w.flush(); + } + } + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + TopDocs results = searcher.search(new KnnVectorQuery("field", new float[] {0, 0}, 8), 10); + assertEquals(8, results.scoreDocs.length); + assertEquals(0, results.scoreDocs[0].doc); + assertEquals(7, results.scoreDocs[7].doc); + + // test some results in the middle of the sequence - also tests docid tiebreaking + results = searcher.search(new KnnVectorQuery("field", new float[] {10, 10}, 8), 10); + assertEquals(8, results.scoreDocs.length); + assertEquals(10, results.scoreDocs[0].doc); + assertEquals(6, results.scoreDocs[7].doc); + } + } + } + + /** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */ + public void testRandom() throws IOException { + int numDocs = atLeast(100); + int dimension = atLeast(5); + int numIters = atLeast(10); + boolean everyDocHasAVector = random().nextBoolean(); + try (Directory d = newDirectory()) { + RandomIndexWriter w = new RandomIndexWriter(random(), d); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (everyDocHasAVector || random().nextInt(10) != 2) { + doc.add(new KnnVectorField("field", randomVector(dimension))); + } + w.addDocument(doc); + } + w.close(); + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(100) + 1; + KnnVectorQuery query = new KnnVectorQuery("field", randomVector(dimension), k); + int n = random().nextInt(100) + 1; + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), reader.numDocs()); + // we may get fewer results than requested if there are deletions, but this test doesn't + // test that + assert reader.hasDeletions() == false; + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private Directory getIndexStore(String field, float[]... contents) throws IOException { + Directory indexStore = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); + for (int i = 0; i < contents.length; ++i) { + Document doc = new Document(); + doc.add(new KnnVectorField(field, contents[i])); + doc.add(new StringField("id", "id" + i, Field.Store.NO)); + writer.addDocument(doc); + } + writer.close(); + return indexStore; + } + + private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches) + throws IOException { + ScoreDoc[] result = searcher.search(q, 1000).scoreDocs; + assertEquals(expectedMatches, result.length); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java index d373fa1d1a4..1ebd0562893 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.util; +import java.util.Random; + public class TestVectorUtil extends LuceneTestCase { public static final double DELTA = 1e-4; @@ -81,7 +83,7 @@ public class TestVectorUtil extends LuceneTestCase { expectThrows(IllegalArgumentException.class, () -> VectorUtil.l2normalize(v)); } - private float l2(float[] v) { + private static float l2(float[] v) { float l2 = 0; for (float x : v) { l2 += x * x; @@ -89,7 +91,7 @@ public class TestVectorUtil extends LuceneTestCase { return l2; } - private float[] negative(float[] v) { + private static float[] negative(float[] v) { float[] u = new float[v.length]; for (int i = 0; i < v.length; i++) { u[i] = -v[i]; @@ -97,10 +99,15 @@ public class TestVectorUtil extends LuceneTestCase { return u; } - private float[] randomVector() { - float[] v = new float[random().nextInt(100) + 1]; - for (int i = 0; i < v.length; i++) { - v[i] = random().nextFloat(); + private static float[] randomVector() { + return randomVector(random().nextInt(100) + 1); + } + + public static float[] randomVector(int dim) { + float[] v = new float[dim]; + Random random = random(); + for (int i = 0; i < dim; i++) { + v[i] = random.nextFloat(); } return v; }