Add more backwards compability tests for Scalar quantization (#13298)

This adds more backwards compatibility coverage for scalar quantization. Adding a test that forces the older metadata version to be written and ensures that it can still be read.
This commit is contained in:
Benjamin Trent 2024-04-16 09:08:30 -04:00 committed by GitHub
parent dcb512289f
commit 3ba7ebbad8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 152 additions and 13 deletions

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.FlatVectorsFormat;
import org.apache.lucene.codecs.FlatVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
import org.apache.lucene.index.SegmentWriteState;
class Lucene99RWHnswScalarQuantizationVectorsFormat
extends Lucene99HnswScalarQuantizedVectorsFormat {
private final FlatVectorsFormat flatVectorsFormat = new Lucene99RWScalarQuantizedFormat();
/** Sole constructor */
protected Lucene99RWHnswScalarQuantizationVectorsFormat() {
super();
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(
state,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
flatVectorsFormat.fieldsWriter(state),
1,
null);
}
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat();
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter(
state, null, rawVectorFormat.fieldsWriter(state));
}
}
}

View File

@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene99;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
public class TestLucene99HnswScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99RWHnswScalarQuantizationVectorsFormat();
}
};
}
}

View File

@ -42,7 +42,9 @@ import org.apache.lucene.util.hnsw.HnswGraph;
*
* @lucene.experimental
*/
public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
public class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
public static final String NAME = "Lucene99HnswScalarQuantizedVectorsFormat";
/**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
@ -103,7 +105,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo
boolean compress,
Float confidenceInterval,
ExecutorService mergeExec) {
super("Lucene99HnswScalarQuantizedVectorsFormat");
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to "

View File

@ -29,14 +29,14 @@ import org.apache.lucene.index.SegmentWriteState;
*
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
// The bits that are allowed for scalar quantization
// We only allow unsigned byte (8), signed byte (7), and half-byte (4)
private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4);
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
public static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
static final int VERSION_START = 0;
static final int VERSION_ADD_BITS = 1;

View File

@ -98,8 +98,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private final FlatVectorsWriter rawVectorDelegate;
private final byte bits;
private final boolean compress;
private final int version;
private boolean finished;
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate)
throws IOException {
this(
state,
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
confidenceInterval,
(byte) 7,
false,
rawVectorDelegate);
}
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state,
Float confidenceInterval,
@ -107,9 +120,27 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
boolean compress,
FlatVectorsWriter rawVectorDelegate)
throws IOException {
this(
state,
Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS,
confidenceInterval,
bits,
compress,
rawVectorDelegate);
}
private Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state,
int version,
Float confidenceInterval,
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate)
throws IOException {
this.confidenceInterval = confidenceInterval;
this.bits = bits;
this.compress = compress;
this.version = version;
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
@ -132,13 +163,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
CodecUtil.writeIndexHeader(
meta,
Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
version,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
quantizedVectorData,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
version,
state.segmentInfo.getId(),
state.segmentSuffix);
success = true;
@ -301,9 +332,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
meta.writeInt(count);
if (count > 0) {
assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile);
meta.writeInt(confidenceInterval == null ? -1 : Float.floatToIntBits(confidenceInterval));
meta.writeByte(bits);
meta.writeByte(compress ? (byte) 1 : (byte) 0);
if (version >= Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS) {
meta.writeInt(confidenceInterval == null ? -1 : Float.floatToIntBits(confidenceInterval));
meta.writeByte(bits);
meta.writeByte(compress ? (byte) 1 : (byte) 0);
} else {
meta.writeInt(
Float.floatToIntBits(
confidenceInterval == null
? calculateDefaultConfidenceInterval(field.getVectorDimension())
: confidenceInterval));
}
meta.writeInt(Float.floatToIntBits(lowerQuantile));
meta.writeInt(Float.floatToIntBits(upperQuantile));
}
@ -453,10 +492,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
CodecUtil.retrieveChecksum(quantizationDataInput);
float confidenceInterval =
this.confidenceInterval == null
? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension())
: this.confidenceInterval;
writeMeta(
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),