mirror of
https://github.com/apache/lucene.git
synced 2025-02-28 21:39:25 +00:00
Minor refactor for HNSW graph merging logic (#12616)
This is a minor refactor of HNSW graph merging logic. Instead of directly checking the KnnVectorReader version, this commit adjusts the logic to see if a specific interface is satisfied for returning a view of the HnswGraph.
This commit is contained in:
parent
1baae3629a
commit
6b7d311c0c
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
||||
/**
|
||||
* An interface that provides an HNSW graph. This interface is useful when gathering multiple HNSW
|
||||
* graphs to bootstrap segment merging. The graph may be off the JVM heap.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface HnswGraphProvider {
|
||||
/**
|
||||
* Return the stored HnswGraph for the given field.
|
||||
*
|
||||
* @param field the field containing the graph
|
||||
* @return the HnswGraph for the given field if found
|
||||
* @throws IOException when reading potentially off-heap graph fails
|
||||
*/
|
||||
HnswGraph getGraph(String field) throws IOException;
|
||||
}
|
@ -24,6 +24,7 @@ import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
@ -55,7 +56,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
||||
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {
|
||||
|
||||
private static final long SHALLOW_SIZE =
|
||||
RamUsageEstimator.shallowSizeOfInstance(Lucene95HnswVectorsFormat.class);
|
||||
@ -308,6 +309,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
||||
}
|
||||
|
||||
/** Get knn graph values; used for testing */
|
||||
@Override
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
|
@ -30,6 +30,7 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
@ -506,7 +507,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||
throws IOException {
|
||||
// Find the KnnVectorReader with the most docs that meets the following criteria:
|
||||
// 1. Does not contain any deleted docs
|
||||
// 2. Is a Lucene95HnswVectorsReader/PerFieldKnnVectorReader
|
||||
// 2. Is a HnswGraphProvider/PerFieldKnnVectorReader
|
||||
// If no readers exist that meet this criteria, return -1. If they do, return their index in
|
||||
// merge state
|
||||
int maxCandidateVectorCount = 0;
|
||||
@ -520,21 +521,23 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||
}
|
||||
|
||||
if (!allMatch(mergeState.liveDocs[i])
|
||||
|| !(currKnnVectorsReader instanceof Lucene95HnswVectorsReader candidateReader)) {
|
||||
|| !(currKnnVectorsReader instanceof HnswGraphProvider)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int candidateVectorCount = 0;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
ByteVectorValues byteVectorValues = candidateReader.getByteVectorValues(fieldInfo.name);
|
||||
ByteVectorValues byteVectorValues =
|
||||
currKnnVectorsReader.getByteVectorValues(fieldInfo.name);
|
||||
if (byteVectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
candidateVectorCount = byteVectorValues.size();
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
FloatVectorValues vectorValues = candidateReader.getFloatVectorValues(fieldInfo.name);
|
||||
FloatVectorValues vectorValues =
|
||||
currKnnVectorsReader.getFloatVectorValues(fieldInfo.name);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
@ -553,13 +556,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||
private HnswGraph getHnswGraphFromReader(String fieldName, KnnVectorsReader knnVectorsReader)
|
||||
throws IOException {
|
||||
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader
|
||||
&& perFieldReader.getFieldReader(fieldName)
|
||||
instanceof Lucene95HnswVectorsReader fieldReader) {
|
||||
&& perFieldReader.getFieldReader(fieldName) instanceof HnswGraphProvider fieldReader) {
|
||||
return fieldReader.getGraph(fieldName);
|
||||
}
|
||||
|
||||
if (knnVectorsReader instanceof Lucene95HnswVectorsReader) {
|
||||
return ((Lucene95HnswVectorsReader) knnVectorsReader).getGraph(fieldName);
|
||||
if (knnVectorsReader instanceof HnswGraphProvider provider) {
|
||||
return provider.getGraph(fieldName);
|
||||
}
|
||||
|
||||
// We should not reach here because knnVectorsReader's type is checked in
|
||||
|
Loading…
x
Reference in New Issue
Block a user