Clean up KNN related backward-codecs changes (#12019)

This commit is contained in:
Benjamin Trent 2022-12-20 08:04:42 -05:00 committed by GitHub
parent 3ac71adbdf
commit 1412e559d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 116 additions and 55 deletions

View File

@ -57,7 +57,7 @@ public final class Lucene90HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without // we need two sources of vectors in order to perform diversity check comparisons without
// colliding // colliding
private RandomAccessVectorValues buildVectors; private final RandomAccessVectorValues buildVectors;
/** /**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense

View File

@ -55,15 +55,12 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
*/ */
public final class Lucene90HnswVectorsReader extends KnnVectorsReader { public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>(); private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
private final long checksumSeed; private final long checksumSeed;
Lucene90HnswVectorsReader(SegmentReadState state) throws IOException { Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
int versionMeta = readMetadata(state); int versionMeta = readMetadata(state);
long[] checksumRef = new long[1]; long[] checksumRef = new long[1];
boolean success = false; boolean success = false;
@ -305,20 +302,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}; };
} }
/** Get knn graph values; used for testing */
public HnswGraph getGraphValues(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.indexDataLength > 0) {
return getGraphValues(entry);
} else {
return HnswGraph.EMPTY;
}
}
private HnswGraph getGraphValues(FieldEntry entry) throws IOException { private HnswGraph getGraphValues(FieldEntry entry) throws IOException {
IndexInput bytesSlice = IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength); vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);

View File

@ -57,13 +57,11 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
*/ */
public final class Lucene91HnswVectorsReader extends KnnVectorsReader { public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>(); private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData; private final IndexInput vectorData;
private final IndexInput vectorIndex; private final IndexInput vectorIndex;
Lucene91HnswVectorsReader(SegmentReadState state) throws IOException { Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
int versionMeta = readMetadata(state); int versionMeta = readMetadata(state);
boolean success = false; boolean success = false;
try { try {
@ -299,20 +297,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}; };
} }
/** Get knn graph values; used for testing */
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.vectorIndexLength > 0) {
return getGraph(entry);
} else {
return HnswGraph.EMPTY;
}
}
private HnswGraph getGraph(FieldEntry entry) throws IOException { private HnswGraph getGraph(FieldEntry entry) throws IOException {
IndexInput bytesSlice = IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength); vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
@ -581,12 +565,12 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
} }
@Override @Override
public int numLevels() throws IOException { public int numLevels() {
return numLevels; return numLevels;
} }
@Override @Override
public int entryNode() throws IOException { public int entryNode() {
return entryNode; return entryNode;
} }

View File

@ -440,12 +440,12 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
} }
@Override @Override
public int numLevels() throws IOException { public int numLevels() {
return numLevels; return numLevels;
} }
@Override @Override
public int entryNode() throws IOException { public int entryNode() {
return entryNode; return entryNode;
} }

View File

@ -63,7 +63,7 @@ public class Lucene94Codec extends Codec {
private final Lucene90StoredFieldsFormat.Mode storedMode; private final Lucene90StoredFieldsFormat.Mode storedMode;
private Mode(Lucene90StoredFieldsFormat.Mode storedMode) { Mode(Lucene90StoredFieldsFormat.Mode storedMode) {
this.storedMode = Objects.requireNonNull(storedMode); this.storedMode = Objects.requireNonNull(storedMode);
} }
} }
@ -164,7 +164,7 @@ public class Lucene94Codec extends Codec {
} }
@Override @Override
public final KnnVectorsFormat knnVectorsFormat() { public KnnVectorsFormat knnVectorsFormat() {
return knnVectorsFormat; return knnVectorsFormat;
} }

View File

