Ability to compute vector similarity scores with DoubleValuesSource (#12548)

### Description

This PR addresses the issue #12394. It adds an API **`similarityToQueryVector`** to `DoubleValuesSource` to compute vector similarity scores between the query vector and the `KnnByteVectorField`/`KnnFloatVectorField` for documents using the 2 new DVS implementations (`ByteVectorSimilarityValuesSource` for byte vectors and `FloatVectorSimilarityValuesSource` for float vectors). Below are the method signatures added to DVS in this PR:

- `DoubleValues similarityToQueryVector(LeafReaderContext ctx, float[] queryVector, String vectorField)` *(uses ByteVectorSimilarityValuesSource)*
- `DoubleValues similarityToQueryVector(LeafReaderContext ctx, byte[] queryVector, String vectorField)` *(uses FloatVectorSimilarityValuesSource)*

Closes #12394
This commit is contained in:
Shubham Chaudhary 2023-10-12 23:04:37 +05:30 committed by GitHub
parent 268dd54a86
commit 52dfe50e8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 644 additions and 1 deletions

View File

@ -147,7 +147,8 @@ API Changes
New Features
---------------------
(No changes)
* GITHUB#12548: Added similarityToQueryVector API to compute vector similarity scores
with DoubleValuesSource. (Shubham Chaudhary)
Improvements
---------------------

View File

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
/**
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
* and the {@link org.apache.lucene.document.KnnByteVectorField} for documents.
*/
class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
private final byte[] queryVector;
public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) {
super(fieldName);
this.queryVector = vector;
}
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return function.compare(queryVector, vectorValues.vectorValue());
}
@Override
public boolean advanceExact(int doc) throws IOException {
return doc >= vectorValues.docID()
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
}
};
}
@Override
public int hashCode() {
return Objects.hash(fieldName, Arrays.hashCode(queryVector));
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
ByteVectorSimilarityValuesSource other = (ByteVectorSimilarityValuesSource) obj;
return Objects.equals(fieldName, other.fieldName)
&& Arrays.equals(queryVector, other.queryVector);
}
@Override
public String toString() {
return "ByteVectorSimilarityValuesSource(fieldName="
+ fieldName
+ " queryVector="
+ Arrays.toString(queryVector)
+ ")";
}
}

View File

@ -24,6 +24,7 @@ import java.util.function.LongToDoubleFunction;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.comparators.DoubleComparator;
/**
@ -172,6 +173,52 @@ public abstract class DoubleValuesSource implements SegmentCacheable {
}
}
/**
* Returns a DoubleValues instance for computing the vector similarity score per document against
* the byte query vector
*
* @param ctx the context for which to return the DoubleValues
* @param queryVector byte query vector
* @param vectorField knn byte field name
* @return DoubleValues instance
* @throws IOException if an {@link IOException} occurs
*/
public static DoubleValues similarityToQueryVector(
LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException {
if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding()
!= VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"Field "
+ vectorField
+ " does not have the expected vector encoding: "
+ VectorEncoding.BYTE);
}
return new ByteVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
}
/**
* Returns a DoubleValues instance for computing the vector similarity score per document against
* the float query vector
*
* @param ctx the context for which to return the DoubleValues
* @param queryVector float query vector
* @param vectorField knn float field name
* @return DoubleValues instance
* @throws IOException if an {@link IOException} occurs
*/
public static DoubleValues similarityToQueryVector(
LeafReaderContext ctx, float[] queryVector, String vectorField) throws IOException {
if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding()
!= VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"Field "
+ vectorField
+ " does not have the expected vector encoding: "
+ VectorEncoding.FLOAT32);
}
return new FloatVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
}
/**
* Creates a DoubleValuesSource that wraps a generic NumericDocValues field
*

View File

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
/**
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
* and the {@link org.apache.lucene.document.KnnFloatVectorField} for documents.
*/
class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
private final float[] queryVector;
public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
super(fieldName);
this.queryVector = vector;
}
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
VectorSimilarityFunction function =
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return function.compare(queryVector, vectorValues.vectorValue());
}
@Override
public boolean advanceExact(int doc) throws IOException {
return doc >= vectorValues.docID()
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
}
};
}
@Override
public int hashCode() {
return Objects.hash(fieldName, Arrays.hashCode(queryVector));
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
FloatVectorSimilarityValuesSource other = (FloatVectorSimilarityValuesSource) obj;
return Objects.equals(fieldName, other.fieldName)
&& Arrays.equals(queryVector, other.queryVector);
}
@Override
public String toString() {
return "FloatVectorSimilarityValuesSource(fieldName="
+ fieldName
+ " queryVector="
+ Arrays.toString(queryVector)
+ ")";
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
/**
* An abstract class that provides the vector similarity scores between the query vector and the
* {@link org.apache.lucene.document.KnnFloatVectorField} or {@link
* org.apache.lucene.document.KnnByteVectorField} for documents.
*/
abstract class VectorSimilarityValuesSource extends DoubleValuesSource {
protected final String fieldName;
public VectorSimilarityValuesSource(String fieldName) {
this.fieldName = fieldName;
}
@Override
public abstract DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores)
throws IOException;
@Override
public boolean needsScores() {
return false;
}
@Override
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
return this;
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
}

