LUCENE-10016: Remove VectorValues#getSimilarityFunction. (#213)

VectorValues is only about iterating over vectors in doc ID order, so it feels
wrong to tie it to the similarity function.
This commit is contained in:
Adrien Grand 2021-07-19 09:48:09 +02:00 committed by GitHub
parent 9b5e233960
commit acf45d8a31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 312 additions and 353 deletions

View File

@ -30,7 +30,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
@ -214,7 +214,7 @@ public final class Lucene60FieldInfosFormat extends FieldInfosFormat {
pointIndexDimensionCount,
pointNumBytes,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
isSoftDeletesField);
} catch (IllegalStateException e) {
throw new CorruptIndexException(

View File

@ -28,7 +28,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
@ -158,7 +158,7 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
SimpleTextUtil.readLine(input, scratch);
assert StringHelper.startsWith(scratch.get(), VECTOR_SEARCH_STRATEGY);
String scoreFunction = readString(VECTOR_SEARCH_STRATEGY.length, scratch);
VectorValues.SimilarityFunction vectorDistFunc = distanceFunction(scoreFunction);
VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction);
SimpleTextUtil.readLine(input, scratch);
assert StringHelper.startsWith(scratch.get(), SOFT_DELETES);
@ -201,8 +201,8 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
return DocValuesType.valueOf(dvType);
}
public VectorValues.SimilarityFunction distanceFunction(String scoreFunction) {
return VectorValues.SimilarityFunction.valueOf(scoreFunction);
public VectorSimilarityFunction distanceFunction(String scoreFunction) {
return VectorSimilarityFunction.valueOf(scoreFunction);
}
private String readString(int offset, BytesRefBuilder scratch) {

View File

@ -81,9 +81,6 @@ public class SimpleTextVectorReader extends VectorReader {
int fieldNumber = readInt(in, FIELD_NUMBER);
while (fieldNumber != -1) {
String fieldName = readString(in, FIELD_NAME);
String scoreFunctionName = readString(in, SCORE_FUNCTION);
VectorValues.SimilarityFunction similarityFunction =
VectorValues.SimilarityFunction.valueOf(scoreFunctionName);
long vectorDataOffset = readLong(in, VECTOR_DATA_OFFSET);
long vectorDataLength = readLong(in, VECTOR_DATA_LENGTH);
int dimension = readInt(in, VECTOR_DIMENSION);
@ -94,9 +91,7 @@ public class SimpleTextVectorReader extends VectorReader {
}
assert fieldEntries.containsKey(fieldName) == false;
fieldEntries.put(
fieldName,
new FieldEntry(
dimension, similarityFunction, vectorDataOffset, vectorDataLength, docIds));
fieldName, new FieldEntry(dimension, vectorDataOffset, vectorDataLength, docIds));
fieldNumber = readInt(in, FIELD_NUMBER);
}
SimpleTextUtil.checkFooter(in);
@ -205,20 +200,13 @@ public class SimpleTextVectorReader extends VectorReader {
private static class FieldEntry {
final int dimension;
final VectorValues.SimilarityFunction similarityFunction;
final long vectorDataOffset;
final long vectorDataLength;
final int[] ordToDoc;
FieldEntry(
int dimension,
VectorValues.SimilarityFunction similarityFunction,
long vectorDataOffset,
long vectorDataLength,
int[] ordToDoc) {
FieldEntry(int dimension, long vectorDataOffset, long vectorDataLength, int[] ordToDoc) {
this.dimension = dimension;
this.similarityFunction = similarityFunction;
this.vectorDataOffset = vectorDataOffset;
this.vectorDataLength = vectorDataLength;
this.ordToDoc = ordToDoc;
@ -260,11 +248,6 @@ public class SimpleTextVectorReader extends VectorReader {
return entry.size();
}
@Override
public SimilarityFunction similarityFunction() {
return entry.similarityFunction;
}
@Override
public float[] vectorValue() {
return values[curOrd];

View File

@ -38,7 +38,6 @@ public class SimpleTextVectorWriter extends VectorWriter {
static final BytesRef FIELD_NUMBER = new BytesRef("field-number ");
static final BytesRef FIELD_NAME = new BytesRef("field-name ");
static final BytesRef SCORE_FUNCTION = new BytesRef("score-function ");
static final BytesRef VECTOR_DATA_OFFSET = new BytesRef("vector-data-offset ");
static final BytesRef VECTOR_DATA_LENGTH = new BytesRef("vector-data-length ");
static final BytesRef VECTOR_DIMENSION = new BytesRef("vector-dimension ");
@ -96,7 +95,6 @@ public class SimpleTextVectorWriter extends VectorWriter {
throws IOException {
writeField(meta, FIELD_NUMBER, field.number);
writeField(meta, FIELD_NAME, field.name);
writeField(meta, SCORE_FUNCTION, field.getVectorSimilarityFunction().name());
writeField(meta, VECTOR_DATA_OFFSET, vectorDataOffset);
writeField(meta, VECTOR_DATA_LENGTH, vectorDataLength);
writeField(meta, VECTOR_DIMENSION, field.getVectorDimension());

View File

@ -23,7 +23,7 @@ import org.apache.lucene.codecs.lucene90.MockTermStateFactory;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ByteBuffersDataOutput;
import org.apache.lucene.store.ByteBuffersIndexOutput;
import org.apache.lucene.util.BytesRef;
@ -116,7 +116,7 @@ public class TestBlockWriter extends LuceneTestCase {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
true);
}
}

View File

@ -41,7 +41,7 @@ import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.Directory;
@ -203,7 +203,7 @@ public class TestSTBlockReader extends LuceneTestCase {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false);
}

View File

@ -29,6 +29,7 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
@ -68,14 +69,14 @@ public abstract class VectorWriter implements Closeable {
}
List<VectorValuesSub> subs = new ArrayList<>();
int dimension = -1;
VectorValues.SimilarityFunction similarityFunction = null;
VectorSimilarityFunction similarityFunction = null;
int nonEmptySegmentIndex = 0;
for (int i = 0; i < mergeState.vectorReaders.length; i++) {
VectorReader vectorReader = mergeState.vectorReaders[i];
if (vectorReader != null) {
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
int segmentDimension = mergeFieldInfo.getVectorDimension();
VectorValues.SimilarityFunction segmentSimilarityFunction =
VectorSimilarityFunction segmentSimilarityFunction =
mergeFieldInfo.getVectorSimilarityFunction();
if (dimension == -1) {
dimension = segmentDimension;
@ -238,11 +239,6 @@ public abstract class VectorWriter implements Closeable {
return subs.get(0).values.dimension();
}
@Override
public SimilarityFunction similarityFunction() {
return subs.get(0).values.similarityFunction();
}
class MergerRandomAccess implements RandomAccessVectorValues {
private final List<RandomAccessVectorValues> raSubs;
@ -269,11 +265,6 @@ public abstract class VectorWriter implements Closeable {
return VectorValuesMerger.this.dimension();
}
@Override
public SimilarityFunction similarityFunction() {
return VectorValuesMerger.this.similarityFunction();
}
@Override
public float[] vectorValue(int target) throws IOException {
int unmappedOrd = ordMap[target];

View File

@ -29,7 +29,7 @@ import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.VectorValues.SimilarityFunction;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory;
@ -102,8 +102,8 @@ import org.apache.lucene.store.IndexOutput;
* <li>VectorDistFunction: a byte containing distance function used for similarity calculation.
* <ul>
* <li>0: no distance function is defined for this field.
* <li>1: EUCLIDEAN_HNSW distance. ({@link SimilarityFunction#EUCLIDEAN})
* <li>2: DOT_PRODUCT_HNSW score. ({@link SimilarityFunction#DOT_PRODUCT})
* <li>1: EUCLIDEAN_HNSW distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* <li>2: DOT_PRODUCT_HNSW score. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* </ul>
* </ul>
*
@ -172,7 +172,7 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat {
pointNumBytes = 0;
}
final int vectorDimension = input.readVInt();
final SimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte());
final VectorSimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte());
try {
infos[i] =
@ -253,11 +253,11 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat {
}
}
private static SimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
if (b < 0 || b >= SimilarityFunction.values().length) {
private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
if (b < 0 || b >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException("invalid distance function: " + b, input);
}
return SimilarityFunction.values()[b];
return VectorSimilarityFunction.values()[b];
}
static {

View File

@ -35,6 +35,7 @@ import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
@ -187,19 +188,18 @@ public final class Lucene90HnswVectorReader extends VectorReader {
}
}
private VectorValues.SimilarityFunction readSimilarityFunction(DataInput input)
throws IOException {
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
int similarityFunctionId = input.readInt();
if (similarityFunctionId < 0
|| similarityFunctionId >= VectorValues.SimilarityFunction.values().length) {
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException(
"Invalid similarity function id: " + similarityFunctionId, input);
}
return VectorValues.SimilarityFunction.values()[similarityFunctionId];
return VectorSimilarityFunction.values()[similarityFunctionId];
}
private FieldEntry readField(DataInput input) throws IOException {
VectorValues.SimilarityFunction similarityFunction = readSimilarityFunction(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
switch (similarityFunction) {
case NONE:
return new FieldEntry(input, similarityFunction);
@ -252,7 +252,14 @@ public final class Lucene90HnswVectorReader extends VectorReader {
// use a seed that is fixed for the index so we get reproducible results for the same query
final Random random = new Random(checksumSeed);
NeighborQueue results =
HnswGraph.search(target, k, k, vectorValues, getGraphValues(fieldEntry), random);
HnswGraph.search(
target,
k,
k,
vectorValues,
fieldEntry.similarityFunction,
getGraphValues(fieldEntry),
random);
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
boolean reversed = fieldEntry.similarityFunction.reversed;
@ -292,7 +299,7 @@ public final class Lucene90HnswVectorReader extends VectorReader {
}
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
if (entry.similarityFunction != VectorValues.SimilarityFunction.NONE) {
if (entry.similarityFunction != VectorSimilarityFunction.NONE) {
HnswGraphFieldEntry graphEntry = (HnswGraphFieldEntry) entry;
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
@ -310,7 +317,7 @@ public final class Lucene90HnswVectorReader extends VectorReader {
private static class FieldEntry {
final int dimension;
final VectorValues.SimilarityFunction similarityFunction;
final VectorSimilarityFunction similarityFunction;
final long vectorDataOffset;
final long vectorDataLength;
@ -318,8 +325,7 @@ public final class Lucene90HnswVectorReader extends VectorReader {
final long indexDataLength;
final int[] ordToDoc;
FieldEntry(DataInput input, VectorValues.SimilarityFunction similarityFunction)
throws IOException {
FieldEntry(DataInput input, VectorSimilarityFunction similarityFunction) throws IOException {
this.similarityFunction = similarityFunction;
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
@ -343,7 +349,7 @@ public final class Lucene90HnswVectorReader extends VectorReader {
final long[] ordOffsets;
HnswGraphFieldEntry(DataInput input, VectorValues.SimilarityFunction similarityFunction)
HnswGraphFieldEntry(DataInput input, VectorSimilarityFunction similarityFunction)
throws IOException {
super(input, similarityFunction);
ordOffsets = new long[size()];
@ -389,11 +395,6 @@ public final class Lucene90HnswVectorReader extends VectorReader {
return fieldEntry.size();
}
@Override
public SimilarityFunction similarityFunction() {
return fieldEntry.similarityFunction;
}
@Override
public float[] vectorValue() throws IOException {
dataIn.seek((long) ord * byteSize);

View File

@ -27,6 +27,7 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.BytesRef;
@ -126,11 +127,12 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
long[] offsets = new long[count];
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer();
if (vectors.similarityFunction() != VectorValues.SimilarityFunction.NONE) {
if (fieldInfo.getVectorSimilarityFunction() != VectorSimilarityFunction.NONE) {
if (vectors instanceof RandomAccessVectorValuesProducer) {
writeGraph(
vectorIndex,
(RandomAccessVectorValuesProducer) vectors,
fieldInfo.getVectorSimilarityFunction(),
vectorIndexOffset,
offsets,
count,
@ -150,7 +152,7 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
vectorIndexLength,
count,
docIds);
if (vectors.similarityFunction() != VectorValues.SimilarityFunction.NONE) {
if (fieldInfo.getVectorSimilarityFunction() != VectorSimilarityFunction.NONE) {
writeGraphOffsets(meta, offsets);
}
}
@ -196,6 +198,7 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
private void writeGraph(
IndexOutput graphData,
RandomAccessVectorValuesProducer vectorValues,
VectorSimilarityFunction similarityFunction,
long graphDataOffset,
long[] offsets,
int count,
@ -203,7 +206,8 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
int beamWidth)
throws IOException {
HnswGraphBuilder hnswGraphBuilder =
new HnswGraphBuilder(vectorValues, maxConn, beamWidth, HnswGraphBuilder.randSeed);
new HnswGraphBuilder(
vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());

View File

@ -23,6 +23,7 @@ import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
/** Describes the properties of a field. */
@ -42,8 +43,7 @@ public class FieldType implements IndexableFieldType {
private int indexDimensionCount;
private int dimensionNumBytes;
private int vectorDimension;
private VectorValues.SimilarityFunction vectorSimilarityFunction =
VectorValues.SimilarityFunction.NONE;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.NONE;
private Map<String, String> attributes;
/** Create a new mutable FieldType with all of the properties from <code>ref</code> */
@ -371,7 +371,7 @@ public class FieldType implements IndexableFieldType {
/** Enable vector indexing, with the specified number of dimensions and distance function. */
public void setVectorDimensionsAndSimilarityFunction(
int numDimensions, VectorValues.SimilarityFunction distFunc) {
int numDimensions, VectorSimilarityFunction distFunc) {
checkIfFrozen();
if (numDimensions <= 0) {
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
@ -393,7 +393,7 @@ public class FieldType implements IndexableFieldType {
}
@Override
public VectorValues.SimilarityFunction vectorSimilarityFunction() {
public VectorSimilarityFunction vectorSimilarityFunction() {
return vectorSimilarityFunction;
}

View File

@ -17,6 +17,7 @@
package org.apache.lucene.document;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
/**
@ -33,8 +34,7 @@ import org.apache.lucene.index.VectorValues;
*/
public class VectorField extends Field {
private static FieldType createType(
float[] v, VectorValues.SimilarityFunction similarityFunction) {
private static FieldType createType(float[] v, VectorSimilarityFunction similarityFunction) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
@ -63,7 +63,7 @@ public class VectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or has dimension &gt; 1024.
*/
public static FieldType createFieldType(
int dimension, VectorValues.SimilarityFunction similarityFunction) {
int dimension, VectorSimilarityFunction similarityFunction) {
FieldType type = new FieldType();
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
type.freeze();
@ -82,8 +82,7 @@ public class VectorField extends Field {
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public VectorField(
String name, float[] vector, VectorValues.SimilarityFunction similarityFunction) {
public VectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
@ -99,7 +98,7 @@ public class VectorField extends Field {
* dimension &gt; 1024.
*/
public VectorField(String name, float[] vector) {
this(name, vector, VectorValues.SimilarityFunction.EUCLIDEAN);
this(name, vector, VectorSimilarityFunction.EUCLIDEAN);
}
/**

View File

@ -56,7 +56,7 @@ public final class FieldInfo {
// if it is a positive value, it means this field indexes vectors
private final int vectorDimension;
private final VectorValues.SimilarityFunction vectorSimilarityFunction;
private final VectorSimilarityFunction vectorSimilarityFunction;
// whether this field is used as the soft-deletes field
private final boolean softDeletesField;
@ -80,7 +80,7 @@ public final class FieldInfo {
int pointIndexDimensionCount,
int pointNumBytes,
int vectorDimension,
VectorValues.SimilarityFunction vectorSimilarityFunction,
VectorSimilarityFunction vectorSimilarityFunction,
boolean softDeletesField) {
this.name = Objects.requireNonNull(name);
this.number = number;
@ -202,7 +202,7 @@ public final class FieldInfo {
throw new IllegalArgumentException(
"vectorDimension must be >=0; got " + vectorDimension + " (field: '" + name + "')");
}
if (vectorDimension == 0 && vectorSimilarityFunction != VectorValues.SimilarityFunction.NONE) {
if (vectorDimension == 0 && vectorSimilarityFunction != VectorSimilarityFunction.NONE) {
throw new IllegalArgumentException(
"vector similarity function must be NONE when dimension = 0; got "
+ vectorSimilarityFunction
@ -355,9 +355,9 @@ public final class FieldInfo {
static void verifySameVectorOptions(
String fieldName,
int vd1,
VectorValues.SimilarityFunction vsf1,
VectorSimilarityFunction vsf1,
int vd2,
VectorValues.SimilarityFunction vsf2) {
VectorSimilarityFunction vsf2) {
if (vd1 != vd2 || vsf1 != vsf2) {
throw new IllegalArgumentException(
"cannot change field \""
@ -478,8 +478,8 @@ public final class FieldInfo {
return vectorDimension;
}
/** Returns {@link VectorValues.SimilarityFunction} for the field */
public VectorValues.SimilarityFunction getVectorSimilarityFunction() {
/** Returns {@link VectorSimilarityFunction} for the field */
public VectorSimilarityFunction getVectorSimilarityFunction() {
return vectorSimilarityFunction;
}

View File

@ -299,9 +299,9 @@ public class FieldInfos implements Iterable<FieldInfo> {
static final class FieldVectorProperties {
final int numDimensions;
final VectorValues.SimilarityFunction similarityFunction;
final VectorSimilarityFunction similarityFunction;
FieldVectorProperties(int numDimensions, VectorValues.SimilarityFunction similarityFunction) {
FieldVectorProperties(int numDimensions, VectorSimilarityFunction similarityFunction) {
this.numDimensions = numDimensions;
this.similarityFunction = similarityFunction;
}
@ -486,7 +486,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
(softDeletesFieldName != null && softDeletesFieldName.equals(fieldName)));
addOrGet(fi);
}
@ -567,7 +567,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
isSoftDeletesField);
}

View File

@ -101,8 +101,8 @@ public interface IndexableFieldType {
/** The number of dimensions of the field's vector value */
int vectorDimension();
/** The {@link VectorValues.SimilarityFunction} of the field's vector value */
VectorValues.SimilarityFunction vectorSimilarityFunction();
/** The {@link VectorSimilarityFunction} of the field's vector value */
VectorSimilarityFunction vectorSimilarityFunction();
/**
* Attributes for the field type.

View File

@ -1327,8 +1327,7 @@ final class IndexingChain implements Accountable {
private int pointIndexDimensionCount = 0;
private int pointNumBytes = 0;
private int vectorDimension = 0;
private VectorValues.SimilarityFunction vectorSimilarityFunction =
VectorValues.SimilarityFunction.NONE;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.NONE;
private static String errMsg =
"Inconsistency of field data structures across documents for field ";
@ -1383,8 +1382,8 @@ final class IndexingChain implements Accountable {
}
}
void setVectors(VectorValues.SimilarityFunction similarityFunction, int dimension) {
if (vectorSimilarityFunction == VectorValues.SimilarityFunction.NONE) {
void setVectors(VectorSimilarityFunction similarityFunction, int dimension) {
if (vectorSimilarityFunction == VectorSimilarityFunction.NONE) {
this.vectorDimension = dimension;
this.vectorSimilarityFunction = similarityFunction;
} else {
@ -1403,7 +1402,7 @@ final class IndexingChain implements Accountable {
pointIndexDimensionCount = 0;
pointNumBytes = 0;
vectorDimension = 0;
vectorSimilarityFunction = VectorValues.SimilarityFunction.NONE;
vectorSimilarityFunction = VectorSimilarityFunction.NONE;
}
void assertSameSchema(FieldInfo fi) {

View File

@ -33,9 +33,6 @@ public interface RandomAccessVectorValues {
/** Return the dimension of the returned vector values */
int dimension();
/** Return the similarity function used to compare these vectors */
VectorValues.SimilarityFunction similarityFunction();
/**
* Return the vector value indexed at the given ordinal. The provided floating point array may be
* shared and overwritten by subsequent calls to this method and {@link #binaryValue(int)}.

View File

@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.squareDistance;
import org.apache.lucene.codecs.VectorReader;
/**
* Vector similarity function; used in search to return top K most similar vectors to a target
* vector. This is a label describing the method used during indexing and searching of the vectors
* in order to determine the nearest neighbors.
*/
public enum VectorSimilarityFunction {
/**
* No similarity function is provided. Note: {@link VectorReader#search(String, float[], int)} is
* not supported for fields specifying this.
*/
NONE,
/** HNSW graph built using Euclidean distance */
EUCLIDEAN(true),
/** HNSW graph buit using dot product */
DOT_PRODUCT;
/**
* If true, the scores associated with vector comparisons are in reverse order; that is, lower
* scores represent more similar vectors. Otherwise, if false, higher scores represent more
* similar vectors.
*/
public final boolean reversed;
VectorSimilarityFunction(boolean reversed) {
this.reversed = reversed;
}
VectorSimilarityFunction() {
reversed = false;
}
/**
* Calculates a similarity score between the two vectors with a specified function.
*
* @param v1 a vector
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public float compare(float[] v1, float[] v2) {
switch (this) {
case EUCLIDEAN:
return squareDistance(v1, v2);
case DOT_PRODUCT:
return dotProduct(v1, v2);
case NONE:
default:
throw new IllegalStateException("Incomparable similarity function: " + this);
}
}
}

View File

@ -16,11 +16,7 @@
*/
package org.apache.lucene.index;
import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.squareDistance;
import java.io.IOException;
import org.apache.lucene.codecs.VectorReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
@ -50,9 +46,6 @@ public abstract class VectorValues extends DocIdSetIterator {
*/
public abstract int size();
/** Return the similarity function used to compare these vectors */
public abstract SimilarityFunction similarityFunction();
/**
* Return the vector value for the current document ID. It is illegal to call this method when the
* iterator is not positioned: before advancing, or after failing to advance. The returned array
@ -74,60 +67,6 @@ public abstract class VectorValues extends DocIdSetIterator {
throw new UnsupportedOperationException();
}
/**
* Vector similarity function; used in search to return top K most similar vectors to a target
* vector. This is a label describing the method used during indexing and searching of the vectors
* in order to determine the nearest neighbors.
*/
public enum SimilarityFunction {
/**
* No similarity function is provided. Note: {@link VectorReader#search(String, float[], int)}
* is not supported for fields specifying this.
*/
NONE,
/** HNSW graph built using Euclidean distance */
EUCLIDEAN(true),
/** HNSW graph buit using dot product */
DOT_PRODUCT;
/**
* If true, the scores associated with vector comparisons are in reverse order; that is, lower
* scores represent more similar vectors. Otherwise, if false, higher scores represent more
* similar vectors.
*/
public final boolean reversed;
SimilarityFunction(boolean reversed) {
this.reversed = reversed;
}
SimilarityFunction() {
reversed = false;
}
/**
* Calculates a similarity score between the two vectors with a specified function.
*
* @param v1 a vector
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public float compare(float[] v1, float[] v2) {
switch (this) {
case EUCLIDEAN:
return squareDistance(v1, v2);
case DOT_PRODUCT:
return dotProduct(v1, v2);
case NONE:
default:
throw new IllegalStateException("Incomparable similarity function: " + this);
}
}
}
/**
* Represents the lack of vector values. It is returned by providers that do not support
* VectorValues.
@ -145,11 +84,6 @@ public abstract class VectorValues extends DocIdSetIterator {
return 0;
}
@Override
public SimilarityFunction similarityFunction() {
return SimilarityFunction.NONE;
}
@Override
public float[] vectorValue() {
throw new IllegalStateException(

View File

@ -108,11 +108,7 @@ class VectorValuesWriter {
*/
public void flush(Sorter.DocMap sortMap, VectorWriter vectorWriter) throws IOException {
VectorValues vectorValues =
new BufferedVectorValues(
docsWithField,
vectors,
fieldInfo.getVectorDimension(),
fieldInfo.getVectorSimilarityFunction());
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
if (sortMap != null) {
vectorWriter.writeField(fieldInfo, new SortingVectorValues(vectorValues, sortMap));
} else {
@ -189,11 +185,6 @@ class VectorValuesWriter {
return delegate.size();
}
@Override
public SimilarityFunction similarityFunction() {
return delegate.similarityFunction();
}
@Override
public int advance(int target) throws IOException {
throw new UnsupportedOperationException();
@ -223,11 +214,6 @@ class VectorValuesWriter {
return delegateRA.dimension();
}
@Override
public SimilarityFunction similarityFunction() {
return delegateRA.similarityFunction();
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return delegateRA.vectorValue(ordMap[targetOrd]);
@ -248,7 +234,6 @@ class VectorValuesWriter {
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
final List<float[]> vectors;
final SimilarityFunction similarityFunction;
final int dimension;
final ByteBuffer buffer;
@ -259,15 +244,10 @@ class VectorValuesWriter {
DocIdSetIterator docsWithFieldIter;
int ord = -1;
BufferedVectorValues(
DocsWithFieldSet docsWithField,
List<float[]> vectors,
int dimension,
SimilarityFunction similarityFunction) {
BufferedVectorValues(DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
this.docsWithField = docsWithField;
this.vectors = vectors;
this.dimension = dimension;
this.similarityFunction = similarityFunction;
buffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
binaryValue = new BytesRef(buffer.array());
raBuffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
@ -277,7 +257,7 @@ class VectorValuesWriter {
@Override
public RandomAccessVectorValues randomAccess() {
return new BufferedVectorValues(docsWithField, vectors, dimension, similarityFunction);
return new BufferedVectorValues(docsWithField, vectors, dimension);
}
@Override
@ -290,11 +270,6 @@ class VectorValuesWriter {
return vectors.size();
}
@Override
public SimilarityFunction similarityFunction() {
return similarityFunction;
}
@Override
public BytesRef binaryValue() {
buffer.asFloatBuffer().put(vectorValue());

View File

@ -25,7 +25,7 @@ import java.util.List;
import java.util.Random;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.SparseFixedBitSet;
/**
@ -91,10 +91,10 @@ public final class HnswGraph extends KnnGraphValues {
int topK,
int numSeed,
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
KnnGraphValues graphValues,
Random random)
throws IOException {
VectorValues.SimilarityFunction similarityFunction = vectors.similarityFunction();
int size = graphValues.size();
// MIN heap, holding the top results

View File

@ -22,7 +22,7 @@ import java.util.Locale;
import java.util.Random;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
/**
@ -42,7 +42,7 @@ public final class HnswGraphBuilder {
private final int beamWidth;
private final NeighborArray scratch;
private final VectorValues.SimilarityFunction similarityFunction;
private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues;
private final Random random;
private final BoundsChecker bound;
@ -67,11 +67,15 @@ public final class HnswGraphBuilder {
* to ensure repeatable construction.
*/
public HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors, int maxConn, int beamWidth, long seed) {
RandomAccessVectorValuesProducer vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
long seed) {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
similarityFunction = vectorValues.similarityFunction();
if (similarityFunction == VectorValues.SimilarityFunction.NONE) {
this.similarityFunction = similarityFunction;
if (similarityFunction == VectorSimilarityFunction.NONE) {
throw new IllegalStateException("No distance function");
}
if (maxConn <= 0) {
@ -133,7 +137,8 @@ public final class HnswGraphBuilder {
/** Inserts a doc with vector value to the graph */
void addGraphNode(float[] value) throws IOException {
NeighborQueue candidates =
HnswGraph.search(value, beamWidth, beamWidth, vectorValues, hnsw, random);
HnswGraph.search(
value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, random);
int node = hnsw.addNode();

View File

@ -32,7 +32,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
@ -88,10 +88,10 @@ public class TestPerFieldConsistency extends LuceneTestCase {
}
private static Field randomVectorField(Random random, String fieldName) {
VectorValues.SimilarityFunction similarityFunction =
RandomPicks.randomFrom(random, VectorValues.SimilarityFunction.values());
while (similarityFunction == VectorValues.SimilarityFunction.NONE) {
similarityFunction = RandomPicks.randomFrom(random, VectorValues.SimilarityFunction.values());
VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(random, VectorSimilarityFunction.values());
while (similarityFunction == VectorSimilarityFunction.NONE) {
similarityFunction = RandomPicks.randomFrom(random, VectorSimilarityFunction.values());
}
float[] values = new float[randomIntBetween(1, 10)];
for (int i = 0; i < values.length; i++) {

View File

@ -112,7 +112,7 @@ public class TestCodecs extends LuceneTestCase {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false));
}
this.terms = terms;

View File

@ -381,7 +381,6 @@ public class TestDocumentWriter extends LuceneTestCase {
public void testRAMUsageVector() throws IOException {
doTestRAMUsage(
field ->
new VectorField(
field, new float[] {1, 2, 3, 4}, VectorValues.SimilarityFunction.EUCLIDEAN));
new VectorField(field, new float[] {1, 2, 3, 4}, VectorSimilarityFunction.EUCLIDEAN));
}
}

View File

@ -260,7 +260,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false));
}
int idx =
@ -279,7 +279,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false));
assertEquals("Field numbers 0 through 9 were allocated", 10, idx);
@ -300,7 +300,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false));
assertEquals("Field numbers should reset after clear()", 0, idx);
}

View File

@ -63,7 +63,7 @@ public class TestFieldsReader extends LuceneTestCase {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
field.name().equals(softDeletesFieldName)));
}
dir = newDirectory();

View File

@ -113,8 +113,8 @@ public class TestIndexableField extends LuceneTestCase {
}
@Override
public VectorValues.SimilarityFunction vectorSimilarityFunction() {
return VectorValues.SimilarityFunction.NONE;
public VectorSimilarityFunction vectorSimilarityFunction() {
return VectorSimilarityFunction.NONE;
}
@Override

View File

@ -38,7 +38,6 @@ import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.index.VectorValues.SimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
@ -58,7 +57,7 @@ public class TestKnnGraph extends LuceneTestCase {
private static int maxConn = Lucene90HnswVectorFormat.DEFAULT_MAX_CONN;
private Codec codec;
private SimilarityFunction similarityFunction;
private VectorSimilarityFunction similarityFunction;
@Before
public void setup() {
@ -76,8 +75,8 @@ public class TestKnnGraph extends LuceneTestCase {
}
};
int similarity = random().nextInt(SimilarityFunction.values().length - 1) + 1;
similarityFunction = SimilarityFunction.values()[similarity];
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
similarityFunction = VectorSimilarityFunction.values()[similarity];
}
@After
@ -227,7 +226,7 @@ public class TestKnnGraph extends LuceneTestCase {
/** Verify that searching does something reasonable */
public void testSearch() throws Exception {
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
similarityFunction = SimilarityFunction.EUCLIDEAN;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
IndexWriterConfig config = newIndexWriterConfig();
config.setCodec(codec); // test is not compatible with simpletext
try (Directory dir = newDirectory();
@ -454,7 +453,8 @@ public class TestKnnGraph extends LuceneTestCase {
add(iw, id, vector, similarityFunction);
}
private void add(IndexWriter iw, int id, float[] vector, SimilarityFunction similarityFunction)
private void add(
IndexWriter iw, int id, float[] vector, VectorSimilarityFunction similarityFunction)
throws IOException {
Document doc = new Document();
if (vector != null) {

View File

@ -17,7 +17,7 @@
package org.apache.lucene.index;
import static org.apache.lucene.index.VectorValues.SimilarityFunction.NONE;
import static org.apache.lucene.index.VectorSimilarityFunction.NONE;
import java.io.IOException;
import java.util.Arrays;

View File

@ -53,7 +53,7 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
@ -73,8 +73,8 @@ public class KnnGraphTester {
private static final String KNN_FIELD = "knn";
private static final String ID_FIELD = "id";
private static final VectorValues.SimilarityFunction SIMILARITY_FUNCTION =
VectorValues.SimilarityFunction.DOT_PRODUCT;
private static final VectorSimilarityFunction SIMILARITY_FUNCTION =
VectorSimilarityFunction.DOT_PRODUCT;
private int numDocs;
private int dim;
@ -251,7 +251,8 @@ public class KnnGraphTester {
private void dumpGraph(Path docsPath) throws IOException {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, 0);
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0);
// start at node 1
for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(values.vectorValue(i));
@ -580,8 +581,7 @@ public class KnnGraphTester {
iwc.setRAMBufferSizeMB(1994d);
// iwc.setMaxBufferedDocs(10000);
FieldType fieldType =
VectorField.createFieldType(dim, VectorValues.SimilarityFunction.DOT_PRODUCT);
FieldType fieldType = VectorField.createFieldType(dim, VectorSimilarityFunction.DOT_PRODUCT);
if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath);
@ -675,11 +675,6 @@ public class KnnGraphTester {
return dim;
}
@Override
public VectorValues.SimilarityFunction similarityFunction() {
return SIMILARITY_FUNCTION;
}
@Override
public float[] vectorValue(int targetOrd) {
int pos = targetOrd * dim;

View File

@ -30,13 +30,11 @@ class MockVectorValues extends VectorValues
protected final int dimension;
protected final float[][] denseValues;
protected final float[][] values;
protected final SimilarityFunction similarityFunction;
private final int numVectors;
private int pos = -1;
MockVectorValues(SimilarityFunction similarityFunction, float[][] values) {
this.similarityFunction = similarityFunction;
MockVectorValues(float[][] values) {
this.dimension = values[0].length;
this.values = values;
int maxDoc = values.length;
@ -52,7 +50,7 @@ class MockVectorValues extends VectorValues
}
public MockVectorValues copy() {
return new MockVectorValues(similarityFunction, values);
return new MockVectorValues(values);
}
@Override
@ -60,11 +58,6 @@ class MockVectorValues extends VectorValues
return numVectors;
}
@Override
public SimilarityFunction similarityFunction() {
return similarityFunction;
}
@Override
public int dimension() {
return dimension;

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.ArrayUtil;
@ -61,7 +62,11 @@ public class TestHnsw extends LuceneTestCase {
int maxConn = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong();
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, seed);
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors);
// Recreate the graph while indexing with the same random seed and write it out
@ -87,7 +92,7 @@ public class TestHnsw extends LuceneTestCase {
indexedDoc++;
}
Document doc = new Document();
doc.add(new VectorField("field", v2.vectorValue(), v2.similarityFunction));
doc.add(new VectorField("field", v2.vectorValue()));
doc.add(new StoredField("id", v2.docID()));
iw.addDocument(doc);
nVec++;
@ -97,7 +102,6 @@ public class TestHnsw extends LuceneTestCase {
try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) {
VectorValues values = ctx.reader().getVectorValues("field");
assertEquals(vectors.similarityFunction, values.similarityFunction());
assertEquals(dim, values.dimension());
assertEquals(nVec, values.size());
assertEquals(indexedDoc, ctx.reader().maxDoc());
@ -121,11 +125,20 @@ public class TestHnsw extends LuceneTestCase {
public void testAknnDiverse() throws IOException {
int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 16, 100, random().nextInt());
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
HnswGraph hnsw = builder.build(vectors);
// run some searches
NeighborQueue nn =
HnswGraph.search(new float[] {1, 0}, 10, 5, vectors.randomAccess(), hnsw, random());
HnswGraph.search(
new float[] {1, 0},
10,
5,
vectors.randomAccess(),
VectorSimilarityFunction.DOT_PRODUCT,
hnsw,
random());
int sum = 0;
for (int node : nn.nodes()) {
sum += node;
@ -168,20 +181,31 @@ public class TestHnsw extends LuceneTestCase {
}
public void testHnswGraphBuilderInvalid() {
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0));
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
expectThrows(
IllegalArgumentException.class,
() -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0));
() ->
new HnswGraphBuilder(
new RandomVectorValues(1, 1, random()),
VectorSimilarityFunction.EUCLIDEAN,
0,
10,
0));
expectThrows(
IllegalArgumentException.class,
() -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0));
() ->
new HnswGraphBuilder(
new RandomVectorValues(1, 1, random()),
VectorSimilarityFunction.EUCLIDEAN,
10,
0,
0));
}
public void testDiversity() throws IOException {
// Some carefully checked test cases with simple 2d vectors on the unit circle:
MockVectorValues vectors =
new MockVectorValues(
VectorValues.SimilarityFunction.DOT_PRODUCT,
new float[][] {
unitVector2d(0.5),
unitVector2d(0.75),
@ -191,7 +215,9 @@ public class TestHnsw extends LuceneTestCase {
unitVector2d(0.77),
});
// First add nodes until everybody gets a full neighbor list
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 2, 10, random().nextInt());
HnswGraphBuilder builder =
new HnswGraphBuilder(
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
builder.addGraphNode(vectors.vectorValue(1));
@ -247,18 +273,22 @@ public class TestHnsw extends LuceneTestCase {
int dim = atLeast(10);
int topK = 5;
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 10, 30, random().nextLong());
VectorSimilarityFunction similarityFunction =
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
HnswGraphBuilder builder =
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
HnswGraph hnsw = builder.build(vectors);
int totalMatches = 0;
for (int i = 0; i < 100; i++) {
float[] query = randomVector(random(), dim);
NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random());
NeighborQueue expected = new NeighborQueue(topK, vectors.similarityFunction.reversed);
NeighborQueue actual =
HnswGraph.search(query, topK, 100, vectors, similarityFunction, hnsw, random());
NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed);
for (int j = 0; j < size; j++) {
float[] v = vectors.vectorValue(j);
if (v != null) {
expected.insertWithOverflow(
j, vectors.similarityFunction.compare(query, vectors.vectorValue(j)));
expected.insertWithOverflow(j, similarityFunction.compare(query, vectors.vectorValue(j)));
}
}
assertEquals(topK, actual.size());
@ -304,11 +334,6 @@ public class TestHnsw extends LuceneTestCase {
return new CircularVectorValues(size);
}
@Override
public SimilarityFunction similarityFunction() {
return SimilarityFunction.DOT_PRODUCT;
}
@Override
public int dimension() {
return 2;
@ -409,13 +434,11 @@ public class TestHnsw extends LuceneTestCase {
static class RandomVectorValues extends MockVectorValues {
RandomVectorValues(int size, int dimension, Random random) {
super(
SimilarityFunction.values()[random.nextInt(SimilarityFunction.values().length - 1) + 1],
createRandomVectors(size, dimension, random));
super(createRandomVectors(size, dimension, random));
}
RandomVectorValues(RandomVectorValues other) {
super(other.similarityFunction, other.values);
super(other.values);
}
@Override

View File

@ -35,6 +35,7 @@ import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.TermVectors;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.Bits;
@ -97,7 +98,7 @@ public class TermVectorLeafReader extends LeafReader {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false);
fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo});
}

View File

@ -93,7 +93,6 @@ public class AssertingVectorFormat extends VectorFormat {
assert values.docID() == -1;
assert values.size() >= 0;
assert values.dimension() > 0;
assert values.similarityFunction() != null;
}
return values;
}

View File

@ -341,8 +341,8 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
if (r.nextBoolean()) {
int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS);
VectorValues.SimilarityFunction similarityFunction =
RandomPicks.randomFrom(r, VectorValues.SimilarityFunction.values());
VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
}
@ -412,7 +412,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false);
}
}

View File

@ -49,15 +49,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
@Override
protected void addRandomFields(Document doc) {
doc.add(new VectorField("v2", randomVector(30), VectorValues.SimilarityFunction.NONE));
doc.add(new VectorField("v2", randomVector(30), VectorSimilarityFunction.NONE));
}
public void testFieldConstructor() {
float[] v = new float[1];
VectorField field = new VectorField("f", v);
assertEquals(1, field.fieldType().vectorDimension());
assertEquals(
VectorValues.SimilarityFunction.EUCLIDEAN, field.fieldType().vectorSimilarityFunction());
assertEquals(VectorSimilarityFunction.EUCLIDEAN, field.fieldType().vectorSimilarityFunction());
assertSame(v, field.vectorValue());
}
@ -66,7 +65,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", null));
expectThrows(
IllegalArgumentException.class,
() -> new VectorField("f", new float[1], (VectorValues.SimilarityFunction) null));
() -> new VectorField("f", new float[1], (VectorSimilarityFunction) null));
expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[0]));
expectThrows(
IllegalArgumentException.class,
@ -91,11 +90,11 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[3], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc2.add(new VectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
@ -107,12 +106,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.commit();
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[3], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc2.add(new VectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
@ -127,11 +126,11 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN));
doc2.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
@ -143,12 +142,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.commit();
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN));
doc2.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
String errMsg =
@ -162,13 +161,13 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[1], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
assertEquals(
@ -183,13 +182,13 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN));
doc2.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
assertEquals(
@ -203,7 +202,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
public void testAddIndexesDirectory0() throws Exception {
String fieldName = "field";
Document doc = new Document();
doc.add(new VectorField(fieldName, new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.DOT_PRODUCT));
try (Directory dir = newDirectory();
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
@ -231,8 +230,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
w.addDocument(doc);
}
doc.add(
new VectorField(fieldName, new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.DOT_PRODUCT));
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
w2.addDocument(doc);
w2.addIndexes(dir);
@ -252,7 +250,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
String fieldName = "field";
float[] vector = new float[1];
Document doc = new Document();
doc.add(new VectorField(fieldName, vector, VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField(fieldName, vector, VectorSimilarityFunction.DOT_PRODUCT));
try (Directory dir = newDirectory();
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
@ -283,12 +281,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[5], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
w2.addDocument(doc);
IllegalArgumentException expected =
expectThrows(
@ -306,12 +304,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN));
w2.addDocument(doc);
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir));
@ -328,12 +326,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[5], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
w2.addDocument(doc);
try (DirectoryReader r = DirectoryReader.open(dir)) {
IllegalArgumentException expected =
@ -354,12 +352,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN));
w2.addDocument(doc);
try (DirectoryReader r = DirectoryReader.open(dir)) {
IllegalArgumentException expected =
@ -380,12 +378,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[5], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT));
w2.addDocument(doc);
try (DirectoryReader r = DirectoryReader.open(dir)) {
IllegalArgumentException expected =
@ -404,12 +402,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Directory dir2 = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN));
w2.addDocument(doc);
try (DirectoryReader r = DirectoryReader.open(dir)) {
IllegalArgumentException expected =
@ -427,8 +425,8 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException expected =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc));
assertEquals(
@ -448,10 +446,10 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
new VectorField(
"f",
new float[VectorValues.MAX_DIMENSIONS + 1],
VectorValues.SimilarityFunction.DOT_PRODUCT)));
VectorSimilarityFunction.DOT_PRODUCT)));
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[1], VectorValues.SimilarityFunction.EUCLIDEAN));
doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc2);
}
}
@ -463,13 +461,11 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Exception e =
expectThrows(
IllegalArgumentException.class,
() ->
doc.add(
new VectorField("f", new float[0], VectorValues.SimilarityFunction.NONE)));
() -> doc.add(new VectorField("f", new float[0], VectorSimilarityFunction.NONE)));
assertEquals("cannot index an empty vector", e.getMessage());
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[1], VectorValues.SimilarityFunction.NONE));
doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.NONE));
w.addDocument(doc2);
}
}
@ -479,14 +475,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setCodec(Codec.forName("SimpleText"));
try (IndexWriter w = new IndexWriter(dir, iwc)) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.forceMerge(1);
}
@ -500,12 +496,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, iwc)) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
}
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.forceMerge(1);
}
@ -513,8 +509,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
}
public void testInvalidVectorFieldUsage() {
VectorField field =
new VectorField("field", new float[2], VectorValues.SimilarityFunction.NONE);
VectorField field = new VectorField("field", new float[2], VectorSimilarityFunction.NONE);
expectThrows(IllegalArgumentException.class, () -> field.setIntValue(14));
@ -528,8 +523,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new StringField("id", "0", Field.Store.NO));
doc.add(
new VectorField("v", new float[] {2, 3, 5}, VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("v", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.addDocument(new Document());
w.commit();
@ -554,16 +548,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new StringField("id", "0", Field.Store.NO));
doc.add(
new VectorField(
"v0", new float[] {2, 3, 5}, VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("v0", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.commit();
doc = new Document();
doc.add(
new VectorField(
"v1", new float[] {2, 3, 5}, VectorValues.SimilarityFunction.DOT_PRODUCT));
doc.add(new VectorField("v1", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.forceMerge(1);
}
@ -575,13 +565,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
int[] fieldDocCounts = new int[numFields];
float[] fieldTotals = new float[numFields];
int[] fieldDims = new int[numFields];
VectorValues.SimilarityFunction[] fieldSearchStrategies =
new VectorValues.SimilarityFunction[numFields];
VectorSimilarityFunction[] fieldSearchStrategies = new VectorSimilarityFunction[numFields];
for (int i = 0; i < numFields; i++) {
fieldDims[i] = random().nextInt(20) + 1;
fieldSearchStrategies[i] =
VectorValues.SimilarityFunction.values()[
random().nextInt(VectorValues.SimilarityFunction.values().length)];
VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)];
}
try (Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
@ -628,15 +617,15 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory();
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc1 = new Document();
doc1.add(new VectorField(fieldName, v, VectorValues.SimilarityFunction.EUCLIDEAN));
doc1.add(new VectorField(fieldName, v, VectorSimilarityFunction.EUCLIDEAN));
v[0] = 1;
Document doc2 = new Document();
doc2.add(new VectorField(fieldName, v, VectorValues.SimilarityFunction.EUCLIDEAN));
doc2.add(new VectorField(fieldName, v, VectorSimilarityFunction.EUCLIDEAN));
iw.addDocument(doc1);
iw.addDocument(doc2);
v[0] = 2;
Document doc3 = new Document();
doc3.add(new VectorField(fieldName, v, VectorValues.SimilarityFunction.EUCLIDEAN));
doc3.add(new VectorField(fieldName, v, VectorSimilarityFunction.EUCLIDEAN));
iw.addDocument(doc3);
iw.forceMerge(1);
try (IndexReader reader = iw.getReader()) {
@ -691,16 +680,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
float[] v = new float[] {1};
doc.add(new VectorField("field1", v, VectorValues.SimilarityFunction.EUCLIDEAN));
doc.add(
new VectorField("field2", new float[] {1, 2, 3}, VectorValues.SimilarityFunction.NONE));
doc.add(new VectorField("field1", v, VectorSimilarityFunction.EUCLIDEAN));
doc.add(new VectorField("field2", new float[] {1, 2, 3}, VectorSimilarityFunction.NONE));
iw.addDocument(doc);
v[0] = 2;
iw.addDocument(doc);
doc = new Document();
doc.add(
new VectorField(
"field3", new float[] {1, 2, 3}, VectorValues.SimilarityFunction.DOT_PRODUCT));
new VectorField("field3", new float[] {1, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT));
iw.addDocument(doc);
iw.forceMerge(1);
try (IndexReader reader = iw.getReader()) {
@ -761,9 +748,9 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
add(iw, fieldName, i, scratch, VectorValues.SimilarityFunction.NONE);
add(iw, fieldName, i, scratch, VectorSimilarityFunction.NONE);
} else {
add(iw, fieldName, i, values[i], VectorValues.SimilarityFunction.NONE);
add(iw, fieldName, i, values[i], VectorSimilarityFunction.NONE);
}
if (random().nextInt(10) == 2) {
// sometimes delete a random document
@ -834,7 +821,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
}
id2value[id] = value;
id2ord[id] = i;
add(iw, fieldName, id, value, VectorValues.SimilarityFunction.EUCLIDEAN);
add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN);
}
try (IndexReader reader = iw.getReader()) {
for (LeafReaderContext ctx : reader.leaves()) {
@ -871,14 +858,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
String field,
int id,
float[] vector,
VectorValues.SimilarityFunction similarityFunction)
VectorSimilarityFunction similarityFunction)
throws IOException {
add(iw, field, id, random().nextInt(100), vector, similarityFunction);
}
private void add(IndexWriter iw, String field, int id, int sortkey, float[] vector)
throws IOException {
add(iw, field, id, sortkey, vector, VectorValues.SimilarityFunction.NONE);
add(iw, field, id, sortkey, vector, VectorSimilarityFunction.NONE);
}
private void add(
@ -887,7 +874,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
int id,
int sortkey,
float[] vector,
VectorValues.SimilarityFunction similarityFunction)
VectorSimilarityFunction similarityFunction)
throws IOException {
Document doc = new Document();
if (vector != null) {
@ -913,10 +900,10 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("v1", randomVector(3), VectorValues.SimilarityFunction.NONE));
doc.add(new VectorField("v1", randomVector(3), VectorSimilarityFunction.NONE));
w.addDocument(doc);
doc.add(new VectorField("v2", randomVector(3), VectorValues.SimilarityFunction.NONE));
doc.add(new VectorField("v2", randomVector(3), VectorSimilarityFunction.NONE));
w.addDocument(doc);
}
@ -937,10 +924,10 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
public void testSimilarityFunctionIdentifiers() {
// make sure we don't accidentally mess up similarity function identifiers by re-ordering their
// enumerators
assertEquals(0, VectorValues.SimilarityFunction.NONE.ordinal());
assertEquals(1, VectorValues.SimilarityFunction.EUCLIDEAN.ordinal());
assertEquals(2, VectorValues.SimilarityFunction.DOT_PRODUCT.ordinal());
assertEquals(3, VectorValues.SimilarityFunction.values().length);
assertEquals(0, VectorSimilarityFunction.NONE.ordinal());
assertEquals(1, VectorSimilarityFunction.EUCLIDEAN.ordinal());
assertEquals(2, VectorSimilarityFunction.DOT_PRODUCT.ordinal());
assertEquals(3, VectorSimilarityFunction.values().length);
}
public void testAdvance() throws Exception {
@ -952,7 +939,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Document doc = new Document();
// randomly add a vector field
if (random().nextInt(4) == 3) {
doc.add(new VectorField(fieldName, new float[4], VectorValues.SimilarityFunction.NONE));
doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.NONE));
}
w.addDocument(doc);
}

View File

@ -140,7 +140,7 @@ public class RandomPostingsTester {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false);
fieldUpto++;
@ -711,7 +711,7 @@ public class RandomPostingsTester {
0,
0,
0,
VectorValues.SimilarityFunction.NONE,
VectorSimilarityFunction.NONE,
false);
}