Minor refactor for HNSW graph merging logic (#12616)

This is a minor refactor of HNSW graph merging logic.

Instead of directly checking the KnnVectorReader version, this commit adjusts the logic to see if a specific interface is satisfied for returning a view of the HnswGraph.
This commit is contained in:
Benjamin Trent 2023-10-03 11:28:10 -07:00 committed by GitHub
parent 1baae3629a
commit 6b7d311c0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 9 deletions

View File

@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs;
import java.io.IOException;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
* An interface that provides an HNSW graph. This interface is useful when gathering multiple HNSW
* graphs to bootstrap segment merging. The graph may be off the JVM heap.
*
* @lucene.experimental
*/
public interface HnswGraphProvider {
/**
* Return the stored HnswGraph for the given field.
*
* @param field the field containing the graph
* @return the HnswGraph for the given field if found
* @throws IOException when reading potentially off-heap graph fails
*/
HnswGraph getGraph(String field) throws IOException;
}

View File

@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
@ -55,7 +56,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
*
* @lucene.experimental
*/
public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene95HnswVectorsFormat.class);
@ -308,6 +309,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
}
/** Get knn graph values; used for testing */
@Override
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {

View File

@ -30,6 +30,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
@ -506,7 +507,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
throws IOException {
// Find the KnnVectorReader with the most docs that meets the following criteria:
// 1. Does not contain any deleted docs
// 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
// 2. Is a HnswGraphProvider/PerFieldKnnVectorReader
// If no readers exist that meet this criteria, return -1. If they do, return their index in
// merge state
int maxCandidateVectorCount = 0;
@ -520,21 +521,23 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
}
if (!allMatch(mergeState.liveDocs[i])
|| !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) {
|| !(currKnnVectorsReader instanceof HnswGraphProvider)) {
continue;
}
int candidateVectorCount = 0;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> {
ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name);
ByteVectorValues byteVectorValues =
currKnnVectorsReader.getByteVectorValues(fieldInfo.name);
if (byteVectorValues == null) {
continue;
}
candidateVectorCount = byteVectorValues.size();
}
case FLOAT32 -> {
FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name);
FloatVectorValues vectorValues =
currKnnVectorsReader.getFloatVectorValues(fieldInfo.name);
if (vectorValues == null) {
continue;
}
@ -553,13 +556,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
throws IOException {
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
&& perFieldReader.getFieldReader(fieldName)
instanceof Lucene95HnswVectorsReader fieldReader) {
&& perFieldReader.getFieldReader(fieldName) instanceof HnswGraphProvider fieldReader) {
return fieldReader.getGraph(fieldName);
}
if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName);
if (knnVectorsReader instanceof HnswGraphProvider provider) {
return provider.getGraph(fieldName);
}
// We should not reach here because knnVectorsReader's type is checked in