Simplify codec setup in vector-related tests. ()

Many of vector-related tests set up a codec manually by extending the current
codec. This makes bumping the current codec a bit painful as all these files
need to be touched. This commit migrates to `TestUtil#alwaysKnnVectorsFormat`,
similarly to what we do for postings and doc values.
This commit is contained in:
Adrien Grand 2024-11-12 10:38:54 +01:00 committed by GitHub
parent 65457224fb
commit 6fe8165cac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 54 additions and 178 deletions
lucene
backward-codecs/src/test/org/apache/lucene
codecs/src/test/org/apache/lucene/codecs/bitvectors
core/src/test/org/apache/lucene
test-framework/src/java/org/apache/lucene/tests/util

View File

@ -18,18 +18,12 @@
package org.apache.lucene.backward_codecs.lucene99;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
public class TestLucene99HnswScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99RWHnswScalarQuantizationVectorsFormat();
}
};
return TestUtil.alwaysKnnVectorsFormat(new Lucene99RWHnswScalarQuantizationVectorsFormat());
}
}

View File

@ -21,9 +21,7 @@ import static org.apache.lucene.backward_index.TestBasicBackwardsCompatibility.a
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import java.io.IOException;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
@ -69,14 +67,10 @@ public class TestInt7HnswBackwardsCompatibility extends BackwardsCompatibilityTe
}
protected Codec getCodec() {
return new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
return TestUtil.alwaysKnnVectorsFormat(
new Lucene99HnswScalarQuantizedVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
};
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH));
}
@Override

View File

@ -22,7 +22,6 @@ import java.io.IOException;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
@ -38,16 +37,12 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
public class TestHnswBitVectorsFormat extends BaseIndexFileFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new HnswBitVectorsFormat();
}
};
return TestUtil.alwaysKnnVectorsFormat(new HnswBitVectorsFormat());
}
@Override

View File

@ -28,7 +28,6 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
@ -48,6 +47,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.SameThreadExecutorService;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
@ -74,12 +74,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
@Override
protected Codec getCodec() {
return new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format;
}
};
return TestUtil.alwaysKnnVectorsFormat(format);
}
private final KnnVectorsFormat getKnnFormat(int bits) {
@ -104,14 +99,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
try (IndexWriter w =
new IndexWriter(
dir,
newIndexWriterConfig()
.setCodec(
new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return getKnnFormat(4);
}
}))) {
newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(4))))) {
Document doc = new Document();
doc.add(
@ -124,14 +112,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
try (IndexWriter w =
new IndexWriter(
dir,
newIndexWriterConfig()
.setCodec(
new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return getKnnFormat(7);
}
}))) {
newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(7))))) {
Document doc = new Document();
doc.add(
@ -162,13 +143,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
new IndexWriter(
dir,
newIndexWriterConfig()
.setCodec(
new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat();
}
}))) {
.setCodec(TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswVectorsFormat())))) {
Document doc = new Document();
doc.add(
@ -181,14 +156,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
try (IndexWriter w =
new IndexWriter(
dir,
newIndexWriterConfig()
.setCodec(
new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return getKnnFormat(7);
}
}))) {
newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(7))))) {
Document doc = new Document();
doc.add(
@ -216,13 +184,9 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
dir,
newIndexWriterConfig()
.setCodec(
new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
16, 100, 1, (byte) 7, false, 0.9f, null);
}
}))) {
TestUtil.alwaysKnnVectorsFormat(
new Lucene99HnswScalarQuantizedVectorsFormat(
16, 100, 1, (byte) 7, false, 0.9f, null))))) {
for (float[] vector : vectors) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.DOT_PRODUCT));

View File