View File

@ -0,0 +1,381 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.junit.AfterClass;
import org.junit.BeforeClass;
public class TestVectorSimilarityValuesSource extends LuceneTestCase {
private static Directory dir;
private static Analyzer analyzer;
private static IndexReader reader;
private static IndexSearcher searcher;
@BeforeClass
public static void beforeClass() throws Exception {
dir = newDirectory();
analyzer = new MockAnalyzer(random());
IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer);
iwConfig.setMergePolicy(newLogMergePolicy());
RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig);
Document document = new Document();
document.add(new StringField("id", "1", Field.Store.NO));
document.add(new SortedDocValuesField("id", new BytesRef("1")));
document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f}));
document.add(
new KnnFloatVectorField(
"knnFloatField2",
new float[] {2.2f, -3.2f, -3.1f},
VectorSimilarityFunction.DOT_PRODUCT));
document.add(
new KnnFloatVectorField(
"knnFloatField3", new float[] {4.5f, 10.3f, -7.f}, VectorSimilarityFunction.COSINE));
document.add(
new KnnFloatVectorField(
"knnFloatField4",
new float[] {-1.3f, 1.0f, 1.0f},
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
document.add(new KnnFloatVectorField("knnFloatField5", new float[] {-6.7f, -1.0f, -0.9f}));
document.add(new KnnByteVectorField("knnByteField1", new byte[] {106, 80, 127}));
document.add(
new KnnByteVectorField(
"knnByteField2", new byte[] {4, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT));
document.add(
new KnnByteVectorField(
"knnByteField3", new byte[] {-121, -64, -1}, VectorSimilarityFunction.COSINE));
document.add(
new KnnByteVectorField(
"knnByteField4",
new byte[] {-127, 127, 127},
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
iw.addDocument(document);
Document document2 = new Document();
document2.add(new StringField("id", "2", Field.Store.NO));
document2.add(new SortedDocValuesField("id", new BytesRef("2")));
document2.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f}));
document2.add(
new KnnFloatVectorField(
"knnFloatField2",
new float[] {-5.2f, 8.7f, 3.1f},
VectorSimilarityFunction.DOT_PRODUCT));
document2.add(
new KnnFloatVectorField(
"knnFloatField3", new float[] {0.2f, -3.2f, 3.1f}, VectorSimilarityFunction.COSINE));
document2.add(new KnnFloatVectorField("knnFloatField5", new float[] {2.f, 13.2f, 9.1f}));
document2.add(new KnnByteVectorField("knnByteField1", new byte[] {1, -2, -30}));
document2.add(
new KnnByteVectorField(
"knnByteField2", new byte[] {40, 21, 3}, VectorSimilarityFunction.DOT_PRODUCT));
document2.add(
new KnnByteVectorField(
"knnByteField3", new byte[] {9, 2, 3}, VectorSimilarityFunction.COSINE));
document2.add(
new KnnByteVectorField(
"knnByteField4",
new byte[] {14, 29, 31},
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
iw.addDocument(document2);
Document document3 = new Document();
document3.add(new StringField("id", "3", Field.Store.NO));
document3.add(new SortedDocValuesField("id", new BytesRef("3")));
document3.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f}));
document3.add(
new KnnFloatVectorField(
"knnFloatField2", new float[] {-8.f, 7.f, -6.f}, VectorSimilarityFunction.DOT_PRODUCT));
document3.add(new KnnFloatVectorField("knnFloatField5", new float[] {5.2f, 3.2f, 3.1f}));
document3.add(new KnnByteVectorField("knnByteField1", new byte[] {-128, 0, 127}));
document3.add(
new KnnByteVectorField(
"knnByteField2", new byte[] {-1, -2, -3}, VectorSimilarityFunction.DOT_PRODUCT));
document3.add(
new KnnByteVectorField(
"knnByteField3", new byte[] {4, 2, 3}, VectorSimilarityFunction.COSINE));
document3.add(
new KnnByteVectorField(
"knnByteField4",
new byte[] {-4, -2, -128},
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
document3.add(new KnnByteVectorField("knnByteField5", new byte[] {-120, -2, 3}));
iw.addDocument(document3);
reader = iw.getReader();
searcher = newSearcher(reader);
iw.close();
}
@AfterClass
public static void afterClass() throws Exception {
searcher = null;
IOUtils.close(reader, dir, analyzer);
}
public void testEuclideanSimilarityValuesSource() throws Exception {
float[] floatQueryVector = new float[] {9.f, 1.f, -10.f};
// Checks the computed similarity score between indexed vectors and query vector
// using DVS is correct by passing indexed and query vector in #compare
DoubleValues dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField1");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField5");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new float[] {-6.7f, -1.0f, -0.9f}, floatQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new float[] {2.f, 13.2f, 9.1f}, floatQueryVector));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new float[] {5.2f, 3.2f, 3.1f}, floatQueryVector));
byte[] byteQueryVector = new byte[] {-128, 2, 127};
dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField1");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new byte[] {106, 80, 127}, byteQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new byte[] {1, -2, -30}, byteQueryVector));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new byte[] {-128, 0, 127}, byteQueryVector));
dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField5");
assertFalse(dv.advanceExact(0));
assertFalse(dv.advanceExact(1));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.EUCLIDEAN.compare(
new byte[] {-120, -2, 3}, byteQueryVector));
}
public void testDotSimilarityValuesSource() throws Exception {
float[] floatQueryVector = new float[] {10.f, 1.f, -8.5f};
// Checks the computed similarity score between indexed vectors and query vector
// using DVS is correct by passing indexed and query vector in #compare
DoubleValues dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField2");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.DOT_PRODUCT.compare(
new float[] {2.2f, -3.2f, -3.1f}, floatQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.DOT_PRODUCT.compare(
new float[] {-5.2f, 8.7f, 3.1f}, floatQueryVector));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.DOT_PRODUCT.compare(
new float[] {-8.f, 7.f, -6.f}, floatQueryVector));
byte[] byteQueryVector = new byte[] {-128, 2, 127};
dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField2");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.DOT_PRODUCT.compare(
new byte[] {4, 2, 3}, byteQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.DOT_PRODUCT.compare(
new byte[] {40, 21, 3}, byteQueryVector));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.DOT_PRODUCT.compare(
new byte[] {-1, -2, -3}, byteQueryVector));
}
public void testCosineSimilarityValuesSource() throws Exception {
float[] floatQueryVector = new float[] {0.6f, -1.6f, 38.0f};
// Checks the computed similarity score between indexed vectors and query vector
// using DVS is correct by passing indexed and query vector in #compare
DoubleValues dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField3");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.COSINE.compare(
new float[] {4.5f, 10.3f, -7.f}, floatQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.COSINE.compare(
new float[] {0.2f, -3.2f, 3.1f}, floatQueryVector));
assertFalse(dv.advanceExact(2));
byte[] byteQueryVector = new byte[] {-10, 8, 0};
dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField3");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.COSINE.compare(
new byte[] {-121, -64, -1}, byteQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.COSINE.compare(new byte[] {9, 2, 3}, byteQueryVector));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.COSINE.compare(new byte[] {4, 2, 3}, byteQueryVector));
}
public void testMaximumProductSimilarityValuesSource() throws Exception {
float[] floatQueryVector = new float[] {1.f, -6.f, -10.f};
// Checks the computed similarity score between indexed vectors and query vector
// using DVS is correct by passing indexed and query vector in #compare
DoubleValues dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField4");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
new float[] {-1.3f, 1.0f, 1.0f}, floatQueryVector));
assertFalse(dv.advanceExact(1));
assertFalse(dv.advanceExact(2));
byte[] byteQueryVector = new byte[] {-127, 127, 127};
dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField4");
assertTrue(
dv.advanceExact(0)
&& dv.doubleValue()
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
new byte[] {-127, 127, 127}, byteQueryVector));
assertTrue(
dv.advanceExact(1)
&& dv.doubleValue()
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
new byte[] {14, 29, 31}, byteQueryVector));
assertTrue(
dv.advanceExact(2)
&& dv.doubleValue()
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
new byte[] {-4, -2, -128}, byteQueryVector));
}
public void testFailuresWithSimilarityValuesSource() throws Exception {
float[] floatQueryVector = new float[] {1.1f, 2.2f, 3.3f};
byte[] byteQueryVector = new byte[] {-10, 20, 30};
expectThrows(
IllegalArgumentException.class,
() ->
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), floatQueryVector, "knnByteField1"));
expectThrows(
IllegalArgumentException.class,
() ->
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), byteQueryVector, "knnFloatField1"));
DoubleValues dv =
DoubleValuesSource.similarityToQueryVector(
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField1");
assertTrue(dv.advanceExact(0));
assertEquals(
dv.doubleValue(),
VectorSimilarityFunction.EUCLIDEAN.compare(new float[] {1.f, 2.f, 3.f}, floatQueryVector),
0.0);
assertNotEquals(
dv.doubleValue(),
VectorSimilarityFunction.DOT_PRODUCT.compare(
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
assertNotEquals(
dv.doubleValue(),
VectorSimilarityFunction.COSINE.compare(new float[] {1.f, 2.f, 3.f}, floatQueryVector));
assertNotEquals(
dv.doubleValue(),
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
}
}