diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index b97566361a0..f656092e472 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -9,7 +9,7 @@ not exceed 500. The number of dimensions can be different across documents. A `dense_vector` field is a single-valued field. -These vectors can be used for document scoring. +These vectors can be used for <>. For example, a document score can represent a distance between a given query vector and the indexed document vector. diff --git a/docs/reference/mapping/types/sparse-vector.asciidoc b/docs/reference/mapping/types/sparse-vector.asciidoc index 38561789b5d..8ed4920c4e6 100644 --- a/docs/reference/mapping/types/sparse-vector.asciidoc +++ b/docs/reference/mapping/types/sparse-vector.asciidoc @@ -9,7 +9,7 @@ not exceed 500. The number of dimensions can be different across documents. A `sparse_vector` field is a single-valued field. -These vectors can be used for document scoring. +These vectors can be used for <>. For example, a document score can represent a distance between a given query vector and the indexed document vector. diff --git a/docs/reference/query-dsl/script-score-query.asciidoc b/docs/reference/query-dsl/script-score-query.asciidoc index cdcfd0f0a50..ee68d3e40fe 100644 --- a/docs/reference/query-dsl/script-score-query.asciidoc +++ b/docs/reference/query-dsl/script-score-query.asciidoc @@ -74,6 +74,113 @@ to be the most efficient by using the internal mechanisms. -------------------------------------------------- // NOTCONSOLE +[[vector-functions]] +===== Functions for vector fields +These functions are used for +for <> and +<> fields. + +For dense_vector fields, `cosineSimilarity` calculates the measure of +cosine similarity between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilarity(params.queryVector, doc['my_dense_vector'])", + "params": { + "queryVector": [4, 3.4, -0.2] <1> + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE +<1> To take advantage of the script optimizations, provide a query vector as a script parameter. + +Similarly, for sparse_vector fields, `cosineSimilaritySparse` calculates cosine similarity +between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "cosineSimilaritySparse(params.queryVector, doc['my_sparse_vector'])", + "params": { + "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE + +For dense_vector fields, `dotProduct` calculates the measure of +dot product between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "dotProduct(params.queryVector, doc['my_dense_vector'])", + "params": { + "queryVector": [4, 3.4, -0.2] + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE + +Similarly, for sparse_vector fields, `dotProductSparse` calculates dot product +between a given query vector and document vectors. + +[source,js] +-------------------------------------------------- +{ + "query": { + "script_score": { + "query": { + "match_all": {} + }, + "script": { + "source": "dotProductSparse(params.queryVector, doc['my_sparse_vector'])", + "params": { + "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} + } + } + } + } +} +-------------------------------------------------- +// NOTCONSOLE + +NOTE: If a document doesn't have a value for a vector field on which +a vector function is executed, 0 is returned as a result +for this document. + +NOTE: If a document's dense vector field has a number of dimensions +different from the query's vector, 0 is used for missing dimensions +in the calculations of vector functions. + [[random-functions]] ===== Random functions diff --git a/modules/mapper-extras/build.gradle b/modules/mapper-extras/build.gradle index 7831de3a68e..73fc8901ec7 100644 --- a/modules/mapper-extras/build.gradle +++ b/modules/mapper-extras/build.gradle @@ -20,4 +20,13 @@ esplugin { description 'Adds advanced field mappers' classname 'org.elasticsearch.index.mapper.MapperExtrasPlugin' + extendedPlugins = ['lang-painless'] } + +dependencies { + compileOnly project(':modules:lang-painless') +} + +integTestCluster { + module project(':modules:lang-painless') +} \ No newline at end of file diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java index 7beddc13ca5..f4a61c3ebd3 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/DenseVectorFieldMapper.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentParser.Token; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.index.query.VectorDVIndexFieldData; import org.elasticsearch.search.DocValueFormat; import java.io.IOException; @@ -119,8 +120,7 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName) { - throw new UnsupportedOperationException( - "Field [" + name() + "] of type [" + typeName() + "] doesn't support sorting, scripting or aggregating"); + return new VectorDVIndexFieldData.Builder(true); } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java index f7288d50393..adf46d6a60d 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/SparseVectorFieldMapper.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentParser.Token; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.index.query.VectorDVIndexFieldData; import org.elasticsearch.search.DocValueFormat; import java.io.IOException; @@ -119,8 +120,7 @@ public class SparseVectorFieldMapper extends FieldMapper { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName) { - throw new UnsupportedOperationException( - "Field [" + name() + "] of type [" + typeName() + "] doesn't support sorting, scripting or aggregating"); + return new VectorDVIndexFieldData.Builder(false); } @Override diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java index c21b006c883..fbf9955f466 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/VectorEncoderDecoder.java @@ -23,7 +23,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.InPlaceMergeSorter; // static utility functions for encoding and decoding dense_vector and sparse_vector fields -final class VectorEncoderDecoder { +public final class VectorEncoderDecoder { static final byte INT_BYTES = 4; static final byte SHORT_BYTES = 2; @@ -34,10 +34,11 @@ final class VectorEncoderDecoder { * BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension * @param values - values of the sparse array * @param dims - dims of the sparse array - * @param dimCount - number of the dimension + * @param dimCount - number of the dimensions, necessary as values and dims are dynamically created arrays, + * and may be over-allocated * @return BytesRef */ - static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { + public static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { // 1. Sort dims and values sortSparseDimsValues(dims, values, dimCount); byte[] buf = new byte[dimCount * (INT_BYTES + SHORT_BYTES)]; @@ -66,9 +67,12 @@ final class VectorEncoderDecoder { /** * Decodes the first part of BytesRef into sparse vector dimensions - * @param vectorBR - vector decoded in BytesRef + * @param vectorBR - sparse vector encoded in BytesRef */ - static int[] decodeSparseVectorDims(BytesRef vectorBR) { + public static int[] decodeSparseVectorDims(BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); int[] dims = new int[dimCount]; int offset = vectorBR.offset; @@ -81,9 +85,12 @@ final class VectorEncoderDecoder { /** * Decodes the second part of the BytesRef into sparse vector values - * @param vectorBR - vector decoded in BytesRef + * @param vectorBR - sparse vector encoded in BytesRef */ - static float[] decodeSparseVector(BytesRef vectorBR) { + public static float[] decodeSparseVector(BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); int offset = vectorBR.offset + SHORT_BYTES * dimCount; //calculate the offset from where values are encoded float[] vector = new float[dimCount]; @@ -100,10 +107,14 @@ final class VectorEncoderDecoder { /** - Sort dimensions in the ascending order and - sort values in the same order as their corresponding dimensions - **/ - static void sortSparseDimsValues(int[] dims, float[] values, int n) { + * Sorts dimensions in the ascending order and + * sorts values in the same order as their corresponding dimensions + * + * @param dims - dimensions of the sparse query vector + * @param values - values for the sparse query vector + * @param n - number of dimensions + */ + public static void sortSparseDimsValues(int[] dims, float[] values, int n) { new InPlaceMergeSorter() { @Override public int compare(int i, int j) { @@ -123,8 +134,42 @@ final class VectorEncoderDecoder { }.sort(0, n); } - // Decodes a BytesRef into an array of floats - static float[] decodeDenseVector(BytesRef vectorBR) { + /** + * Sorts dimensions in the ascending order and + * sorts values in the same order as their corresponding dimensions + * + * @param dims - dimensions of the sparse query vector + * @param values - values for the sparse query vector + * @param n - number of dimensions + */ + public static void sortSparseDimsDoubleValues(int[] dims, double[] values, int n) { + new InPlaceMergeSorter() { + @Override + public int compare(int i, int j) { + return Integer.compare(dims[i], dims[j]); + } + + @Override + public void swap(int i, int j) { + int tempDim = dims[i]; + dims[i] = dims[j]; + dims[j] = tempDim; + + double tempValue = values[j]; + values[j] = values[i]; + values[i] = tempValue; + } + }.sort(0, n); + } + + /** + * Decodes a BytesRef into an array of floats + * @param vectorBR - dense vector encoded in BytesRef + */ + public static float[] decodeDenseVector(BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } int dimCount = vectorBR.length / INT_BYTES; float[] vector = new float[dimCount]; int offset = vectorBR.offset; diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/DocValuesWhitelistExtension.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/DocValuesWhitelistExtension.java new file mode 100644 index 00000000000..f463135d69f --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/DocValuesWhitelistExtension.java @@ -0,0 +1,42 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + + +import org.elasticsearch.painless.spi.PainlessExtension; +import org.elasticsearch.painless.spi.Whitelist; +import org.elasticsearch.painless.spi.WhitelistLoader; +import org.elasticsearch.script.ScoreScript; +import org.elasticsearch.script.ScriptContext; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class DocValuesWhitelistExtension implements PainlessExtension { + + private static final Whitelist WHITELIST = + WhitelistLoader.loadFromResourceFiles(DocValuesWhitelistExtension.class, "docvalues_whitelist.txt"); + + @Override + public Map, List> getContextWhitelists() { + return Collections.singletonMap(ScoreScript.CONTEXT, Collections.singletonList(WHITELIST)); + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java new file mode 100644 index 00000000000..93e80d2a653 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/ScoreScriptUtils.java @@ -0,0 +1,218 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.VectorEncoderDecoder; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.index.mapper.VectorEncoderDecoder.sortSparseDimsDoubleValues; + +public class ScoreScriptUtils { + + //**************FUNCTIONS FOR DENSE VECTORS + + /** + * Calculate a dot product between a query's dense vector and documents' dense vectors + * + * @param queryVector the query vector parsed as {@code List} from json + * @param dvs VectorScriptDocValues representing encoded documents' vectors + */ + public static double dotProduct(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); + return intDotProduct(queryVector, docVector); + } + + /** + * Calculate cosine similarity between a query's dense vector and documents' dense vectors + * + * CosineSimilarity is implemented as a class to use + * painless script caching to calculate queryVectorMagnitude + * only once per script execution for all documents. + * A user will call `cosineSimilarity(params.queryVector, doc['my_vector'])` + */ + public static final class CosineSimilarity { + final double queryVectorMagnitude; + final List queryVector; + + // calculate queryVectorMagnitude once per query execution + public CosineSimilarity(List queryVector) { + this.queryVector = queryVector; + double doubleValue; + double dotProduct = 0; + for (Number value : queryVector) { + doubleValue = value.doubleValue(); + dotProduct += doubleValue * doubleValue; + } + this.queryVectorMagnitude = Math.sqrt(dotProduct); + } + + public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); + + // calculate docVector magnitude + double dotProduct = 0f; + for (int dim = 0; dim < docVector.length; dim++) { + dotProduct += (double) docVector[dim] * docVector[dim]; + } + final double docVectorMagnitude = Math.sqrt(dotProduct); + + double docQueryDotProduct = intDotProduct(queryVector, docVector); + return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); + } + } + + private static double intDotProduct(List v1, float[] v2){ + int dims = Math.min(v1.size(), v2.length); + double v1v2DotProduct = 0; + int dim = 0; + Iterator v1Iter = v1.iterator(); + while(dim < dims) { + v1v2DotProduct += v1Iter.next().doubleValue() * v2[dim]; + dim++; + } + return v1v2DotProduct; + } + + + //**************FUNCTIONS FOR SPARSE VECTORS + + /** + * Calculate a dot product between a query's sparse vector and documents' sparse vectors + * + * DotProductSparse is implemented as a class to use + * painless script caching to prepare queryVector + * only once per script execution for all documents. + * A user will call `dotProductSparse(params.queryVector, doc['my_vector'])` + */ + public static final class DotProductSparse { + final double[] queryValues; + final int[] queryDims; + + // prepare queryVector once per script execution + // queryVector represents a map of dimensions to values + public DotProductSparse(Map queryVector) { + //break vector into two arrays dims and values + int n = queryVector.size(); + queryDims = new int[n]; + queryValues = new double[n]; + int i = 0; + for (Map.Entry dimValue : queryVector.entrySet()) { + try { + queryDims[i] = Integer.parseInt(dimValue.getKey()); + } catch (final NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e); + } + queryValues[i] = dimValue.getValue().doubleValue(); + i++; + } + // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions + sortSparseDimsDoubleValues(queryDims, queryValues, n); + } + + public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); + return intDotProductSparse(queryValues, queryDims, docValues, docDims); + } + } + + /** + * Calculate cosine similarity between a query's sparse vector and documents' sparse vectors + * + * CosineSimilaritySparse is implemented as a class to use + * painless script caching to prepare queryVector and calculate queryVectorMagnitude + * only once per script execution for all documents. + * A user will call `cosineSimilaritySparse(params.queryVector, doc['my_vector'])` + */ + public static final class CosineSimilaritySparse { + final double[] queryValues; + final int[] queryDims; + final double queryVectorMagnitude; + + // prepare queryVector once per script execution + public CosineSimilaritySparse(Map queryVector) { + //break vector into two arrays dims and values + int n = queryVector.size(); + queryValues = new double[n]; + queryDims = new int[n]; + double dotProduct = 0; + int i = 0; + for (Map.Entry dimValue : queryVector.entrySet()) { + try { + queryDims[i] = Integer.parseInt(dimValue.getKey()); + } catch (final NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e); + } + queryValues[i] = dimValue.getValue().doubleValue(); + dotProduct += queryValues[i] * queryValues[i]; + i++; + } + this.queryVectorMagnitude = Math.sqrt(dotProduct); + // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions + sortSparseDimsDoubleValues(queryDims, queryValues, n); + } + + public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { + BytesRef value = dvs.getEncodedValue(); + if (value == null) return 0; + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); + + // calculate docVector magnitude + double dotProduct = 0; + for (float docValue : docValues) { + dotProduct += (double) docValue * docValue; + } + final double docVectorMagnitude = Math.sqrt(dotProduct); + + double docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims); + return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); + } + } + + private static double intDotProductSparse(double[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) { + double v1v2DotProduct = 0; + int v1Index = 0; + int v2Index = 0; + // find common dimensions among vectors v1 and v2 and calculate dotProduct based on common dimensions + while (v1Index < v1Values.length && v2Index < v2Values.length) { + if (v1Dims[v1Index] == v2Dims[v2Index]) { + v1v2DotProduct += v1Values[v1Index] * v2Values[v2Index]; + v1Index++; + v2Index++; + } else if (v1Dims[v1Index] > v2Dims[v2Index]) { + v2Index++; + } else { + v1Index++; + } + } + return v1v2DotProduct; + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java new file mode 100644 index 00000000000..99e581ce4e5 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVAtomicFieldData.java @@ -0,0 +1,80 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.fielddata.AtomicFieldData; +import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.index.fielddata.SortedBinaryDocValues; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; + +final class VectorDVAtomicFieldData implements AtomicFieldData { + + private final LeafReader reader; + private final String field; + private final boolean isDense; + + VectorDVAtomicFieldData(LeafReader reader, String field, boolean isDense) { + this.reader = reader; + this.field = field; + this.isDense = isDense; + } + + @Override + public long ramBytesUsed() { + return 0; // not exposed by Lucene + } + + @Override + public Collection getChildResources() { + return Collections.emptyList(); + } + + @Override + public SortedBinaryDocValues getBytesValues() { + throw new UnsupportedOperationException("String representation of doc values for vector fields is not supported"); + } + + @Override + public ScriptDocValues getScriptValues() { + try { + final BinaryDocValues values = DocValues.getBinary(reader, field); + if (isDense) { + return new VectorScriptDocValues.DenseVectorScriptDocValues(values); + } else { + return new VectorScriptDocValues.SparseVectorScriptDocValues(values); + } + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values for vector field!", e); + } + } + + @Override + public void close() { + // no-op + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java new file mode 100644 index 00000000000..9badf9f11b4 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorDVIndexFieldData.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.SortField; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested; +import org.elasticsearch.index.fielddata.IndexFieldDataCache; +import org.elasticsearch.index.fielddata.plain.DocValuesIndexFieldData; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.search.MultiValueMode; + + +public class VectorDVIndexFieldData extends DocValuesIndexFieldData implements IndexFieldData { + private final boolean isDense; + + public VectorDVIndexFieldData(Index index, String fieldName, boolean isDense) { + super(index, fieldName); + this.isDense = isDense; + } + + @Override + public SortField sortField(@Nullable Object missingValue, MultiValueMode sortMode, Nested nested, boolean reverse) { + throw new IllegalArgumentException("can't sort on the vector field"); + } + + @Override + public VectorDVAtomicFieldData load(LeafReaderContext context) { + return new VectorDVAtomicFieldData(context.reader(), fieldName, isDense); + } + + @Override + public VectorDVAtomicFieldData loadDirect(LeafReaderContext context) throws Exception { + return load(context); + } + + public static class Builder implements IndexFieldData.Builder { + private final boolean isDense; + public Builder(boolean isDense) { + this.isDense = isDense; + } + + @Override + public IndexFieldData build(IndexSettings indexSettings, MappedFieldType fieldType, IndexFieldDataCache cache, + CircuitBreakerService breakerService, MapperService mapperService) { + final String fieldName = fieldType.name(); + return new VectorDVIndexFieldData(indexSettings.getIndex(), fieldName, isDense); + } + + } +} diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java new file mode 100644 index 00000000000..603881d3907 --- /dev/null +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/query/VectorScriptDocValues.java @@ -0,0 +1,78 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.fielddata.ScriptDocValues; + +import java.io.IOException; + +/** + * VectorScriptDocValues represents docValues for dense and sparse vector fields + */ +public abstract class VectorScriptDocValues extends ScriptDocValues { + + private final BinaryDocValues in; + private BytesRef value; + + VectorScriptDocValues(BinaryDocValues in) { + this.in = in; + } + + @Override + public void setNextDocId(int docId) throws IOException { + if (in.advanceExact(docId)) { + value = in.binaryValue(); + } else { + value = null; + } + } + + // package private access only for {@link ScoreScriptUtils} + BytesRef getEncodedValue() { + return value; + } + + @Override + public BytesRef get(int index) { + throw new UnsupportedOperationException("vector fields may only be used via vector functions in scripts"); + } + + @Override + public int size() { + throw new UnsupportedOperationException("vector fields may only be used via vector functions in scripts"); + } + + // not final, as it needs to be extended by Mockito for tests + public static class DenseVectorScriptDocValues extends VectorScriptDocValues { + public DenseVectorScriptDocValues(BinaryDocValues in) { + super(in); + } + } + + // not final, as it needs to be extended by Mockito for tests + public static class SparseVectorScriptDocValues extends VectorScriptDocValues { + public SparseVectorScriptDocValues(BinaryDocValues in) { + super(in); + } + } + +} diff --git a/modules/mapper-extras/src/main/resources/META-INF/services/org.elasticsearch.painless.spi.PainlessExtension b/modules/mapper-extras/src/main/resources/META-INF/services/org.elasticsearch.painless.spi.PainlessExtension new file mode 100644 index 00000000000..f4cc27a362e --- /dev/null +++ b/modules/mapper-extras/src/main/resources/META-INF/services/org.elasticsearch.painless.spi.PainlessExtension @@ -0,0 +1 @@ +org.elasticsearch.index.query.DocValuesWhitelistExtension \ No newline at end of file diff --git a/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt new file mode 100644 index 00000000000..3a8989e20b0 --- /dev/null +++ b/modules/mapper-extras/src/main/resources/org/elasticsearch/index/query/docvalues_whitelist.txt @@ -0,0 +1,32 @@ +# +# Licensed to Elasticsearch under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +class org.elasticsearch.index.query.VectorScriptDocValues { +} +class org.elasticsearch.index.query.VectorScriptDocValues$DenseVectorScriptDocValues { +} +class org.elasticsearch.index.query.VectorScriptDocValues$SparseVectorScriptDocValues { +} + +static_import { + double cosineSimilarity(List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilarity + double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues) from_class org.elasticsearch.index.query.ScoreScriptUtils + double dotProductSparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$DotProductSparse + double cosineSimilaritySparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.index.query.ScoreScriptUtils$CosineSimilaritySparse +} \ No newline at end of file diff --git a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java index 67ab7826137..9b8a741192c 100644 --- a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java +++ b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/VectorEncoderDecoderTests.java @@ -83,7 +83,7 @@ public class VectorEncoderDecoderTests extends ESTestCase { } // imitates the code in DenseVectorFieldMapper::parse - private BytesRef mockEncodeDenseVector(float[] dims) { + public static BytesRef mockEncodeDenseVector(float[] dims) { final short INT_BYTES = VectorEncoderDecoder.INT_BYTES; byte[] buf = new byte[INT_BYTES * dims.length]; int offset = 0; diff --git a/modules/mapper-extras/src/test/java/org/elasticsearch/index/query/ScoreScriptUtilsTests.java b/modules/mapper-extras/src/test/java/org/elasticsearch/index/query/ScoreScriptUtilsTests.java new file mode 100644 index 00000000000..bcdf0387c3f --- /dev/null +++ b/modules/mapper-extras/src/test/java/org/elasticsearch/index/query/ScoreScriptUtilsTests.java @@ -0,0 +1,82 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.query; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.VectorEncoderDecoder; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.index.query.ScoreScriptUtils.CosineSimilarity; +import org.elasticsearch.index.query.ScoreScriptUtils.DotProductSparse; +import org.elasticsearch.index.query.ScoreScriptUtils.CosineSimilaritySparse; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.index.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector; +import static org.elasticsearch.index.query.ScoreScriptUtils.dotProduct; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + + +public class ScoreScriptUtilsTests extends ESTestCase { + public void testDenseVectorFunctions() { + float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = mockEncodeDenseVector(docVector); + VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + List queryVector = Arrays.asList(0.5, 111.3, -13.0, 14.8, -156.0); + + // test dotProduct + double result = dotProduct(queryVector, dvs); + assertEquals("dotProduct result is not equal to the expected value!", 65425.62, result, 0.1); + + // test cosineSimilarity + CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector); + double result2 = cosineSimilarity.cosineSimilarity(dvs); + assertEquals("cosineSimilarity result is not equal to the expected value!", 0.78, result2, 0.1); + } + + public void testSparseVectorFunctions() { + int[] docVectorDims = {2, 10, 50, 113, 4545}; + float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(docVectorDims, docVectorValues, docVectorDims.length); + VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + Map queryVector = new HashMap() {{ + put("2", 0.5); + put("10", 111.3); + put("50", -13.0); + put("113", 14.8); + put("4545", -156.0); + }}; + + // test dotProduct + DotProductSparse docProductSparse = new DotProductSparse(queryVector); + double result = docProductSparse.dotProductSparse(dvs); + assertEquals("dotProductSparse result is not equal to the expected value!", 65425.62, result, 0.1); + + // test cosineSimilarity + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.78, result2, 0.1); + } +} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml new file mode 100644 index 00000000000..e5db535b69b --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_basic.yml @@ -0,0 +1,100 @@ +setup: + - skip: + features: headers + version: " - 7.0.99" + reason: "dense_vector functions were introduced in 7.1.0" + + - do: + indices.create: + include_type_name: false + index: test-index + body: + settings: + number_of_replicas: 0 + mappings: + properties: + my_dense_vector: + type: dense_vector + - do: + index: + index: test-index + id: 1 + body: + my_dense_vector: [230.0, 300.33, -34.8988, 15.555, -200.0] + + - do: + index: + index: test-index + id: 2 + body: + my_dense_vector: [-0.5, 100.0, -13, 14.8, -156.0] + + - do: + index: + index: test-index + id: 3 + body: + my_dense_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - do: + indices.refresh: {} + +--- +"Dot Product": + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "1"} + - gte: {hits.hits.0._score: 65425.62} + - lte: {hits.hits.0._score: 65425.63} + + - match: {hits.hits.1._id: "3"} + - gte: {hits.hits.1._score: 37111.98} + - lte: {hits.hits.1._score: 37111.99} + + - match: {hits.hits.2._id: "2"} + - gte: {hits.hits.2._score: 35853.78} + - lte: {hits.hits.2._score: 35853.79} + +--- +"Cosine Similarity": + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "3"} + - gte: {hits.hits.0._score: 0.999} + - lte: {hits.hits.0._score: 1.001} + + - match: {hits.hits.1._id: "2"} + - gte: {hits.hits.1._score: 0.998} + - lte: {hits.hits.1._score: 1.0} + + - match: {hits.hits.2._id: "1"} + - gte: {hits.hits.2._score: 0.78} + - lte: {hits.hits.2._score: 0.791} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_indexing.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_indexing.yml deleted file mode 100644 index 846341cd8ec..00000000000 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/10_indexing.yml +++ /dev/null @@ -1,27 +0,0 @@ -setup: - - skip: - version: " - 6.99.99" - reason: "dense_vector field was introduced in 7.0.0" - - - do: - indices.create: - index: test-index - body: - settings: - number_of_replicas: 0 - mappings: - properties: - my_dense_vector: - type: dense_vector - - ---- -"Indexing": - - do: - index: - index: test-index - id: 1 - body: - my_dense_vector: [1.5, -10, 3455, 345452.4545] - - - match: { result: created } diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml new file mode 100644 index 00000000000..0b6cab59900 --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/dense-vector/20_special_cases.yml @@ -0,0 +1,152 @@ +setup: + - skip: + features: headers + version: " - 7.0.99" + reason: "dense_vector functions were introduced in 7.1.0" + + - do: + indices.create: + include_type_name: false + index: test-index + body: + settings: + number_of_replicas: 0 + # we need to have 1 shard to get request failure in test "Dense vectors should error with sparse vector functions" + number_of_shards: 1 + mappings: + properties: + my_dense_vector: + type: dense_vector + + +--- +"Vectors of different dimensions and data types": +# document vectors of different dimensions + - do: + index: + index: test-index + id: 1 + body: + my_dense_vector: [10] + + - do: + index: + index: test-index + id: 2 + body: + my_dense_vector: [10, 10.5] + + - do: + index: + index: test-index + id: 3 + body: + my_dense_vector: [10, 10.5, 100.5] + + - do: + indices.refresh: {} + +# query vector of type integer + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [10] + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +# query vector of type double + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [10.0] + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +--- +"Distance functions for documents missing vector field should return 0": +- do: + index: + index: test-index + id: 1 + body: + my_dense_vector: [10] + +- do: + index: + index: test-index + id: 2 + body: + some_other_field: "random_value" + +- do: + indices.refresh: {} + +- do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [10.0] + +- match: {hits.total: 2} +- match: {hits.hits.0._id: "1"} +- match: {hits.hits.1._id: "2"} +- match: {hits.hits.1._score: 0.0} + +--- +"Dense vectors should error with sparse vector functions": +- do: + index: + index: test-index + id: 1 + body: + my_dense_vector: [10, 2, 0.15] + +- do: + indices.refresh: {} + +- do: + catch: bad_request + headers: + Content-Type: application/json + search: + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: {"2": 0.5, "10" : 111.3} +- match: { error.root_cause.0.type: "script_exception" } diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml new file mode 100644 index 00000000000..142a80291ae --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_basic.yml @@ -0,0 +1,100 @@ +setup: + - skip: + features: headers + version: " - 7.0.99" + reason: "sparse_vector functions were introduced in 7.1.0" + + - do: + indices.create: + include_type_name: false + index: test-index + body: + settings: + number_of_replicas: 0 + mappings: + properties: + my_sparse_vector: + type: sparse_vector + - do: + index: + index: test-index + id: 1 + body: + my_sparse_vector: {"2": 230.0, "10" : 300.33, "50": -34.8988, "113": 15.555, "4545": -200.0} + + - do: + index: + index: test-index + id: 2 + body: + my_sparse_vector: {"2": -0.5, "10" : 100.0, "50": -13, "113": 14.8, "4545": -156.0} + + - do: + index: + index: test-index + id: 3 + body: + my_sparse_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + + - do: + indices.refresh: {} + +--- +"Dot Product": +- do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + +- match: {hits.total: 3} + +- match: {hits.hits.0._id: "1"} +- gte: {hits.hits.0._score: 65425.62} +- lte: {hits.hits.0._score: 65425.63} + +- match: {hits.hits.1._id: "3"} +- gte: {hits.hits.1._score: 37111.98} +- lte: {hits.hits.1._score: 37111.99} + +- match: {hits.hits.2._id: "2"} +- gte: {hits.hits.2._score: 35853.78} +- lte: {hits.hits.2._score: 35853.79} + +--- +"Cosine Similarity": +- do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + +- match: {hits.total: 3} + +- match: {hits.hits.0._id: "3"} +- gte: {hits.hits.0._score: 0.999} +- lte: {hits.hits.0._score: 1.001} + +- match: {hits.hits.1._id: "2"} +- gte: {hits.hits.1._score: 0.998} +- lte: {hits.hits.1._score: 1.0} + +- match: {hits.hits.2._id: "1"} +- gte: {hits.hits.2._score: 0.78} +- lte: {hits.hits.2._score: 0.791} diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_indexing.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_indexing.yml deleted file mode 100644 index b3efff318b5..00000000000 --- a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/10_indexing.yml +++ /dev/null @@ -1,27 +0,0 @@ -setup: - - skip: - version: " - 6.99.99" - reason: "sparse_vector field was introduced in 7.0.0" - - - do: - indices.create: - index: test-index - body: - settings: - number_of_replicas: 0 - mappings: - properties: - my_sparse_vector: - type: sparse_vector - - ---- -"Indexing": - - do: - index: - index: test-index - id: 1 - body: - my_sparse_vector: { "50" : 1.8, "2" : -0.4, "10" : 1000.3, "4545" : -0.00004} - - - match: { result: created } diff --git a/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml new file mode 100644 index 00000000000..106a3d966a4 --- /dev/null +++ b/modules/mapper-extras/src/test/resources/rest-api-spec/test/sparse-vector/20_special_cases.yml @@ -0,0 +1,203 @@ +setup: + - skip: + features: headers + version: " - 7.0.99" + reason: "sparse_vector functions were introduced in 7.1.0" + + - do: + indices.create: + include_type_name: false + index: test-index + body: + settings: + number_of_replicas: 0 + # we need to have 1 shard to get request failure in test "Sparse vectors should error with dense vector functions" + number_of_shards: 1 + mappings: + properties: + my_sparse_vector: + type: sparse_vector + + +--- +"Vectors of different dimensions and data types": +# document vectors of different dimensions + - do: + index: + index: test-index + id: 1 + body: + my_sparse_vector: {"1": 10} + + - do: + index: + index: test-index + id: 2 + body: + my_sparse_vector: {"1": 10, "10" : 10.5} + + - do: + index: + index: test-index + id: 3 + body: + my_sparse_vector: {"1": 10, "10" : 10.5, "100": 100.5} + + - do: + indices.refresh: {} + +# query vector of type integer + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"1": 10} + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +# query vector of type double + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"1": 10.0} + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "3"} + +--- +"Distance functions for documents missing vector field should return 0": +- do: + index: + index: test-index + id: 1 + body: + my_sparse_vector: {"1": 10} + +- do: + index: + index: test-index + id: 2 + body: + some_other_field: "random_value" + +- do: + indices.refresh: {} + +- do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"1": 10.0} + +- match: {hits.total: 2} +- match: {hits.hits.0._id: "1"} +- match: {hits.hits.1._id: "2"} +- match: {hits.hits.1._score: 0.0} + + +--- +"Dimensions can be sorted differently": +# All the documents' and query's vectors are the same, and should return cosineSimilarity equal to 1 +- do: + index: + index: test-index + id: 1 + body: + my_sparse_vector: {"2": 230.0, "11" : 300.33, "12": -34.8988, "30": 15.555, "100": -200.0} + +- do: + index: + index: test-index + id: 2 + body: + my_sparse_vector: {"100": -200.0, "12": -34.8988, "11" : 300.33, "113": 15.555, "2": 230.0} + +- do: + index: + index: test-index + id: 3 + body: + my_sparse_vector: {"100": -200.0, "30": 15.555, "12": -34.8988, "11" : 300.33, "2": 230.0} + +- do: + indices.refresh: {} + +- do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555} + +- match: {hits.total: 3} + +- gte: {hits.hits.0._score: 0.99} +- lte: {hits.hits.0._score: 1.001} +- gte: {hits.hits.1._score: 0.99} +- lte: {hits.hits.1._score: 1.001} +- gte: {hits.hits.2._score: 0.99} +- lte: {hits.hits.2._score: 1.001} + +--- +"Sparse vectors should error with dense vector functions": +- do: + index: + index: test-index + id: 1 + body: + my_sparse_vector: {"100": -200.0, "30": 15.555} + +- do: + indices.refresh: {} + +- do: + catch: bad_request + headers: + Content-Type: application/json + search: + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProduct(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: [0.5, 111] +- match: { error.root_cause.0.type: "script_exception" }