@ -24,10 +24,8 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
@ -44,6 +42,7 @@ import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
@ -52,19 +51,15 @@ import org.apache.lucene.util.quantization.ScalarQuantizer;
public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
private static Codec getCodec(int bits, boolean compress) {
return new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
return TestUtil.alwaysKnnVectorsFormat(
new Lucene99HnswScalarQuantizedVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
1,
bits,
compress,
0f,
null);
}
};
null));
}
public void testNonZeroScores() throws IOException {

View File

@ -28,7 +28,6 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
@ -43,6 +42,7 @@ import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
@ -70,12 +70,7 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
@Override
protected Codec getCodec() {
return new Lucene101Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format;
}
};
return TestUtil.alwaysKnnVectorsFormat(format);
}
public void testSearch() throws Exception {

View File

@ -30,8 +30,6 @@ import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
@ -85,33 +83,15 @@ public class TestKnnGraph extends LuceneTestCase {
vectorEncoding = randomVectorEncoding();
boolean quantized = randomBoolean();
codec =
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return quantized
? new Lucene99HnswScalarQuantizedVectorsFormat(
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH)
: new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}
};
TestUtil.alwaysKnnVectorsFormat(
quantized
? new Lucene99HnswScalarQuantizedVectorsFormat(
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH)
: new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH));
float32Codec =
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}
};
TestUtil.alwaysKnnVectorsFormat(
new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH));
}
private VectorEncoding randomVectorEncoding() {

View File

@ -50,7 +50,6 @@ import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.codecs.asserting.AssertingCodec;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.store.BaseDirectoryWrapper;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -1084,13 +1083,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer);
KnnVectorsFormat format1 = randomVectorFormat(VectorEncoding.FLOAT32);
KnnVectorsFormat format2 = randomVectorFormat(VectorEncoding.FLOAT32);
iwc.setCodec(
new AssertingCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format1;
}
});
iwc.setCodec(TestUtil.alwaysKnnVectorsFormat(format1));
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
Document doc = new Document();
@ -1104,13 +1097,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
}
iwc = newIndexWriterConfig(mockAnalyzer);
iwc.setCodec(
new AssertingCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format2;
}
});
iwc.setCodec(TestUtil.alwaysKnnVectorsFormat(format2));
try (IndexWriter iwriter = new IndexWriter(directory, iwc)) {
Document doc = new Document();

View File

@ -38,8 +38,6 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
@ -152,19 +150,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
IndexWriterConfig iwc =
new IndexWriterConfig()
.setCodec(
new FilterCodec(
TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
})
TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswVectorsFormat(M, beamWidth)))
// set a random merge policy
.setMergePolicy(newMergePolicy(random()));
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
@ -255,18 +241,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
IndexWriterConfig iwc =
new IndexWriterConfig()
.setCodec(
new FilterCodec(
TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
});
TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswVectorsFormat(M, beamWidth)));
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
KnnVectorValues.DocIndexIterator it2 = v2.iterator();
while (it2.nextDoc() != NO_MORE_DOCS) {
@ -317,32 +292,10 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
HnswGraphBuilder.randSeed = seed;
IndexWriterConfig iwc =
new IndexWriterConfig()
.setCodec(
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
});
.setCodec(TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswVectorsFormat(M, beamWidth)));
IndexWriterConfig iwc2 =
new IndexWriterConfig()
.setCodec(
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(M, beamWidth);
}
};
}
})
.setCodec(TestUtil.alwaysKnnVectorsFormat(new Lucene99HnswVectorsFormat(M, beamWidth)))
.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.LONG)));
try (Directory dir = newDirectory();

View File

@ -1310,6 +1310,25 @@ public final class TestUtil {
};
}
/**
* Return a Codec that can read any of the default codecs and formats, but always writes in the
* specified format.
*/
public static Codec alwaysKnnVectorsFormat(final KnnVectorsFormat format) {
// TODO: we really need for knn vectors impls etc to announce themselves
// (and maybe their params, too) to infostream on flush and merge.
// otherwise in a real debugging situation we won't know whats going on!
if (LuceneTestCase.VERBOSE) {
System.out.println("TestUtil: forcing knn vectors format to:" + format);
}
return new AssertingCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format;
}
};
}
/**
* Returns the actual default codec (e.g. LuceneMNCodec) for this version of Lucene. This may be
* different from {@link Codec#getDefault()} because that is randomized.