@ -96,7 +96,7 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* *
* @lucene.experimental * @lucene.experimental
*/ */
public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat { public class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
static final String META_CODEC_NAME = "lucene94HnswVectorsFormatMeta"; static final String META_CODEC_NAME = "lucene94HnswVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "lucene94HnswVectorsFormatData"; static final String VECTOR_DATA_CODEC_NAME = "lucene94HnswVectorsFormatData";
@ -115,20 +115,18 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
*/ */
public static final int DEFAULT_BEAM_WIDTH = 100; public static final int DEFAULT_BEAM_WIDTH = 100;
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
/** /**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene94HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. * {@link Lucene94HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/ */
private final int maxConn; final int maxConn;
/** /**
* The number of candidate neighbors to track while searching the graph for each newly inserted * The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to to {@link Lucene94HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link * node. Defaults to to {@link Lucene94HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link
* HnswGraph} for details. * HnswGraph} for details.
*/ */
private final int beamWidth; final int beamWidth;
/** Constructs a format using default graph construction parameters */ /** Constructs a format using default graph construction parameters */
public Lucene94HnswVectorsFormat() { public Lucene94HnswVectorsFormat() {
@ -149,7 +147,7 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
@Override @Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene94HnswVectorsWriter(state, maxConn, beamWidth); throw new UnsupportedOperationException("Old codecs may only be used for reading");
} }
@Override @Override

View File

@ -504,12 +504,12 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
} }
@Override @Override
public int numLevels() throws IOException { public int numLevels() {
return numLevels; return numLevels;
} }
@Override @Override
public int entryNode() throws IOException { public int entryNode() {
return entryNode; return entryNode;
} }

View File

@ -17,7 +17,7 @@
package org.apache.lucene.backward_codecs.lucene94; package org.apache.lucene.backward_codecs.lucene94;
import static org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; import static org.apache.lucene.backward_codecs.lucene94.Lucene94RWHnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;

View File

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene94;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
/** Implements the Lucene 9.4 index format for backwards compat testing */
public class Lucene94RWCodec extends Lucene94Codec {
private final KnnVectorsFormat defaultKnnVectorsFormat;
private final KnnVectorsFormat knnVectorsFormat =
new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return defaultKnnVectorsFormat;
}
};
/** Instantiates a new codec. */
public Lucene94RWCodec() {
defaultKnnVectorsFormat =
new Lucene94RWHnswVectorsFormat(
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
@Override
public final KnnVectorsFormat knnVectorsFormat() {
return knnVectorsFormat;
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene94;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
public final class Lucene94RWHnswVectorsFormat extends Lucene94HnswVectorsFormat {
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
public Lucene94RWHnswVectorsFormat(int maxConn, int beamWidth) {
super(maxConn, beamWidth);
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene94HnswVectorsWriter(state, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene94HnswVectorsReader(state);
}
@Override
public String toString() {
return "Lucene94RWHnswVectorsFormat(name=Lucene94RWHnswVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ")";
}
}

View File

@ -19,24 +19,23 @@ package org.apache.lucene.backward_codecs.lucene94;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
public class TestLucene94HnswVectorsFormat extends BaseKnnVectorsFormatTestCase { public class TestLucene94HnswVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override @Override
protected Codec getCodec() { protected Codec getCodec() {
return TestUtil.getDefaultCodec(); return new Lucene94RWCodec();
} }
public void testToString() { public void testToString() {
Lucene94Codec customCodec = Lucene94RWCodec customCodec =
new Lucene94Codec() { new Lucene94RWCodec() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene94HnswVectorsFormat(10, 20); return new Lucene94RWHnswVectorsFormat(10, 20);
} }
}; };
String expectedString = String expectedString =
"Lucene94HnswVectorsFormat(name=Lucene94HnswVectorsFormat, maxConn=10, beamWidth=20)"; "Lucene94RWHnswVectorsFormat(name=Lucene94RWHnswVectorsFormat, maxConn=10, beamWidth=20)";
assertEquals(expectedString, customCodec.getKnnVectorsFormatForField("bogus_field").toString()); assertEquals(expectedString, customCodec.getKnnVectorsFormatForField("bogus_field").toString());
} }
} }