Refactor unit tests for vector functions. (#48662)

This PR performs the following changes:
* Split `ScoreScriptUtilsTests` into `DenseVectorFunctionTests` and
`SparseVectorFunctionTests`. This will make it easier to delete all sparse
vector function tests once we remove support on 8.x.
* As much as possible, break up the large test methods into individual tests
for each vector function (`cosineSimilarity`, `l2norm`, etc.).
This commit is contained in:
Julie Tibshirani 2019-10-30 15:20:49 -07:00
parent ede1681c5a
commit ae1ef5fd92
3 changed files with 328 additions and 358 deletions

View File

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.vectors.query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues;
import org.junit.Before;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class DenseVectorFunctionTests extends ESTestCase {
private String field;
private float[] docVector;
private List<Number> queryVector;
private List<Number> invalidQueryVector;
@Before
public void setUpVectors() {
field = "vector";
docVector = new float[] {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
invalidQueryVector = Arrays.asList(0.5, 111.3);
}
public void testDenseVectorFunctions() {
for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
DenseVectorScriptDocValues docValues = mock(DenseVectorScriptDocValues.class);
when(docValues.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues));
testDotProduct(docValues, scoreScript);
testCosineSimilarity(docValues, scoreScript);
testL1Norm(docValues, scoreScript);
testL2Norm(docValues, scoreScript);
}
}
private void testDotProduct(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
DotProduct function = new DotProduct(scoreScript, queryVector, field);
double result = function.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
DotProduct deprecatedFunction = new DotProduct(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
DotProduct invalidFunction = new DotProduct(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}
private void testCosineSimilarity(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
CosineSimilarity function = new CosineSimilarity(scoreScript, queryVector, field);
double result = function.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result, 0.001);
CosineSimilarity deprecatedFunction = new CosineSimilarity(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
CosineSimilarity invalidFunction = new CosineSimilarity(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}
private void testL1Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
L1Norm function = new L1Norm(scoreScript, queryVector, field);
double result = function.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result, 0.001);
L1Norm deprecatedFunction = new L1Norm(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
L1Norm invalidFunction = new L1Norm(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}
private void testL2Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
L2Norm function = new L2Norm(scoreScript, queryVector, field);
double result = function.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result, 0.001);
L2Norm deprecatedFunction = new L2Norm(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, deprecatedResult, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
L2Norm invalidFunction = new L2Norm(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}
}

View File

@ -1,358 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.vectors.query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilaritySparse;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProductSparse;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1NormSparse;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ScoreScriptUtilsTests extends ESTestCase {
public void testDenseVectorFunctions() {
testDenseVectorFunctions(Version.V_7_4_0);
testDenseVectorFunctions(Version.CURRENT);
}
private void testDenseVectorFunctions(Version indexVersion) {
String field = "vector";
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
// test dotProduct
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, field);
double result = dotProduct.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
// test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, field);
double result2 = cosineSimilarity.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
// test l1Norm
L1Norm l1norm = new L1Norm(scoreScript, queryVector, field);
double result3 = l1norm.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
// test l2norm
L2Norm l2norm = new L2Norm(scoreScript, queryVector, field);
double result4 = l2norm.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
// test dotProduct fails when queryVector has wrong number of dims
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test cosineSimilarity fails when queryVector has wrong number of dims
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test l1norm fails when queryVector has wrong number of dims
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test l2norm fails when queryVector has wrong number of dims
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}
public void testDeprecatedDenseVectorFunctions() {
testDeprecatedDenseVectorFunctions(Version.V_7_4_0);
testDeprecatedDenseVectorFunctions(Version.CURRENT);
}
private void testDeprecatedDenseVectorFunctions(Version indexVersion) {
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
// test dotProduct
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, dvs);
double result = dotProduct.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, dvs);
double result2 = cosineSimilarity.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1Norm
L1Norm l1norm = new L1Norm(scoreScript, queryVector, dvs);
double result3 = l1norm.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm
L2Norm l2norm = new L2Norm(scoreScript, queryVector, dvs);
double result4 = l2norm.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test dotProduct fails when queryVector has wrong number of dims
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, dvs);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity fails when queryVector has wrong number of dims
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1norm fails when queryVector has wrong number of dims
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm fails when queryVector has wrong number of dims
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
}
public void testSparseVectorFunctions() {
testSparseVectorFunctions(Version.V_7_4_0);
testSparseVectorFunctions(Version.CURRENT);
}
private void testSparseVectorFunctions(Version indexVersion) {
String field = "vector";
int[] docVectorDims = {2, 10, 50, 113, 4545};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("4545", -156.0);
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void testDeprecatedSparseVectorFunctions() {
testDeprecatedSparseVectorFunctions(Version.V_7_4_0);
testDeprecatedSparseVectorFunctions(Version.CURRENT);
}
private void testDeprecatedSparseVectorFunctions(Version indexVersion) {
int[] docVectorDims = {2, 10, 50, 113, 4545};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("4545", -156.0);
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, dvs);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, dvs);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, dvs);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, dvs);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
}
public void testSparseVectorMissingDimensions1() {
String field = "vector";
// Document vector's biggest dimension > query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("114", -20.5);
put("4545", -156.0);
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void testSparseVectorMissingDimensions2() {
String field = "vector";
// Document vector's biggest dimension < query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("4545", -156.0);
put("4548", -20.5);
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
}

View File

@ -0,0 +1,209 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.vectors.query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilaritySparse;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProductSparse;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1NormSparse;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.SparseVectorScriptDocValues;
import org.junit.Before;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class SparseVectorFunctionTests extends ESTestCase {
private String field;
private int[] docVectorDims;
private float[] docVectorValues;
private Map<String, Number> queryVector;
@Before
public void setUpVectors() {
field = "vector";
docVectorDims = new int[] {2, 10, 50, 113, 4545};
docVectorValues = new float[] {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("4545", -156.0);
}};
}
public void testSparseVectorFunctions() {
for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(indexVersion,
docVectorDims, docVectorValues, docVectorDims.length);
SparseVectorScriptDocValues docValues = mock(SparseVectorScriptDocValues.class);
when(docValues.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues));
testDotProduct(docValues, scoreScript);
testCosineSimilarity(docValues, scoreScript);
testL1Norm(docValues, scoreScript);
testL2Norm(docValues, scoreScript);
}
}
private void testDotProduct(SparseVectorScriptDocValues docValues, ScoreScript scoreScript) {
DotProductSparse function = new DotProductSparse(scoreScript, queryVector, field);
double result = function.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
DotProductSparse deprecatedFunction = new DotProductSparse(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, deprecatedResult, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
}
private void testCosineSimilarity(SparseVectorScriptDocValues docValues, ScoreScript scoreScript) {
CosineSimilaritySparse function = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result = function.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result, 0.001);
CosineSimilaritySparse deprecatedFunction = new CosineSimilaritySparse(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, deprecatedResult, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
}
private void testL1Norm(SparseVectorScriptDocValues docValues, ScoreScript scoreScript) {
L1NormSparse function = new L1NormSparse(scoreScript, queryVector, field);
double result = function.l1normSparse();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result, 0.001);
L1NormSparse deprecatedFunction = new L1NormSparse(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.l1normSparse();
assertEquals("l1norm result is not equal to the expected value!", 485.184, deprecatedResult, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
}
private void testL2Norm(SparseVectorScriptDocValues docValues, ScoreScript scoreScript) {
L2NormSparse function = new L2NormSparse(scoreScript, queryVector, field);
double result = function.l2normSparse();
assertEquals("L2NormSparse result is not equal to the expected value!", 301.361, result, 0.001);
L2NormSparse deprecatedFunction = new L2NormSparse(scoreScript, queryVector, docValues);
double deprecatedResult = deprecatedFunction.l2normSparse();
assertEquals("L2NormSparse result is not equal to the expected value!", 301.361, deprecatedResult, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
}
public void testSparseVectorMissingDimensions1() {
String field = "vector";
// Document vector's biggest dimension > query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("114", -20.5);
put("4545", -156.0);
}};
// test dotProductSparse
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilaritySparse
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test L2NormSparse
L2NormSparse L2NormSparse = new L2NormSparse(scoreScript, queryVector, field);
double result4 = L2NormSparse.l2normSparse();
assertEquals("L2NormSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void testSparseVectorMissingDimensions2() {
String field = "vector";
// Document vector's biggest dimension < query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("4545", -156.0);
put("4548", -20.5);
}};
// test dotProductSparse
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilaritySparse
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test L2NormSparse
L2NormSparse L2NormSparse = new L2NormSparse(scoreScript, queryVector, field);
double result4 = L2NormSparse.l2normSparse();
assertEquals("L2NormSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
}